Example #1
0
    def __init__(self, params):
        self.params = params

        self.device = torch.device('cpu' if self.params.gpu ==
                                   -1 else f'cuda:{params.gpu}')
        print(self.device)
        # self.log_dir = get_dump_path(params)

        self.batch_size = params.batch_size

        self.load_pretrained_model = params.load_pretrained_model
        self.pretrained_model_path = params.pretrained_model_path
        self.save_model_path = params.save_model_path

        self.using_mmd = params.using_mmd

        # data
        self.num_cells, self.num_genes, self.num_classes, self.graph, self.features, self.train_dataset, \
        self.train_mask, self.vali_mask, self.test_dataset = load_PPP_mammary_gland(params)
        # self.vae = torch.load('./saved_model/vae.pkl', self.features.device)
        # self.features = self.vae.get_hidden(self.features)
        # model
        self.model = GraphSAGE(in_feats=params.dense_dim,
                               n_hidden=params.hidden_dim,
                               n_classes=self.num_classes,
                               n_layers=params.n_layers,
                               activation=F.relu,
                               dropout=params.dropout,
                               aggregator_type=params.aggregator_type,
                               num_genes=self.num_genes)
        # self.model = GCN(
        #                  in_feats=params.dense_dim,
        #                  n_hidden=params.hidden_dim,
        #                  n_classes=self.num_classes,
        #                  n_layers=params.n_layers,
        #                  activation=F.relu)
        # self.model = GAT(in_feats=params.dense_dim,
        #                  n_hidden=100,
        #                  n_classes=self.num_classes,
        #                  n_layers=params.n_layers,
        #                  activation=F.relu)
        self.graph.readonly(readonly_state=True)
        self.loss_fn = nn.BCEWithLogitsLoss().to(self.device)
        self.model.to(self.device)
        self.features = self.features.to(self.device)
        self.train_mask = self.train_mask.to(self.device)
        self.vali_mask = self.vali_mask.to(self.device)
        self.train_dataset = self.train_dataset.to(self.device)
        self.trainset = TrainSet(self.train_dataset[self.train_mask])
        self.test_dataset = self.test_dataset.to(self.device)
        self.train_dataloader = DataLoader(self.trainset,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           drop_last=True)
        self.test_dataloader = DataLoader(self.test_dataset,
                                          batch_size=self.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        self.loss_weight = torch.Tensor(params.loss_weight).to(self.device)
def build_model(args, in_feats, n_hidden, n_classes, device, n_layers=1):
    if args.model == 'gcn_cv_sc':
        infer_device = torch.device("cpu")  # for sampling
        train_model = GCNSampling(in_feats, n_hidden, n_classes, 2, F.relu,
                                  args.dropout).to(device)
        infer_model = GCNInfer(in_feats, args.n_hidden, n_classes, 2, F.relu)
        model = (train_model, infer_model)
    elif args.model == 'gs-mean':
        model = GraphSAGE(in_feats, n_hidden, n_classes, n_layers, F.relu,
                          args.dropout, 'mean').to(device)
    elif args.model == 'mlp':
        model = MLP(in_feats, n_hidden, n_classes, n_layers, F.relu,
                    args.dropout).to(device)
    elif args.model == 'mostfrequent':
        model = MostFrequentClass()
    # elif args.model == 'egcn':
    #     if n_layers != 2:
    #         print("Warning, EGCN doesn't respect n_layers")
    #     egcn_args = egcn_utils.Namespace({'feats_per_node': in_feats,
    #                                       'layer_1_feats': n_hidden,
    #                                       'layer_2_feats': n_classes})
    #     model = EGCN(egcn_args, torch.nn.RReLU(), device=device, skipfeats=False)
    elif args.model == 'gat':
        print("Warning, GAT doesn't respect n_layers")
        heads = [8, args.gat_out_heads]  # Fixed head config
        # Div num_hidden by heads for same capacity
        n_hidden_per_head = int(n_hidden / heads[0])
        assert n_hidden_per_head * heads[
            0] == n_hidden, f"{n_hidden} not divisible by {heads[0]}"
        model = GAT(1, in_feats, n_hidden_per_head, n_classes, heads, F.elu,
                    0.6, 0.6, 0.2, False).to(device)
    else:
        raise NotImplementedError("Model not implemented")

    return model
Example #3
0
class Trainer:
    def __init__(self, params):
        self.params = params
        self.train_device = torch.device('cpu' if params.use_cpu else 'cuda:0')
        self.test_device = torch.device('cpu' if params.use_cpu else 'cuda:0')
        # self.log_dir = get_dump_path(params)

        # data
        self.num_cells, self.num_genes, self.graph, self.features, self.labels, self.train_mask, self.test_mask = load_tissue(
            params)
        # model
        self.model = GraphSAGE(self.graph,
                               in_feats=params.dense_dim,
                               n_hidden=params.hidden_dim,
                               n_classes=params.n_classes,
                               n_layers=params.n_layers,
                               activation=F.relu,
                               dropout=params.dropout,
                               aggregator_type=params.aggregator_type)

    def train(self):
        self.model.train()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=params.lr)
        loss_fn = nn.CrossEntropyLoss()

        for epoch in range(params.n_epochs):
            # forward
            self.model.to(self.train_device)
            self.features = self.features.to(self.train_device)
            self.train_mask = self.train_mask.to(self.train_device)
            self.test_mask = self.test_mask.to(self.train_device)
            self.labels = self.labels.to(self.train_device)

            logits = self.model(self.features)
            loss = loss_fn(logits[self.num_genes:][self.train_mask],
                           self.labels[self.train_mask])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # acc = self.evaluate(self.train_mask)
            # print("Train Accuracy {:.4f}".format(acc))
            _, _, train_acc = self.evaluate(self.train_mask)
            c, t, test_acc = self.evaluate(self.test_mask)
            if epoch % 20 == 0:
                print(
                    f"Epoch {epoch:04d}: Acc {train_acc:.4f} / {test_acc:.4f}, Loss {loss:.4f}, [{c}/{t}]"
                )

    def evaluate(self, mask):
        self.model.eval()
        with torch.no_grad():
            logits = self.model(self.features)
            logits = logits[self.num_genes:][mask]
            labels = self.labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels).item()
        total = mask.type(torch.LongTensor).sum().item()
        return correct, total, correct / total
