TFcomb main tutorial

This notebook describes the main function of TFcomb to identify the reprogramming TFs from source state to target state. We also provided multiple visualization examples in this tutorial.

import

[1]:
import torch
import os
import scanpy as sc
import celloracle as co
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
[2]:
import TFcomb as tfc
[3]:
import importlib
import TFcomb
importlib.reload(TFcomb.preprocessing.pca_umap)
importlib.reload(TFcomb.preprocessing.oracle_process)
importlib.reload(TFcomb.tools.utils)
importlib.reload(TFcomb.tools.GRN_func)
importlib.reload(TFcomb.tools.GNN_module)
importlib.reload(TFcomb.tools.link_recover)
importlib.reload(TFcomb.tools.link_recover_module)
importlib.reload(TFcomb.tools.tf_inference)
importlib.reload(TFcomb.plotting.plot)
importlib.reload(TFcomb.plotting.plot_grn)
[3]:
<module 'TFcomb.plotting.plot_grn' from '/nfs/public/lichen/code/TFcomb_github_doc/TFcomb/TFcomb/plotting/plot_grn.py'>

Read processed adata

[4]:
data_dir = '../../data/iPSC_example/RNA_data'
save_dir = '../../data/iPSC_example/save' # create the folder to save results
os.makedirs(save_dir,exist_ok=True)

adata = sc.read(os.path.join(data_dir,'adata_rna.h5ad'))

# - ensure the cluster_name
cluster_name = 'celltype'

# - print basic info of the adata
print(f"Cell number is :{adata.shape[0]}")
print(f"Gene number is :{adata.shape[1]}")
print('the cluster_name count is:')
adata.obs[cluster_name].value_counts()
Cell number is :7841
Gene number is :1838
the cluster_name count is:
[4]:
Fibroblasts    4665
iPSCs          3176
Name: celltype, dtype: int64

Load base GRN & create celloracle object

[5]:
# - load base GRN
base_GRN = pd.read_parquet(os.path.join('../../data/iPSC_example/base_GRN',"base_GRN_dataframe_stream.parquet"))
print(f"the base_GRN shape is {base_GRN.shape}")


# - train a pca and umap on the adata first, and input the adata to the oracle.
# the pca and umap model is for later visualize the transition
pca_train,umap_train, oracle = tfc.pp.pca_umap_train(adata,
                   cluster_column_name=cluster_name,
                   embedding_name="X_umap",
                   n_components=50,
                   svd_solver='arpack',
                   random_seed=2022)
the base_GRN shape is (31844, 1100)
WARNING: adata.X seems to be already log-transformed.
[6]:
# - import the base_GRN to the oracle , and do the imputation
oracle.import_TF_data(TF_info_matrix=base_GRN)
oracle = tfc.pp.oracle_preprocess(oracle)
oracle.adata.obs['total'] = 'total'
../../_images/tutorial_example_TFcomb_main_TFcomb_main_10_0.png
n_comps is: 25
cell number is :7841
Auto-selected k is :196
but we set default k is: 10

Construct a total GRN for all cells

[7]:
%%time
# Calculate GRN for each population in "louvain_annot" clustering unit.
# This step may take some time.
cluster_name_for_GRN_unit = cluster_name
links = oracle.get_links(cluster_name_for_GRN_unit='total',
                         alpha=10,
                         verbose_level=10) # 值得注意的在计算link的时候用的都是oracle.adata的imputed count
Inferring GRN for total...
../../_images/tutorial_example_TFcomb_main_TFcomb_main_12_3.png
CPU times: user 2min 9s, sys: 12.1 s, total: 2min 21s
Wall time: 3min 9s
[8]:
# save oracle and links
os.makedirs(os.path.join(save_dir, 'oracle'), exist_ok = True)
oracle.to_hdf5(os.path.join(save_dir, 'oracle', "data.celloracle.oracle"))
links.to_hdf5(file_path=os.path.join(save_dir, 'oracle', "links.celloracle.links"))

Infer the results with TFcomb

Init the parameters

[11]:
adata.obs[cluster_name].value_counts()
[11]:
Fibroblasts    4665
iPSCs          3176
Name: celltype, dtype: int64
[12]:
# - set the ground-truth TFs
gt_tfs = ['POU5F1', 'MYC', 'SOX2', 'NANOG', 'KLF4']

