from typing import List, Tuple, Optional
import numpy as np
from anndata import AnnData
import ot
from sklearn.decomposition import NMF
from .helper import intersect, kl_divergence_backend, to_dense_array, extract_data_matrix
[docs]def pairwise_align(
sliceA: AnnData,
sliceB: AnnData,
alpha: float = 0.1,
dissimilarity: str ='kl',
use_rep: Optional[str] = None,
G_init = None,
a_distribution = None,
b_distribution = None,
norm: bool = False,
numItermax: int = 200,
backend = ot.backend.NumpyBackend(),
use_gpu: bool = False,
return_obj: bool = False,
verbose: bool = False,
gpu_verbose: bool = True,
**kwargs) -> Tuple[np.ndarray, Optional[int]]:
"""
Calculates and returns optimal alignment of two slices.
Args:
sliceA: Slice A to align.
sliceB: Slice B to align.
alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
use_rep: If ``None``, uses ``slice.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``slice.obsm[use_rep]``.
G_init (array-like, optional): Initial mapping to be used in FGW-OT, otherwise default is uniform mapping.
a_distribution (array-like, optional): Distribution of sliceA spots, otherwise default is uniform.
b_distribution (array-like, optional): Distribution of sliceB spots, otherwise default is uniform.
numItermax: Max number of iterations during FGW-OT.
norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``.
use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch.
return_obj: If ``True``, additionally returns objective function output of FGW-OT.
verbose: If ``True``, FGW-OT is verbose.
gpu_verbose: If ``True``, print whether gpu is being used to user.
Returns:
- Alignment of spots.
If ``return_obj = True``, additionally returns:
- Objective function output of FGW-OT.
"""
# Determine if gpu or cpu is being used
if use_gpu:
try:
import torch
except:
print("We currently only have gpu support for Pytorch. Please install torch.")
if isinstance(backend,ot.backend.TorchBackend):
if torch.cuda.is_available():
if gpu_verbose:
print("gpu is available, using gpu.")
else:
if gpu_verbose:
print("gpu is not available, resorting to torch cpu.")
use_gpu = False
else:
print("We currently only have gpu support for Pytorch, please set backend = ot.backend.TorchBackend(). Reverting to selected backend cpu.")
use_gpu = False
else:
if gpu_verbose:
print("Using selected backend cpu. If you want to use gpu, set use_gpu = True.")
# subset for common genes
common_genes = intersect(sliceA.var.index, sliceB.var.index)
sliceA = sliceA[:, common_genes]
sliceB = sliceB[:, common_genes]
# Backend
nx = backend
# Calculate spatial distances
coordinatesA = sliceA.obsm['spatial'].copy()
coordinatesA = nx.from_numpy(coordinatesA)
coordinatesB = sliceB.obsm['spatial'].copy()
coordinatesB = nx.from_numpy(coordinatesB)
if isinstance(nx,ot.backend.TorchBackend):
coordinatesA = coordinatesA.float()
coordinatesB = coordinatesB.float()
D_A = ot.dist(coordinatesA,coordinatesA, metric='euclidean')
D_B = ot.dist(coordinatesB,coordinatesB, metric='euclidean')
if isinstance(nx,ot.backend.TorchBackend) and use_gpu:
D_A = D_A.cuda()
D_B = D_B.cuda()
# Calculate expression dissimilarity
A_X, B_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA,use_rep))), nx.from_numpy(to_dense_array(extract_data_matrix(sliceB,use_rep)))
if isinstance(nx,ot.backend.TorchBackend) and use_gpu:
A_X = A_X.cuda()
B_X = B_X.cuda()
if dissimilarity.lower()=='euclidean' or dissimilarity.lower()=='euc':
M = ot.dist(A_X,B_X)
else:
s_A = A_X + 0.01
s_B = B_X + 0.01
M = kl_divergence_backend(s_A, s_B)
M = nx.from_numpy(M)
if isinstance(nx,ot.backend.TorchBackend) and use_gpu:
M = M.cuda()
# init distributions
if a_distribution is None:
a = nx.ones((sliceA.shape[0],))/sliceA.shape[0]
else:
a = nx.from_numpy(a_distribution)
if b_distribution is None:
b = nx.ones((sliceB.shape[0],))/sliceB.shape[0]
else:
b = nx.from_numpy(b_distribution)
if isinstance(nx,ot.backend.TorchBackend) and use_gpu:
a = a.cuda()
b = b.cuda()
if norm:
D_A /= nx.min(D_A[D_A>0])
D_B /= nx.min(D_B[D_B>0])
# Run OT
if G_init is not None:
G_init = nx.from_numpy(G_init)
if isinstance(nx,ot.backend.TorchBackend):
G_init = G_init.float()
if use_gpu:
G_init.cuda()
pi, logw = my_fused_gromov_wasserstein(M, D_A, D_B, a, b, G_init = G_init, loss_fun='square_loss', alpha= alpha, log=True, numItermax=numItermax,verbose=verbose, use_gpu = use_gpu)
pi = nx.to_numpy(pi)
obj = nx.to_numpy(logw['fgw_dist'])
if isinstance(backend,ot.backend.TorchBackend) and use_gpu:
torch.cuda.empty_cache()
if return_obj:
return pi, obj
return pi
[docs]def center_align(
A: AnnData,
slices: List[AnnData],
lmbda = None,
alpha: float = 0.1,
n_components: int = 15,
threshold: float = 0.001,
max_iter: int = 10,
dissimilarity: str ='kl',
norm: bool = False,
random_seed: Optional[int] = None,
pis_init: Optional[List[np.ndarray]] = None,
distributions = None,
backend = ot.backend.NumpyBackend(),
use_gpu: bool = False,
verbose: bool = False,
gpu_verbose: bool = True) -> Tuple[AnnData, List[np.ndarray]]:
"""
Computes center alignment of slices.
Args:
A: Slice to use as the initialization for center alignment; Make sure to include gene expression and spatial information.
slices: List of slices to use in the center alignment.
lmbda (array-like, optional): List of probability weights assigned to each slice; If ``None``, use uniform weights.
alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
n_components: Number of components in NMF decomposition.
threshold: Threshold for convergence of W and H during NMF decomposition.
max_iter: Maximum number of iterations for our center alignment algorithm.
dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
random_seed: Set random seed for reproducibility.
pis_init: Initial list of mappings between 'A' and 'slices' to solver. Otherwise, default will automatically calculate mappings.
distributions (List[array-like], optional): Distributions of spots for each slice. Otherwise, default is uniform.
backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``.
use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch.
verbose: If ``True``, FGW-OT is verbose.
gpu_verbose: If ``True``, print whether gpu is being used to user.
Returns:
- Inferred center slice with full and low dimensional representations (W, H) of the gene expression matrix.
- List of pairwise alignment mappings of the center slice (rows) to each input slice (columns).
"""
# Determine if gpu or cpu is being used
if use_gpu:
try:
import torch
except:
print("We currently only have gpu support for Pytorch. Please install torch.")
if isinstance(backend,ot.backend.TorchBackend):
if torch.cuda.is_available():
if gpu_verbose:
print("gpu is available, using gpu.")
else:
if gpu_verbose:
print("gpu is not available, resorting to torch cpu.")
use_gpu = False
else:
print("We currently only have gpu support for Pytorch, please set backend = ot.backend.TorchBackend(). Reverting to selected backend cpu.")
use_gpu = False
else:
if gpu_verbose:
print("Using selected backend cpu. If you want to use gpu, set use_gpu = True.")
if lmbda is None:
lmbda = len(slices)*[1/len(slices)]
if distributions is None:
distributions = len(slices)*[None]
# get common genes
common_genes = A.var.index
for s in slices:
common_genes = intersect(common_genes, s.var.index)
# subset common genes
A = A[:, common_genes]
for i in range(len(slices)):
slices[i] = slices[i][:, common_genes]
print('Filtered all slices for common genes. There are ' + str(len(common_genes)) + ' common genes.')
# Run initial NMF
if dissimilarity.lower()=='euclidean' or dissimilarity.lower()=='euc':
model = NMF(n_components=n_components, init='random', random_state = random_seed, verbose = verbose)
else:
model = NMF(n_components=n_components, solver = 'mu', beta_loss = 'kullback-leibler', init='random', random_state = random_seed, verbose = verbose)
if pis_init is None:
pis = [None for i in range(len(slices))]
W = model.fit_transform(A.X)
else:
pis = pis_init
W = model.fit_transform(A.shape[0]*sum([lmbda[i]*np.dot(pis[i], to_dense_array(slices[i].X)) for i in range(len(slices))]))
H = model.components_
center_coordinates = A.obsm['spatial']
if not isinstance(center_coordinates, np.ndarray):
print("Warning: A.obsm['spatial'] is not of type numpy array.")
# Initialize center_slice
center_slice = AnnData(np.dot(W,H))
center_slice.var.index = common_genes
center_slice.obs.index = A.obs.index
center_slice.obsm['spatial'] = center_coordinates
# Minimize R
iteration_count = 0
R = 0
R_diff = 100
while R_diff > threshold and iteration_count < max_iter:
print("Iteration: " + str(iteration_count))
pis, r = center_ot(W, H, slices, center_coordinates, common_genes, alpha, backend, use_gpu, dissimilarity = dissimilarity, norm = norm, G_inits = pis, distributions=distributions, verbose = verbose)
W, H = center_NMF(W, H, slices, pis, lmbda, n_components, random_seed, dissimilarity = dissimilarity, verbose = verbose)
R_new = np.dot(r,lmbda)
iteration_count += 1
R_diff = abs(R - R_new)
print("Objective ",R_new)
print("Difference: " + str(R_diff) + "\n")
R = R_new
center_slice = A.copy()
center_slice.X = np.dot(W, H)
center_slice.uns['paste_W'] = W
center_slice.uns['paste_H'] = H
center_slice.uns['full_rank'] = center_slice.shape[0]*sum([lmbda[i]*np.dot(pis[i], to_dense_array(slices[i].X)) for i in range(len(slices))])
center_slice.uns['obj'] = R
return center_slice, pis
#--------------------------- HELPER METHODS -----------------------------------
def center_ot(W, H, slices, center_coordinates, common_genes, alpha, backend, use_gpu, dissimilarity = 'kl', norm = False, G_inits = None, distributions=None, verbose = False):
center_slice = AnnData(np.dot(W,H))
center_slice.var.index = common_genes
center_slice.obsm['spatial'] = center_coordinates
if distributions is None:
distributions = len(slices)*[None]
pis = []
r = []
print('Solving Pairwise Slice Alignment Problem.')
for i in range(len(slices)):
p, r_q = pairwise_align(center_slice, slices[i], alpha = alpha, dissimilarity = dissimilarity, norm = norm, return_obj = True, G_init = G_inits[i], b_distribution=distributions[i], backend = backend, use_gpu = use_gpu, verbose = verbose, gpu_verbose = False)
pis.append(p)
r.append(r_q)
return pis, np.array(r)
def center_NMF(W, H, slices, pis, lmbda, n_components, random_seed, dissimilarity = 'kl', verbose = False):
print('Solving Center Mapping NMF Problem.')
n = W.shape[0]
B = n*sum([lmbda[i]*np.dot(pis[i], to_dense_array(slices[i].X)) for i in range(len(slices))])
if dissimilarity.lower()=='euclidean' or dissimilarity.lower()=='euc':
model = NMF(n_components=n_components, init='random', random_state = random_seed, verbose = verbose)
else:
model = NMF(n_components=n_components, solver = 'mu', beta_loss = 'kullback-leibler', init='random', random_state = random_seed, verbose = verbose)
W_new = model.fit_transform(B)
H_new = model.components_
return W_new, H_new
def my_fused_gromov_wasserstein(M, C1, C2, p, q, G_init = None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False,numItermax=200, use_gpu = False, **kwargs):
"""
Adapted fused_gromov_wasserstein with the added capability of defining a G_init (inital mapping).
Also added capability of utilizing different POT backends to speed up computation.
For more info, see: https://pythonot.github.io/gen_modules/ot.gromov.html
"""
p, q = ot.utils.list_to_array(p, q)
p0, q0, C10, C20, M0 = p, q, C1, C2, M
nx = ot.backend.get_backend(p0, q0, C10, C20, M0)
constC, hC1, hC2 = ot.gromov.init_matrix(C1, C2, p, q, loss_fun)
if G_init is None:
G0 = p[:, None] * q[None, :]
else:
G0 = (1/nx.sum(G_init)) * G_init
if use_gpu:
G0 = G0.cuda()
def f(G):
return ot.gromov.gwloss(constC, hC1, hC2, G)
def df(G):
return ot.gromov.gwggrad(constC, hC1, hC2, G)
if log:
res, log = ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
fgw_dist = log['loss'][-1]
log['fgw_dist'] = fgw_dist
log['u'] = log['u']
log['v'] = log['v']
return res, log
else:
return ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)