Example #4
0
    def __init__(self, params):
        self.params = params
        self.train_device = torch.device('cpu' if params.use_cpu else 'cuda:0')
        self.test_device = torch.device('cpu' if params.use_cpu else 'cuda:0')
        # self.log_dir = get_dump_path(params)

        # data
        self.num_cells, self.num_genes, self.graph, self.features, self.labels, self.train_mask, self.test_mask = load_tissue(
            params)
        # model
        self.model = GraphSAGE(self.graph,
                               in_feats=params.dense_dim,
                               n_hidden=params.hidden_dim,
                               n_classes=params.n_classes,
                               n_layers=params.n_layers,
                               activation=F.relu,
                               dropout=params.dropout,
                               aggregator_type=params.aggregator_type)
Example #5
0
def build_model(args,
                in_feats,
                n_hidden,
                n_classes,
                device,
                n_layers=1,
                backend='geometric',
                edge_index=None,
                num_nodes=None):
    if backend == 'geometric':
        if args.model == 'gs-mean':
            model = geo.GraphSAGE(in_feats, n_hidden, n_classes, n_layers,
                                  F.relu, args.dropout).to(device)
        elif args.model == "gcn":
            model = geo.GCN(in_feats, n_hidden, n_classes, n_layers, F.relu,
                            args.dropout).to(device)
        elif args.model == "gat":
            print("Warning, GAT doesn't respect n_layers")
            heads = [8, args.gat_out_heads]  # Fixed head config
            n_hidden_per_head = int(n_hidden / heads[0])
            model = geo.GAT(in_feats, n_hidden_per_head, n_classes, F.relu,
                            args.dropout, 0.6, heads).to(device)
        elif args.model == "mlp":
            model = geo.MLP(in_feats, n_hidden, n_classes, n_layers, F.relu,
                            args.dropout).to(device)
        elif args.model == "jknet":
            model = geo.JKNet(in_feats, n_hidden, n_classes, n_layers, F.relu,
                              args.dropout).to(device)
        elif args.model == "sgnet":
            model = SGNet(in_channels=in_feats,
                          out_channels=n_classes,
                          K=n_layers).to(device)
        else:
            raise NotImplementedError
    else:
        if args.model == 'gs-mean':
            model = GraphSAGE(in_feats, n_hidden, n_classes, n_layers, F.relu,
                              args.dropout, 'mean').to(device)
        elif args.model == 'mlp':
            model = MLP(in_feats, n_hidden, n_classes, n_layers, F.relu,
                        args.dropout).to(device)
        elif args.model == 'mostfrequent':
            model = MostFrequentClass()
        elif args.model == 'gat':
            print("Warning, GAT doesn't respect n_layers")
            heads = [8, args.gat_out_heads]  # Fixed head config
            # Div num_hidden by heads for same capacity
            n_hidden_per_head = int(n_hidden / heads[0])
            assert n_hidden_per_head * heads[
                0] == n_hidden, f"{n_hidden} not divisible by {heads[0]}"
            model = GAT(1, in_feats, n_hidden_per_head, n_classes, heads,
                        F.elu, 0.6, 0.6, 0.2, False).to(device)
        else:
            raise NotImplementedError("Model not implemented")

    return model