# - set source_state and target_state
source_state, target_state = 'Fibroblasts', 'iPSCs'
combine = '_'.join([source_state, target_state])

# - init the save_dir to save results
save_dir_part = os.path.join(save_dir, combine)
os.makedirs(save_dir_part, exist_ok=True)

Generate the CellOracle with only cells of source state and target state

[13]:
col_name = 'combine'
if not os.path.exists(os.path.join(save_dir_part,"data.celloracle.oracle")):
    oracle_part = oracle.copy()
    oracle_part.adata = oracle.adata[oracle.adata.obs[cluster_name].isin([source_state, target_state])]
    oracle_part.adata.obs[col_name] = '_'.join([source_state, target_state])
    oracle_part.adata.obs[col_name] = oracle_part.adata.obs[col_name].astype('category')
    oracle_part.cluster_column_name = col_name
    oracle_part.adata.uns[f'{col_name}_colors'] = np.array(['#ead3c6'])
    # Calculate GRN for each population in "louvain_annot" clustering unit.
    # This step may take some time.
    links_part = oracle_part.get_links(cluster_name_for_GRN_unit=col_name,
                                alpha=10,
                                verbose_level=10)


    oracle_part.to_hdf5(os.path.join(save_dir_part, "data.celloracle.oracle"))
    links_part.to_hdf5(file_path=os.path.join(save_dir_part, "links.celloracle.links"))
else:
    oracle_part = co.load_hdf5(os.path.join(save_dir_part, "data.celloracle.oracle"))
    links_part = co.load_hdf5(file_path=os.path.join(save_dir_part, "links.celloracle.links"))

print('finish oracle_links_part load')
Inferring GRN for Fibroblasts_iPSCs...
finish oracle_links_part load
[14]:
# set the link threshold and do inference again
threshold_number = 10000
alpha_fit_GRN = 10
links_part.filter_links(threshold_number=threshold_number,
                    p=0.001,
                    weight='coef_abs')
oracle_part.get_cluster_specific_TFdict_from_Links(links_object=links_part)
oracle_part.fit_GRN_for_simulation(alpha=alpha_fit_GRN,
                              use_cluster_specific_TFdict=True)
[15]:
# - get the GRN matrix
gene_GRN_mtx_ori, tf_GRN_mtx_ori, tf_GRN_dict_ori = tfc.tl.get_GRN_parameters(oracle_part, combine)

gene_GRN_mtx, tf_GRN_mtx = gene_GRN_mtx_ori, tf_GRN_mtx_ori

tf_list, gene_list = list(tf_GRN_mtx.index), list(tf_GRN_mtx.columns)

Solve the inverse problem and obtain the expected alterations

[20]:
# - init the save_dir to save figures
save_dir_fig = os.path.join(save_dir, 'figures')
os.makedirs(save_dir_fig, exist_ok=True)

tf_infer_params = {
    "layer_use": 'normalized_count',
    "model": 'ridge',
    "alpha": 1,
    "a1": 0.6,
    "a2": 0.2,
    "a3": 0.2,
    "save": os.path.join(save_dir_fig,'regression_coef.png')
}
[21]:
plt.rcParams["figure.figsize"] = [8, 4]
rr, X, source_ave, target_ave = tfc.tl.TF_inference(
                adata,
                cluster_name,
                tf_list,
                tf_GRN_mtx,
                gene_GRN_mtx,
                tf_GRN_mtx_ori,
                gene_GRN_mtx_ori,
                source_state,
                target_state,
                **tf_infer_params)
==========model:ridge, alpha:1
correlation is: 0.9641523233486307
../../_images/tutorial_example_TFcomb_main_TFcomb_main_34_1.png

Obtain directing scores

[22]:
TF_ds_dict, _ = tfc.tl.get_directing_score(
                    tf_list,
                    rr,
                    tf_GRN_mtx,
                    tf_GRN_dict,
                    diff_ave = (target_ave - source_ave).ravel(),
                    )

Plot results

First, we calculate the DE genes to annotate in our figures.

[23]:
pos_genes, neg_genes, pos_tfs, neg_tfs = tfc.tl.get_de_genes(
                    adata,
                    cluster_name,
                    source_state,
                    target_state,
                    tf_list,
                    p_val=0.05)

Plot the scatter plots to show the directing scores and expected alterations for all TFs.

[24]:
# - get the percentage to show TFs
x_name, y_name = 'Expected alteration', 'Directing score'
TF_number = 20 # top show TFs

