Source code for SpaRCL._reference_centers

# -*- coding: utf-8 -*-

from typing import Optional
from ._compat import Literal
from anndata import AnnData

import numpy as np
import pandas as pd
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.metrics.pairwise import euclidean_distances

from scanpy.tools._utils import _choose_representation

_Method = Literal['MiniBatchKMeans', 'KMeans', 'Random']


[docs]def reference_centers( adata: AnnData, n_centers: int, method: _Method = 'MiniBatchKMeans', n_pcs: Optional[int] = None, use_rep: Optional[str] = None, random_state: int = 0, copy: bool = False, ) -> Optional[AnnData]: ''' Select reference centers for mini-batch relational contrastive learning. Parameters ---------- adata Annotated data matrix. n_centers Number of reference centers to be selected. method Method to use for reference center selection. * ``'MiniBatchKMeans'`` Use `scikit-learn` :class:`~sklearn.cluster.MiniBatchKMeans` to select reference centers. * ``'KMeans'`` Use `scikit-learn` :class:`~sklearn.cluster.KMeans` to select reference centers. * ``'Random'`` Randomly choose reference centers. n_pcs Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None`. use_rep Use the indicated representation. `'X'` or any key for `.obsm` is valid. If `None`, the representation is chosen automatically: For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used. If 'X_pca' is not present, it’s computed with default parameters. random_state Change to use different initial states for the optimization. copy Return a copy instead of writing to ``adata``. Returns ------- Depending on ``copy``, returns or updates ``adata`` with the following fields. .obs['reference_centers'] Boolean indicator of reference centers. ''' if method not in ['MiniBatchKMeans', 'KMeans', 'Random']: raise ValueError('method needs to be \'MiniBatchKMeans\', \'KMeans\' or \'Random\'') if n_centers > adata.shape[0]: raise ValueError(f'Expected n_centers <= n_obs, but n_centers = {n_centers}, n_obs = {adata.shape[0]}') adata = adata.copy() if copy else adata X = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) reference_centers = pd.Series(False, index=adata.obs_names) if method == 'MiniBatchKMeans': kmeans = MiniBatchKMeans(n_clusters=n_centers, random_state=random_state).fit(X) reference_centers.iloc[np.unique(np.argmin(euclidean_distances(X, kmeans.cluster_centers_), axis=0))] = True elif method == 'KMeans': kmeans = KMeans(n_clusters=n_centers, random_state=random_state).fit(X) reference_centers.iloc[np.unique(np.argmin(euclidean_distances(X, kmeans.cluster_centers_), axis=0))] = True elif method == 'Random': rng = np.random.RandomState(seed=random_state) reference_centers.iloc[rng.choice(np.arange(adata.shape[0]), size=n_centers, replace=False)] = True adata.uns['reference_centers'] = {} ref_dict = adata.uns['reference_centers'] ref_dict['params'] = {} ref_dict['params']['n_centers'] = np.count_nonzero(reference_centers) ref_dict['params']['method'] = method ref_dict['params']['n_pcs'] = n_pcs ref_dict['params']['use_rep'] = use_rep ref_dict['params']['random_state'] = random_state adata.obs['reference_centers'] = reference_centers return adata if copy else None