Source code for TFcomb.preprocessing.pca_umap

import celloracle as co
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.decomposition import PCA
import umap
import matplotlib.pyplot as plt
import scipy
# from trajectory.oracle_utility import _adata_to_df
from TFcomb.tools.utils_celloracle import _adata_to_df
import seaborn as sns

from TFcomb.tools.utils import fix_seed

# def test():
#     fix_seed(0)

[docs] def pca_umap_train(adata, cluster_column_name=None, embedding_name=None, n_components=50, svd_solver='arpack', random_seed=2022): """ Trains PCA and UMAP models on Oracle-normalized data. This function processes the input AnnData object with Oracle's normalization method, performs PCA for dimensionality reduction, and fits a UMAP model for visualization. Args: adata (anndata.AnnData): The AnnData object containing the dataset to be processed. cluster_column_name (str, optional): The name of the column in `adata.obs` containing cluster labels. Used during Oracle normalization. Defaults to None. embedding_name (str, optional): The name of the embedding field to be added during Oracle normalization. Defaults to None. n_components (int, optional): The number of principal components to retain in PCA. Defaults to 50. svd_solver (str, optional): The SVD solver to use for PCA. Options include 'auto', 'full', 'arpack', and 'randomized'. Defaults to 'arpack'. random_seed (int, optional): The random seed for reproducibility. Defaults to 2022. Returns: tuple: A tuple containing: - sklearn.decomposition.PCA: The trained PCA model. - umap.UMAP: The trained UMAP model. - co.Oracle: The Oracle object with the normalized data. """ plt.rcParams["figure.figsize"] = [5, 5] oracle = co.Oracle() adata.X = adata.layers["raw_count"].copy() oracle.import_anndata_as_raw_count(adata=adata, cluster_column_name=cluster_column_name, embedding_name=embedding_name) gem_imputed = _adata_to_df(oracle.adata, "normalized_count") pca_train=PCA(n_components=n_components, svd_solver=svd_solver,random_state=random_seed) X_pca = pca_train.fit_transform(gem_imputed.values) umap_train=umap.UMAP(random_state=random_seed) proj = umap_train.fit_transform(X_pca) return pca_train,umap_train,oracle
[docs] def pca_umap_vis(pca_train=None, umap_train=None, exp_mtx=None, label=None, title=None, bbox=1, figsize=(12,5), save=None ): """ Visualizes UMAP projections of PCA-transformed data. This function applies a trained PCA model to the input expression matrix, projects it onto the UMAP embedding, and visualizes the results in a scatter plot. Args: pca_train (sklearn.decomposition.PCA, optional): Trained PCA model. Defaults to None. umap_train (umap.UMAP, optional): Trained UMAP model. Defaults to None. exp_mtx (numpy.ndarray or scipy.sparse.csr_matrix, optional): Input expression matrix to be projected. If sparse, it will be converted to a dense array. Defaults to None. label (array-like, optional): Labels for the data points. Used to color the scatter plot. Defaults to None. title (str, optional): Title for the plot. Defaults to None. bbox (float, optional): Positioning factor for the legend box. Defaults to 1. figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (12, 5). save (str, optional): Path to save the plot as an image file. If None, the plot is not saved. Defaults to None. """ if isinstance(exp_mtx,scipy.sparse.csr.csr_matrix): exp_mtx = exp_mtx.toarray() X_pca = pca_train.transform(exp_mtx) proj_ori = umap_train.transform(X_pca) fig=plt.figure(figsize=figsize) for i in range(1): ax_ = fig.add_subplot(1,1,i+1) df = {'UMAP_1':proj_ori[:, 0],\ 'UMAP_2':proj_ori[:, 1], \ 'label':label} df = pd.DataFrame(df) ax = sns.scatterplot(x="UMAP_1", y="UMAP_2", hue="label", edgecolor='none', # hue_order=celltypes, # saturation=1, palette = 'tab10', s=8,linewidth = 0.0001, data=df) plt.xticks(rotation=0,fontsize=15) plt.yticks(rotation=0,fontsize=15) # ax.set(title='original UMAP',xlabel='UMAP_1') ax.set_xlabel('') ax.set_ylabel('') if i>0: ax.set_xlim(lim1_x,lim2_x) ax.set_ylim(lim1_y,lim2_y) ax.set_title(title,fontsize=18) axLine, axLabel = ax.get_legend_handles_labels() # ax.legend([],[],frameon=False) ax.legend(loc='upper right',bbox_to_anchor=(bbox, 1), frameon=False) if save: fig.savefig(save,facecolor='white',bbox_inches='tight',dpi=400)