Example #6
0
def main(args):
    start_time_str = time.strftime("_%m_%d_%H_%M_%S", time.localtime())
    log_path = os.path.join(args.log_dir, args.model + start_time_str)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    logging.basicConfig(filename=os.path.join(log_path, 'log_file'),
                        filemode='w',
                        format='| %(asctime)s |\n%(message)s',
                        datefmt='%b %d %H:%M:%S',
                        level=logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))

    logging.info(args)
    # load data
    if args.model in ['EGNN', 'ECConv', 'GTEA-ST']:
        data = Dataset(data_dir=args.data_dir, batch_size=args.batch_size, use_static=True)
    else:
        data = Dataset(data_dir=args.data_dir, batch_size=args.batch_size)

    
    

    g = data.g

    # features = torch.FloatTensor(data.features)
    # labels = torch.LongTensor(data.labels)

    train_loader = data.train_loader
    val_loader = data.val_loader
    test_loader = data.test_loader

    num_nodes = data.num_nodes
    node_in_dim = args.node_in_dim

    num_edges = data.num_edges
    edge_in_dim = data.edge_in_dim
    edge_timestep_len = data.edge_timestep_len    

    num_train_samples = data.num_train_samples
    num_val_samples = data.num_val_samples
    num_test_samples = data.num_test_samples

    logging.info("""----Data statistics------'
      #Nodes %d
      #Edges %d
      #Node_feat %d
      #Edge_feat %d
      #Edge_timestep %d
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (num_nodes, num_edges, 
           node_in_dim, edge_in_dim, edge_timestep_len,
              num_train_samples,
              num_val_samples,
              num_test_samples))
    


    device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() and args.gpu >=0 else "cpu")
    infer_device = device if args.infer_gpu else torch.device('cpu')

    # g = g.to(device)


    # create  model   

    if args.model == 'GCN':

        model = GCN(num_nodes=num_nodes,
                        in_feats=node_in_dim, 
                        n_hidden=args.node_hidden_dim, 
                        n_layers=args.num_layers,
                        activation=F.relu,
                        dropout=args.dropout)
    elif args.model == 'GraphSAGE':

        model = GraphSAGE(num_nodes=num_nodes,
                            in_feats=node_in_dim, 
                            n_hidden=args.node_hidden_dim, 
                            n_layers=args.num_layers,
                            activation=F.relu,
                            dropout=args.dropout)
    elif args.model == 'GAT':

        model = GAT(num_nodes=num_nodes,                 
                 in_dim=node_in_dim,
                 hidden_dim=args.node_hidden_dim,
                 num_layers=args.num_layers,
                 num_heads=args.num_heads)
    elif args.model == 'ECConv':
        model = ECConv(num_nodes=num_nodes,                 
                 node_in_dim=node_in_dim,
                 edge_in_dim=edge_in_dim,
                 hidden_dim=args.node_hidden_dim,
                 num_layers=args.num_layers,
                 drop_prob=args.dropout,
                 device=device)
    elif args.model == 'EGNN':
        model = EGNN(num_nodes=num_nodes,                 
                 node_in_dim=node_in_dim,
                 edge_in_dim=edge_in_dim,
                 hidden_dim=args.node_hidden_dim,
                 num_layers=args.num_layers,
                 drop_prob=args.dropout,
                 device=device)
    elif args.model == 'GTEA-ST':
        model = GTEAST(num_nodes=num_nodes,                 
                 node_in_dim=node_in_dim,
                 edge_in_dim=edge_in_dim,
                 node_hidden_dim=args.node_hidden_dim,
                 num_layers=args.num_layers,
                 drop_prob=args.dropout,
                 device=device)

    elif args.model == 'TGAT':
        model = TGAT(num_nodes=num_nodes, 
                        node_in_dim=node_in_dim, 
                        node_hidden_dim=args.node_hidden_dim, 
                        edge_in_dim=edge_in_dim-1, 
                        time_hidden_dim=args.time_hidden_dim, 
                        num_class=0, 
                        num_layers=args.num_layers, 
                        num_heads=args.num_heads, 
                        device=device, 
                        drop_prob=args.dropout)
    elif args.model == 'GTEA-LSTM':
        model = GTEALSTM(num_nodes=num_nodes,
                           node_in_dim=node_in_dim, 
                           node_hidden_dim=args.node_hidden_dim,
                           edge_in_dim=edge_in_dim, 
                           num_class=0, 
                           num_layers=args.num_layers, 
                           num_time_layers=args.num_lstm_layers, 
                           bidirectional=args.bidirectional,
                           device=device, 
                           drop_prob=args.dropout)
    elif args.model == 'GTEA-LSTM+T2V':
        model = GTEALSTMT2V(num_nodes=num_nodes,
                           node_in_dim=node_in_dim, 
                           node_hidden_dim=args.node_hidden_dim,
                           edge_in_dim=edge_in_dim-1, 
                           time_hidden_dim=args.time_hidden_dim,
                           num_class=0, 
                           num_layers=args.num_layers, 
                           num_time_layers=args.num_lstm_layers, 
                           bidirectional=args.bidirectional,
                           device=device, 
                           drop_prob=args.dropout)
    elif args.model == 'GTEA-Trans':
        model = GTEATrans(num_nodes=num_nodes,
                           node_in_dim=node_in_dim, 
                           node_hidden_dim=args.node_hidden_dim,
                           edge_in_dim=edge_in_dim, 
                           num_class=0, 
                           num_layers=args.num_layers, 
                           num_heads=args.num_heads,
                           num_time_layers=args.num_lstm_layers, 
                           device=device, 
                           drop_prob=args.dropout)
    elif args.model == 'GTEA-Trans+T2V':
        model = GTEATransT2V(num_nodes=num_nodes,
                           node_in_dim=node_in_dim, 
                           node_hidden_dim=args.node_hidden_dim,
                           edge_in_dim=edge_in_dim-1, 
                           time_hidden_dim=args.time_hidden_dim,
                           num_class=0, 
                           num_layers=args.num_layers, 
                           num_heads=args.num_heads,
                           num_time_layers=args.num_lstm_layers, 
                           device=device, 
                           drop_prob=args.dropout)
    else:
        logging.info('Model {} not found.'.format(args.model))
        exit(0)

    # send model to device
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    checkpoint_path = os.path.join(log_path, str(args.model) + '_checkpoint.pt')
    
    trainer = Trainer(g=g,
                     model=model, 
                     optimizer=optimizer, 
                     epochs=args.epochs, 
                     train_loader=train_loader, 
                     val_loader=val_loader, 
                     test_loader=test_loader,
                     patience=args.patience, 
                     batch_size=args.batch_size,
                     num_neighbors=args.num_neighbors, 
                     num_layers=args.num_layers, 
                     num_workers=args.num_workers, 
                     device=device,
                     infer_device=infer_device, 
                     log_path=log_path,
                     checkpoint_path=checkpoint_path)

    logging.info('Start training')

    best_val_result, test_result = trainer.train()

    # recording the result
    line = [start_time_str[1:]] + [args.model] + ['K=' + str(args.use_K)] + \
    [str(x) for x in best_val_result] + [str(x) for x in test_result] + [str(args)]
    line = ','.join(line) + '\n'

    with open(os.path.join(args.log_dir, str(args.model) + '_result.csv'), 'a') as f:
        f.write(line)
Example #7
0
class Trainer:
    def __init__(self, params):
        self.params = params
        
        self.device = torch.device('cpu' if self.params.gpu == -1 else f'cuda:{params.gpu}')
        print(self.device)
        # self.log_dir = get_dump_path(params) 

        self.batch_size = params.batch_size

        self.load_pretrained_model = params.load_pretrained_model
        self.pretrained_model_path = params.pretrained_model_path
        self.save_model_path = params.save_model_path

        self.using_mmd = params.using_mmd

        # data
        self.num_cells, self.num_genes, self.num_classes, self.graph, self.features, self.train_dataset, \
        self.train_mask, self.vali_mask, self.test_dataset = load_PPP_mammary_gland(params)
        # self.vae = torch.load('./saved_model/vae.pkl', self.features.device)
        # self.features = self.vae.get_hidden(self.features)
        # model
        self.model = GraphSAGE(in_feats=params.dense_dim,
                               n_hidden=params.hidden_dim,
                               n_classes=self.num_classes,
                               n_layers=params.n_layers,
                               activation=F.relu,
                               dropout=params.dropout,
                               aggregator_type=params.aggregator_type,
                               num_genes=self.num_genes)
        # self.model = GCN(
        #                  in_feats=params.dense_dim,
        #                  n_hidden=params.hidden_dim,
        #                  n_classes=self.num_classes,
        #                  n_layers=params.n_layers,
        #                  activation=F.relu)
        # self.model = GAT(in_feats=params.dense_dim,
        #                  n_hidden=100,
        #                  n_classes=self.num_classes,
        #                  n_layers=params.n_layers,
        #                  activation=F.relu)
        self.graph.readonly(readonly_state=True)
        self.model.to(self.device)
        self.features = self.features.to(self.device)
        self.train_mask = self.train_mask.to(self.device)
        self.vali_mask = self.vali_mask.to(self.device)
        self.train_dataset = self.train_dataset.to(self.device)
        self.trainset = TrainSet(self.train_dataset[self.train_mask])
        self.test_dataset = self.test_dataset.to(self.device)
        self.train_dataloader = DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        self.loss_weight = torch.Tensor(params.loss_weight).to(self.device)

    def train(self):
        if self.load_pretrained_model:
            print(f'load model from {self.pretrained_model_path}')
            self.model.load_state_dict(torch.load(self.pretrained_model_path))
        self.model.train()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.params.lr)
        loss_fn = nn.CrossEntropyLoss(weight=self.loss_weight)

        ll_loss = 1e5
        print("start train")
        for epoch in range(self.params.n_epochs):
            self.model.train()
            for step, (batch_x1, batch_x2, batch_y) in enumerate(self.train_dataloader):
                # list_tar = list(enumerate(self.test_dataloader))

                logits = self.model(self.graph, self.features, batch_x1, batch_x2)
                # import pdb; pdb.set_trace()
                loss = loss_fn(logits, batch_y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            _, _, train_loss = self.evaluate(self.train_mask)
            precision, recall, vali_loss = self.evaluate(self.vali_mask)
                
            # if vali_loss < ll_loss:
            #     torch.save(self.model.state_dict(), self.save_model_path)
            #     ll_loss = vali_loss
            if train_loss < ll_loss:
                torch.save(self.model.state_dict(), self.save_model_path)
                ll_loss = train_loss

            if epoch % 1 == 0:
                precision, recall, train_loss = self.evaluate(self.train_mask)
                print(f"Epoch {epoch:04d}: precesion {precision:.5f}, recall {recall:05f}, train loss: {vali_loss}")
                if self.params.just_train == 0:
                    precision, recall, vali_loss = self.evaluate(self.vali_mask)
                    print(f"Epoch {epoch:04d}: precesion {precision:.5f}, recall {recall:05f}, vali loss: {vali_loss}")
                    precision, recall, test_loss = self.test(self.test_dataset)
                    print(f"Epoch {epoch:04d}: precesion {precision:.5f}, recall {recall:05f}, test loss: {test_loss}")

    def evaluate(self, mask):
        self.model.eval()
        eval_dataset = self.train_dataset[mask]
        loss_fn = nn.CrossEntropyLoss(self.loss_weight)
        with torch.no_grad():
            logits = self.model(self.graph, self.features, eval_dataset[:, 0], eval_dataset[:, 1])
            loss = loss_fn(logits, eval_dataset[:, 2])
        _, indices = torch.max(logits, dim=1)
        ap_score = average_precision_score(eval_dataset[:,2].tolist(), indices.tolist())
        precision, recall, f1_score, _ = sklearn.metrics.precision_recall_fscore_support(eval_dataset[:,2].tolist(), indices.tolist(), labels=[0,1])
        return precision[1], recall[1], loss

    def test(self, test_dataset):
        self.model.eval()
        eval_dataset = test_dataset
        loss_fn = nn.CrossEntropyLoss(self.loss_weight)
        with torch.no_grad():
            logits = self.model(self.graph, self.features, eval_dataset[:, 0], eval_dataset[:, 1])
            loss = loss_fn(logits, eval_dataset[:, 2])
        _, indices = torch.max(logits, dim=1)
        precision, recall, f1_score, _ = sklearn.metrics.precision_recall_fscore_support(eval_dataset[:,2].tolist(), indices.tolist(), labels=[0,1])
        return precision[1], recall[1], loss
Example #8
0
                        type=int,
                        help='GraphSAGE hidden dimension')
    parser.add_argument('--drop',
                        dest="dropout",
                        default=0.5,
                        type=float,
                        help='Dropout for full connected layer')

    args = parser.parse_args()

    # set seed
    set_seed(args)

    # set logger
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    # set cuda
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    # load data
    data, num_classes = load_data(args)
    data.to(args.device)

    # train the model
    model = GraphSAGE(data, num_classes, args)
    model.to(args.device)
    train(args, data, model)
def build_model(args,
                in_feats,
                n_hidden,
                n_classes,
                device,
                n_layers=1,
                backend='geometric'):
    if args.model == 'graphsaint':
        assert backend == 'geometric'
        model_spec = args.variant
    else:
        model_spec = args.model

    if backend == 'geometric':
        print("Using Geometric Backend")
        if model_spec == 'gs-mean':
            model = geo.GraphSAGE(in_feats, n_hidden, n_classes, n_layers,
                                  F.relu, args.dropout).to(device)
        elif model_spec == "gcn":
            model = geo.GCN(in_feats, n_hidden, n_classes, n_layers, F.relu,
                            args.dropout).to(device)
        elif model_spec == "gat":
            print("Warning, GAT doesn't respect n_layers")
            heads = [8, args.gat_out_heads]  # Fixed head config
            n_hidden_per_head = int(n_hidden / heads[0])
            model = geo.GAT(in_feats, n_hidden_per_head, n_classes, F.relu,
                            args.dropout, 0.6, heads).to(device)
        elif model_spec == "mlp":
            model = geo.MLP(in_feats, n_hidden, n_classes, n_layers, F.relu,
                            args.dropout).to(device)
        elif model_spec == 'jknet-sageconv':
            # Geometric JKNEt with SAGECOnv
            model = JKNet(tg.nn.SAGEConv,
                          in_feats,
                          n_hidden,
                          n_classes,
                          n_layers,
                          F.relu,
                          args.dropout,
                          mode="cat",
                          conv_kwargs={
                              "normalize": False
                          },
                          backend="geometric").to(device)
        elif model_spec == 'jknet-graphconv':
            model = JKNet(tg.nn.GraphConv,
                          in_feats,
                          n_hidden,
                          n_classes,
                          n_layers,
                          F.relu,
                          args.dropout,
                          mode="cat",
                          conv_kwargs={
                              "aggr": "mean"
                          },
                          backend="geometric").to(device)
        elif model_spec == "sgnet":
            model = geo.SGNet(in_channels=in_feats,
                              out_channels=n_classes,
                              K=n_layers,
                              cached=True).to(device)
        else:
            raise NotImplementedError(
                f"Unknown model spec 'f{model_spec} for backend {backend}")
    elif backend == 'dgl':  # DGL models
        if model_spec == 'gs-mean':
            model = GraphSAGE(in_feats, n_hidden, n_classes, n_layers, F.relu,
                              args.dropout, 'mean').to(device)
        elif model_spec == 'mlp':
            model = MLP(in_feats, n_hidden, n_classes, n_layers, F.relu,
                        args.dropout).to(device)
        elif model_spec == 'mostfrequent':
            model = MostFrequentClass()
        elif model_spec == 'gat':
            print("Warning, GAT doesn't respect n_layers")
            heads = [8, args.gat_out_heads]  # Fixed head config
            # Div num_hidden by heads for same capacity
            n_hidden_per_head = int(n_hidden / heads[0])
            assert n_hidden_per_head * heads[
                0] == n_hidden, f"{n_hidden} not divisible by {heads[0]}"
            model = GAT(1, in_feats, n_hidden_per_head, n_classes, heads,
                        F.elu, 0.6, 0.6, 0.2, False).to(device)
        elif model_spec == 'node2vec':
            raise NotImplementedError(
                "Node2vec initializer needs to move to different location")
            # model = tg.nn.Node2Vec(
            #     edge_index,
            #     n_hidden,
            #     args.n2v_walk_length,
            #     args.n2v_context_size,
            #     walks_per_node=args.n2v_walks_per_node,
            #     p=args.n2v_p,
            #     q=args.n2v_q,
            #     num_negative_samples=args.n2v_num_negative_samples,
            #     num_nodes=num_nodes,
            #     sparse=True
            # )
        elif model_spec == 'jknet-sageconv':
            # DGL JKNet
            model = JKNet(dgl.nn.pytorch.SAGEConv,
                          in_feats,
                          n_hidden,
                          n_classes,
                          n_layers,
                          F.relu,
                          args.dropout,
                          mode="cat",
                          conv_args=["mean"],
                          backend='dgl').to(device)
        elif model_spec == 'sgnet':
            model = SGNet(in_feats,
                          n_classes,
                          k=n_layers,
                          cached=True,
                          bias=True,
                          norm=None).to(device)
        else:
            raise NotImplementedError(
                f"Unknown model spec 'f{model_spec} for backend {backend}")
    else:
        raise NotImplementedError(f"Unknown backend: {backend}")

    return model