df = pd.DataFrame({'index':tf_list + tf_list,
                    'value':list(rr.coef_) + list(TF_ds_dict.values()),
                    'cluster':[x_name]*len(rr.coef_) + [y_name]*len(rr.coef_)})

percentile = tfc.tl.get_percentile_thre(df, 'value', True, x_name, y_name, TF_number)
the percentile threshold is 75.0
[25]:
plot_score_params = {
    "cluster1": x_name,
    "cluster2": y_name,
    "percentile1": percentile,
    "percentile2": percentile,
    "de_genes": list(pos_genes) + list(neg_genes),
    "gt_tfs": gt_tfs,
    "title": '->'.join([source_state, target_state]),
    "plot_all_tf": False,
    "goi_size": 10,
    "point_size": 15,
    "save": os.path.join(save_dir_fig,'score.png'),
}
[26]:
plt.rcParams["figure.figsize"] = [8, 8]
change_tf = tfc.pl.plot_score_comparison(
                            df,
                            value='value',
                            **plot_score_params)
WARNING:adjustText:Looks like you are using a tranform that doesn't support FancyArrowPatch, using ax.annotate instead. The arrows might strike through texts. Increasing shrinkA in arrowprops might help.
../../_images/tutorial_example_TFcomb_main_TFcomb_main_43_1.png

Plot GRN.

[27]:
# - get the score list, if you want to get the TFs for overexpression, set the direction to 'pos'
df_score = tfc.tl.get_single_TF(
    tf_list,
    rr,
    TF_ds_dict,
    direction = 'pos')
df_score.head()
[27]:
TF directing_score expected_alteration
0 POU5F1 0.658345 2.307519
1 RXRA 0.365671 0.237054
2 MYC 0.360068 0.533985
3 ID3 0.318769 0.203920
4 USF2 0.309472 0.446271
[28]:
gene_color_dict = tfc.pl.get_gene_color_dict(oracle_part,
                                      source_state)
../../_images/tutorial_example_TFcomb_main_TFcomb_main_46_0.png
[29]:
# - get tf_GRN_dict_delta
tf_GRN_dict_delta = {}
for key_1, dict_1 in tf_GRN_dict.items():
    tf_GRN_dict_delta[key_1] = {key_2:value for key_2, value in dict_1.items() if key_2 not in tf_GRN_dict_ori[key_1]}

# - set plot TFs
plot_tf_list = list(df_score['TF'][0:10])

recover_tf_list, recover_pair_list = [], []
for tf in plot_tf_list:
    recover_tf_list.extend(list(tf_GRN_dict_delta[tf].keys()))
    recover_pair_list.extend([f'{tf}_{target}' for target in list(tf_GRN_dict_delta[tf].keys())])
recover_tfs = np.unique(recover_tf_list)

Note: pygraphviz should be installed to plot the GRN. Refer here for the installation.

[30]:
tfc.pl.plot_GRN(tf_GRN_dict = tf_GRN_dict,
                plot_tf_list = plot_tf_list, # plot top 10 TFs
                save_dir_GRN = os.path.join(save_dir_fig,'GRN.png'),
                filter_link = False,
                anno_tfs = recover_tfs,
                anno_edges = recover_pair_list,
                show_mode = True,
                figsize = (15,15),
                gene_color_dict = gene_color_dict)
../../_images/tutorial_example_TFcomb_main_TFcomb_main_49_0.png

Calculate TIS score

[31]:
tis_score = tfc.tl.get_benchmark_score(gt_tfs, plot_tf_list)
print('TIS is: ', tis_score)
TIS is:  0.44

Infer TF combiantions

[32]:
df_score_combo = tfc.tl.get_multi_TF(
    tf_list,
    rr,
    tf_GRN_mtx,
    tf_GRN_dict,
    source_ave,
    target_ave,
    direction = 'pos',
    number = 3)
100%|██████████| 19600/19600 [00:09<00:00, 2100.71it/s]
[33]:
df_score_combo.head()
[33]:
TF_combination directing_score
0 (POU5F1, MYC, NANOG) 0.766688
1 (POU5F1, MYC, USF2) 0.755908
2 (SOX2, POU5F1, MYC) 0.755667
3 (POU5F1, MYC, RXRA) 0.752582
4 (POU5F1, MYC, STAT3) 0.750849
[ ]: