Source code for TFcomb.tools.GNN_module

import os
os.environ["DGLBACKEND"] = "pytorch"
from dgl.data import DGLDataset
import dgl
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd


def get_edges_data(tf_GRN_mtx):
    Src, Dst, Weight = [], [], []
    gene_list = list(tf_GRN_mtx.columns)
    for i,index_ in enumerate(tf_GRN_mtx.index):
        for j,column_ in enumerate(tf_GRN_mtx.columns):
            if tf_GRN_mtx.iloc[i,j]!=0:
                Src.append(gene_list.index(index_))
                Dst.append(gene_list.index(column_))
                Weight.append(tf_GRN_mtx.iloc[i,j])
    df = pd.DataFrame({'Src':Src,
                       'Dst':Dst,
                       'Weight':Weight})
    return df

[docs] class GRN_Dataset(DGLDataset): """ A custom dataset class for representing Gene Regulatory Networks (GRN) as DGL graphs. This class constructs a DGL graph dataset using gene expression data, transcription factors (TFs), and regulatory relationships encoded in a GRN matrix. Args: adata_part (anndata.AnnData): A subset of AnnData containing gene expression data. tf_GRN_mtx (pd.DataFrame): A GRN matrix with TFs as rows and target genes as columns. The values represent regulatory weights. tf_list (list): A list of transcription factors to label nodes. Attributes: adata_part (anndata.AnnData): Input AnnData object. tf_GRN_mtx (pd.DataFrame): Input GRN matrix. tf_list (list): List of transcription factors. graph (dgl.DGLGraph): Constructed DGL graph with node and edge features. Methods: process(): Constructs the DGL graph from the GRN matrix and gene expression data. __getitem__(i): Returns the constructed graph. Since this dataset only contains one graph, `i` is ignored. __len__(): Returns the number of graphs in the dataset. Always 1 for this class. """ def __init__(self,adata_part,tf_GRN_mtx,tf_list): self.adata_part = adata_part self.tf_GRN_mtx = tf_GRN_mtx self.tf_list = tf_list super().__init__(name="GRN") # self.adata_part = adata_part # self.tf_GRN_mtx = tf_GRN_mtx # self.tf_list = tf_list
[docs] def process(self): node_features = torch.from_numpy(self.adata_part.X.T) node_labels = torch.from_numpy(np.array([1 if i in self.tf_list else 0 for i in self.tf_GRN_mtx.columns])) edges_data = get_edges_data(self.tf_GRN_mtx) edge_features = torch.from_numpy(edges_data["Weight"].to_numpy()) edges_src = torch.from_numpy(edges_data["Src"].to_numpy()) edges_dst = torch.from_numpy(edges_data["Dst"].to_numpy()) self.graph = dgl.graph( (edges_src, edges_dst), num_nodes=node_labels.shape[0] ) self.graph.ndata["feat"] = node_features self.graph.ndata["label"] = node_labels self.graph.edata["weight"] = edge_features
[docs] def __getitem__(self, i): return self.graph
[docs] def __len__(self): return 1
from sklearn.metrics import roc_auc_score,accuracy_score,f1_score,precision_score,recall_score def compute_metrics(pos_score, neg_score, thre=0.5): scores = torch.cat([pos_score, neg_score]).cpu().detach().numpy() labels = torch.cat( [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])] ).cpu().numpy() scores_ = scores.copy() # thre = -1 scores_[scores>thre]=1 scores_[scores<=thre]=0 return roc_auc_score(labels, scores),accuracy_score(labels, scores_),\ f1_score(labels, scores_),precision_score(labels, scores_),recall_score(labels, scores_) def compute_loss(pos_score, neg_score, device='cpu'): scores = torch.cat([pos_score, neg_score]) labels = torch.cat( [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])] ).to(device) return F.binary_cross_entropy_with_logits(scores, labels) #-----models from dgl.nn import SAGEConv # build a two-layer GraphSAGE model class GraphSAGE(nn.Module): def __init__(self, in_feats, h_feats): super(GraphSAGE, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats, "mean") self.conv2 = SAGEConv(h_feats, h_feats, "mean") def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = F.relu(h) h = self.conv2(g, h) return h def load_model(self, path): pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict) import dgl.nn as dglnn class GCN(nn.Module): def __init__(self, in_size, hid_size, out_size): super().__init__() self.layers = nn.ModuleList() # two-layer GCN self.layers.append( dglnn.GraphConv(in_size, hid_size, activation=F.relu) ) self.layers.append(dglnn.GraphConv(hid_size, out_size)) self.dropout = nn.Dropout(0.5) def forward(self, g, features): h = features for i, layer in enumerate(self.layers): if i != 0: h = self.dropout(h) h = layer(g, h) return h def load_model(self, path): pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict) # attention: only write load_data in GAT, which means only GAT can use earlystoping now class GAT(nn.Module): def __init__(self, in_size, hid_size, out_size, heads): super().__init__() self.gat_layers = nn.ModuleList() # three-layer GAT self.gat_layers.append( dglnn.GATConv(in_size, hid_size, heads[0], activation=F.elu) ) self.gat_layers.append( dglnn.GATConv( hid_size * heads[0], hid_size, heads[1], residual=True, activation=F.elu, ) ) self.gat_layers.append( dglnn.GATConv( hid_size * heads[1], out_size, heads[2], residual=True, activation=None, ) ) def forward(self, g, inputs): h = inputs for i, layer in enumerate(self.gat_layers): h = layer(g, h) if i == 2: # last layer h = h.mean(1) else: # other layer(s) h = h.flatten(1) return h def load_model(self, path): pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict) class MLP(torch.nn.Module): def __init__(self, sizes, batch_norm=True, last_layer_act="linear"): """ Multi-layer perceptron :param sizes: list of sizes of the layers :param batch_norm: whether to use batch normalization :param last_layer_act: activation function of the last layer """ super(MLP, self).__init__() layers = [] for s in range(len(sizes) - 1): layers = layers + [ torch.nn.Linear(sizes[s], sizes[s + 1]), torch.nn.BatchNorm1d(sizes[s + 1]) if batch_norm and s < len(sizes) - 1 else None, torch.nn.ReLU() ] layers = [l for l in layers if l is not None][:-1] self.activation = last_layer_act self.network = torch.nn.Sequential(*layers) self.relu = torch.nn.ReLU() def forward(self, x): return self.network(x) def load_model(self, path): pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict) import dgl.function as fn class DotPredictor(nn.Module): def forward(self, g, h): with g.local_scope(): g.ndata["h"] = h # Compute a new edge feature named 'score' by a dot-product between the # source node feature 'h' and destination node feature 'h'. g.apply_edges(fn.u_dot_v("h", "h", "score")) # u_dot_v returns a 1-element vector for each edge so you need to squeeze it. return g.edata["score"][:, 0] def load_model(self, path): pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict) class MLPPredictor(nn.Module): def __init__(self, h_feats): super().__init__() self.W1 = nn.Linear(h_feats * 2, h_feats) self.W2 = nn.Linear(h_feats, 1) def apply_edges(self, edges): """ Computes a scalar score for each edge of the given graph. Parameters ---------- edges : Has three members ``src``, ``dst`` and ``data``, each of which is a dictionary representing the features of the source nodes, the destination nodes, and the edges themselves. Returns ------- dict A dictionary of new edge features. """ h = torch.cat([edges.src["h"], edges.dst["h"]], 1) return {"score": self.W2(F.relu(self.W1(h))).squeeze(1)} def forward(self, g, h): with g.local_scope(): g.ndata["h"] = h g.apply_edges(self.apply_edges) return g.edata["score"] def load_model(self, path): pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict) class EarlyStopping: """Early stops the training if loss doesn't improve after a given patience.""" def __init__(self, patience=10, verbose=False, checkpoint_file_model='',checkpoint_file_pred=''): """ Args: patience (int): How long to wait after last time loss improved. Default: 10 verbose (bool): If True, prints a message for each loss improvement. Default: False """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.loss_min = np.Inf self.checkpoint_file_model = checkpoint_file_model self.checkpoint_file_pred = checkpoint_file_pred def __call__(self, loss, model, pred): if np.isnan(loss): self.early_stop = True score = -loss if self.best_score is None: self.best_score = score self.save_checkpoint(loss, model, pred) elif score <= self.best_score: self.counter += 1 if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True model.load_model(self.checkpoint_file_model) pred.load_model(self.checkpoint_file_pred) else: self.best_score = score self.save_checkpoint(loss, model, pred) self.counter = 0 def save_checkpoint(self, loss, model, pred): """ Saves model when loss decrease """ if self.verbose: print(f'Loss decreased ({self.loss_min:.6f} --> {loss:.6f}). Saving model ...') torch.save(model.state_dict(), self.checkpoint_file_model) torch.save(pred.state_dict(), self.checkpoint_file_pred) self.loss_min = loss