Source code for TFcomb.plotting.plot_grn

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import networkx as nx
from matplotlib.colors import LinearSegmentedColormap

[docs] def plot_GRN( tf_GRN_dict = None, plot_tf_list = None, save_dir_GRN = None, filter_link = True, figsize=(15,15), anno_tfs = None, anno_edges = None, show_mode = True, seed=2, scale=0.5, iterations=5, k=0.1, gene_color_dict = None): """Plot the GRN with TFs and target genes. Args: tf_GRN_dict (dict, optional): A dict where keys are TFs and values are target genes. Defaults to None. plot_tf_list (list, optional): A list of TFs to be plotted. Defaults to None. save_dir_GRN (str, optional): Path to save the figure. Defaults to None. filter_link (bool, optional): Whether filter the links with small values. Defaults to True. figsize (tuple, optional): Size of figure. Defaults to (15,15). anno_tfs (list, optional): Nodes to be annotated. Defaults to None. anno_edges (list, optional): Edges to be annotated. Defaults to None. show_mode (bool, optional): Whether plot the figure. Defaults to True. seed (int, optional): Random seed. Defaults to 2. scale (float, optional): The scale to show the graph. Defaults to 0.5. iterations (int, optional): The parameters of nx.spring_layout. Defaults to 5. k (float, optional): The parameters of nx.spring_layout. Defaults to 0.1. gene_color_dict (dict, optional): A dict where keys are genes and values are colors, which is used to give colors representing the differentially expressed degree. Defaults to None. """ if isinstance(plot_tf_list, str): plot_tf_list = [plot_tf_list] # Create a directed graph G = nx.DiGraph() # Add directed edges and weights edges_with_weights = [] edges_colors = [] genes = [] edges_colors_dict = {} for tf in plot_tf_list: genes.append(tf) for target in tf_GRN_dict[tf].keys(): if filter_link: if abs(tf_GRN_dict[tf][target]) < 0.1: continue # if target in intersect_list: # continue genes.append(target) edges_with_weights.append((tf,target,tf_GRN_dict[tf][target])) if tf_GRN_dict[tf][target] > 0: edges_colors_dict['_'.join([tf, target])] = '#ff8884' else: edges_colors_dict['_'.join([tf, target])] = '#9ac9db' genes = list(np.unique(genes)) G.add_nodes_from(genes) G.add_weighted_edges_from(edges_with_weights) for start, end, weight in G.edges(data="weight"): if '_'.join([start, end]) in anno_edges: edges_colors.append('black') else: edges_colors.append(edges_colors_dict['_'.join([start, end])]) # Calculate line thickness based on weight edge_widths = [] for start, end, weight in G.edges(data="weight"): if '_'.join([start, end]) in anno_edges: edge_widths.append(3 * abs(weight)) # The multiplier can be adjusted as needed else: edge_widths.append(3 * abs(weight)) # The multiplier can be adjusted as needed # Calculate out-degree of each node and set node sizes node_out_degree = dict(G.out_degree()) node_sizes = [150 + 5 * node_out_degree[node] for node in G.nodes] # Base size is 200, each out-degree adds 50 # Calculate font sizes # Calculate out-degrees out_degrees = G.out_degree() font_sizes = {node: np.log(200 + 300 * deg) for node, deg in out_degrees} # Set the color for special nodes node_colors = ["#BEB8DC" if node in plot_tf_list else "#E7DAD2" for node in G.nodes] node_colors = ["red" if node in plot_tf_list else gene_color_dict[node] for node in G.nodes] # Draw the network plt.figure(figsize=figsize) from networkx.drawing.nx_agraph import graphviz_layout pos = nx.spring_layout(G, seed=seed, scale=scale, iterations=iterations, k=k) # Define node layout pos = graphviz_layout(G) nx.draw(G, pos, with_labels=False, node_color=node_colors, node_size = node_sizes, font_color="black", width=edge_widths, edge_color=edges_colors, ) # Draw node labels based on out-degrees and adjust font sizes for node, (x, y) in pos.items(): if node in anno_tfs: plt.text(x, y, s=node, color = 'darkred', # bbox=dict(facecolor='white', edgecolor='black'), ha='center', va='center', fontsize=font_sizes[node]) else: plt.text(x, y, s=node, # bbox=dict(facecolor='white', edgecolor='black'), ha='center', va='center', fontsize=font_sizes[node]) # plt.title(f"Recovered links of {tf} on {prefix_dict[prefix]}", fontsize=12) plt.title(f"GRN", fontsize=12) if save_dir_GRN: plt.savefig(save_dir_GRN,facecolor='white',bbox_inches='tight') plt.savefig(save_dir_GRN.replace('.pdf', '.png'),facecolor='white',bbox_inches='tight', dpi=300) if show_mode: plt.show() else: plt.close()
def create_gradient_color_list(n, start_color, middle_color, end_color): # Create a custom color map # Define the color map colors = [start_color, middle_color, end_color] cmap = LinearSegmentedColormap.from_list("custom_cmap", colors, N=n) # Generate the color list gradient_color_list = [cmap(i) for i in range(cmap.N)] return gradient_color_list
[docs] def get_gene_color_dict(oracle_part, source_state): """Generate the gene_color_dict Args: oracle_part (celloracle object): CellOracle object. source_state (str): Name of the source state. Returns: gene_color_dict (dict): A dict mapping genes to colors. gene_color_dict_2 (dict): A dict mapping genes to colors 2. """ adata = oracle_part.adata.copy() sc.tl.rank_genes_groups(adata, use_raw = False, groupby='celltype', reference = source_state, rankby_abs = False, method = 'wilcoxon' ) de_genes = pd.DataFrame(adata.uns['rank_genes_groups']['names']) de_genes = list(de_genes.iloc[:,0]) # Define colors start_color = "#800080" # Red middle_color = "white" # Green end_color = "#FFD700" # Blue color_list = create_gradient_color_list(len(de_genes), start_color, middle_color, end_color) gene_color_dict = {} for i, gene in enumerate(de_genes): gene_color_dict[gene] = color_list[i] # Visualize the gradient effect plt.figure(figsize=(10, 2)) plt.imshow([color_list], aspect='auto') plt.axis('off') plt.title("Color Gradient from Red to Green to Blue") plt.show() return gene_color_dict