import numpy as np
import scipy
import ot
[docs]def match_spots_using_spatial_heuristic(
X,
Y,
use_ot: bool = True) -> np.ndarray:
"""
Calculates and returns a mapping of spots using a spatial heuristic.
Args:
X (array-like, optional): Coordinates for spots X.
Y (array-like, optional): Coordinates for spots Y.
use_ot: If ``True``, use optimal transport ``ot.emd()`` to calculate mapping. Otherwise, use Scipy's ``min_weight_full_bipartite_matching()`` algorithm.
Returns:
Mapping of spots using a spatial heuristic.
"""
n1,n2=len(X),len(Y)
X,Y = norm_and_center_coordinates(X),norm_and_center_coordinates(Y)
dist = scipy.spatial.distance_matrix(X,Y)
if use_ot:
pi = ot.emd(np.ones(n1)/n1, np.ones(n2)/n2, dist)
else:
row_ind, col_ind = scipy.sparse.csgraph.min_weight_full_bipartite_matching(scipy.sparse.csr_matrix(dist))
pi = np.zeros((n1,n2))
pi[row_ind, col_ind] = 1/max(n1,n2)
if n1<n2: pi[:, [(j not in col_ind) for j in range(n2)]] = 1/(n1*n2)
elif n2<n1: pi[[(i not in row_ind) for i in range(n1)], :] = 1/(n1*n2)
return pi
def kl_divergence(X, Y):
"""
Returns pairwise KL divergence (over all pairs of samples) of two matrices X and Y.
Args:
X: np array with dim (n_samples by n_features)
Y: np array with dim (m_samples by n_features)
Returns:
D: np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
"""
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."
X = X/X.sum(axis=1, keepdims=True)
Y = Y/Y.sum(axis=1, keepdims=True)
log_X = np.log(X)
log_Y = np.log(Y)
X_log_X = np.matrix([np.dot(X[i],log_X[i].T) for i in range(X.shape[0])])
D = X_log_X.T - np.dot(X,log_Y.T)
return np.asarray(D)
def kl_divergence_backend(X, Y):
"""
Returns pairwise KL divergence (over all pairs of samples) of two matrices X and Y.
Takes advantage of POT backend to speed up computation.
Args:
X: np array with dim (n_samples by n_features)
Y: np array with dim (m_samples by n_features)
Returns:
D: np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
"""
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."
nx = ot.backend.get_backend(X,Y)
X = X/nx.sum(X,axis=1, keepdims=True)
Y = Y/nx.sum(Y,axis=1, keepdims=True)
log_X = nx.log(X)
log_Y = nx.log(Y)
X_log_X = nx.einsum('ij,ij->i',X,log_X)
X_log_X = nx.reshape(X_log_X,(1,X_log_X.shape[0]))
D = X_log_X.T - nx.dot(X,log_Y.T)
return nx.to_numpy(D)
def intersect(lst1, lst2):
"""
Gets and returns intersection of two lists.
Args:
lst1: List
lst2: List
Returns:
lst3: List of common elements.
"""
temp = set(lst2)
lst3 = [value for value in lst1 if value in temp]
return lst3
def norm_and_center_coordinates(X):
"""
Normalizes and centers coordinates at the origin.
Args:
X: Numpy array
Returns:
X_new: Updated coordiantes.
"""
return (X-X.mean(axis=0))/min(scipy.spatial.distance.pdist(X))
## Covert a sparse matrix into a dense np array
to_dense_array = lambda X: X.toarray() if isinstance(X,scipy.sparse.csr.spmatrix) else np.array(X)
## Returns the data matrix or representation
extract_data_matrix = lambda adata,rep: adata.X if rep is None else adata.obsm[rep]