Exemplo n.º 1
0
    def __init__(self, args):
        self.dataset = args.dataset
        self.device = torch.device(f'cuda:{args.cuda_num}' if args.cuda else 'cpu')
        if self.dataset in ["Cora", "Citeseer", "Pubmed", 'CoauthorCS']:
            self.data = load_data(self.dataset)
            self.loss_fn = torch.nn.functional.nll_loss
        else:
            raise Exception(f'the dataset of {self.dataset} has not been implemented')

        self.miss_rate = args.miss_rate
        if self.miss_rate > 0.:
            self.data.x = remove_feature(self.data, self.miss_rate)

        self.type_model = args.type_model
        self.epochs = args.epochs
        self.grad_clip = args.grad_clip
        self.weight_decay = args.weight_decay
        if self.type_model == 'GCN':
            self.model = GCN(args)
        elif self.type_model == 'simpleGCN':
            self.model = simpleGCN(args)
        elif self.type_model == 'GAT':
            self.model = GAT(args)
        else:
            raise Exception(f'the model of {self.type_model} has not been implemented')

        self.data.to(self.device)
        self.model.to(self.device)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.seed = args.random_seed
        self.type_norm = args.type_norm
        self.skip_weight = args.skip_weight
Exemplo n.º 2
0
 def __init__(self, input_dim, hidden, output_dim):
     super(SpatialBlock, self).__init__()
     self.input_dim = input_dim
     self.output_dim = output_dim
     self.homogeneousGAT = GAT(output_dim,
                               hidden,
                               output_dim,
                               nheads=3,
                               conv=False)
     self.residual = nn.Linear(input_dim, output_dim)
Exemplo n.º 3
0
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Load data
adj, features, labels, idx_train, idx_val, idx_test = load_data()

# Model and optimizer
if args.model == 'GCN':
    model = GCN(nfeat=features.shape[1],
            nhid=args.hidden,
            nclass=labels.max().item() + 1,
            dropout=args.dropout)
else:
    model = GAT(nfeat=features.shape[1],
                nhid=args.hidden,
                nclass=labels.max().item()+1,
                dropout=args.dropout,
                num_head=args.num_head)

optimizer = optim.Adam(model.parameters(),
                       lr=args.lr, weight_decay=args.weight_decay)

if args.cuda:
    model.cuda()
    features = features.cuda()
    adj = adj.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()
Exemplo n.º 4
0
def choose_model(conf):
    if conf['model_name'] == 'GCN':
        model = GCN(g=G,
                    in_feats=features.shape[1],
                    n_hidden=conf['hidden'],
                    n_classes=labels.max().item() + 1,
                    n_layers=1,
                    activation=F.relu,
                    dropout=conf['dropout']).to(conf['device'])
    elif conf['model_name'] in ['GAT', 'SGAT']:
        if conf['model_name'] == 'GAT':
            num_heads = 8
        else:
            num_heads = 1
        num_layers = 1
        num_out_heads = 1
        heads = ([num_heads] * num_layers) + [num_out_heads]
        model = GAT(
            g=G,
            num_layers=num_layers,
            in_dim=features.shape[1],
            num_hidden=8,
            num_classes=labels.max().item() + 1,
            heads=heads,
            activation=F.relu,
            feat_drop=0.6,
            attn_drop=0.6,
            negative_slope=0.2,  # negative slope of leaky relu
            residual=False).to(conf['device'])
    elif conf['model_name'] == 'GraphSAGE':
        model = GraphSAGE(in_feats=features.shape[1],
                          n_hidden=conf['embed_dim'],
                          n_classes=labels.max().item() + 1,
                          n_layers=2,
                          activation=F.relu,
                          dropout=0.5,
                          aggregator_type=conf['agg_type']).to(conf['device'])
    elif conf['model_name'] == 'APPNP':
        model = APPNP(g=G,
                      in_feats=features.shape[1],
                      hiddens=[64],
                      n_classes=labels.max().item() + 1,
                      activation=F.relu,
                      feat_drop=0.5,
                      edge_drop=0.5,
                      alpha=0.1,
                      k=10).to(conf['device'])
    elif conf['model_name'] == 'MoNet':
        model = MoNet(g=G,
                      in_feats=features.shape[1],
                      n_hidden=64,
                      out_feats=labels.max().item() + 1,
                      n_layers=1,
                      dim=2,
                      n_kernels=3,
                      dropout=0.7).to(conf['device'])
    elif conf['model_name'] == 'SGC':
        model = SGConv(in_feats=features.shape[1],
                       out_feats=labels.max().item() + 1,
                       k=2,
                       cached=True,
                       bias=False).to(conf['device'])
    elif conf['model_name'] == 'GCNII':
        if conf['dataset'] == 'citeseer':
            conf['layer'] = 32
            conf['hidden'] = 256
            conf['lamda'] = 0.6
            conf['dropout'] = 0.7
        elif conf['dataset'] == 'pubmed':
            conf['hidden'] = 256
            conf['lamda'] = 0.4
            conf['dropout'] = 0.5
        model = GCNII(nfeat=features.shape[1],
                      nlayers=conf['layer'],
                      nhidden=conf['hidden'],
                      nclass=labels.max().item() + 1,
                      dropout=conf['dropout'],
                      lamda=conf['lamda'],
                      alpha=conf['alpha'],
                      variant=False).to(conf['device'])
    return model
Exemplo n.º 5
0
def choose_model(conf, G, features, labels, byte_idx_train, labels_one_hot):
    if conf['model_name'] == 'GCN':
        model = GCN(g=G,
                    in_feats=features.shape[1],
                    n_hidden=conf['hidden'],
                    n_classes=labels.max().item() + 1,
                    n_layers=1,
                    activation=F.relu,
                    dropout=conf['dropout']).to(conf['device'])
    elif conf['model_name'] == 'GAT':
        num_heads = 8
        num_layers = 1
        num_out_heads = 1
        heads = ([num_heads] * num_layers) + [num_out_heads]
        model = GAT(
            g=G,
            num_layers=num_layers,
            in_dim=G.ndata['feat'].shape[1],
            num_hidden=8,
            num_classes=labels.max().item() + 1,
            heads=heads,
            activation=F.relu,
            feat_drop=0.6,
            attn_drop=0.6,
            negative_slope=0.2,  # negative slope of leaky relu
            residual=False).to(conf['device'])
    elif conf['model_name'] == 'PLP':
        model = PLP(g=G,
                    num_layers=conf['num_layers'],
                    in_dim=G.ndata['feat'].shape[1],
                    emb_dim=conf['emb_dim'],
                    num_classes=labels.max().item() + 1,
                    activation=F.relu,
                    feat_drop=conf['feat_drop'],
                    attn_drop=conf['attn_drop'],
                    residual=False,
                    byte_idx_train=byte_idx_train,
                    labels_one_hot=labels_one_hot,
                    ptype=conf['ptype'],
                    mlp_layers=conf['mlp_layers']).to(conf['device'])
    elif conf['model_name'] == 'GraphSAGE':
        model = GraphSAGE(in_feats=G.ndata['feat'].shape[1],
                          n_hidden=16,
                          n_classes=labels.max().item() + 1,
                          n_layers=1,
                          activation=F.relu,
                          dropout=0.5,
                          aggregator_type=conf['agg_type']).to(conf['device'])
    elif conf['model_name'] == 'APPNP':
        model = APPNP(g=G,
                      in_feats=G.ndata['feat'].shape[1],
                      hiddens=[64],
                      n_classes=labels.max().item() + 1,
                      activation=F.relu,
                      feat_drop=0.5,
                      edge_drop=0.5,
                      alpha=0.1,
                      k=10).to(conf['device'])
    elif conf['model_name'] == 'LogReg':
        model = MLP(num_layers=1,
                    input_dim=G.ndata['feat'].shape[1],
                    hidden_dim=None,
                    output_dim=labels.max().item() + 1,
                    dropout=0).to(conf['device'])
    elif conf['model_name'] == 'MLP':
        model = MLP(num_layers=2,
                    input_dim=G.ndata['feat'].shape[1],
                    hidden_dim=conf['hidden'],
                    output_dim=labels.max().item() + 1,
                    dropout=conf['dropout']).to(conf['device'])
    else:
        raise ValueError(f'Undefined Model.')
    return model
Exemplo n.º 6
0
        
        # Build graph classification model
        if model_name == 'GCN':
            model = GCN(n_feat=datareader.data['features_dim'],
                    n_class=datareader.data['n_classes'],
                    n_layer=args.n_agg_layer,
                    agg_hidden=args.agg_hidden,
                    fc_hidden=args.fc_hidden,
                    dropout=args.dropout,
                    readout=readout_name,
                    device=device).to(device)
        elif model_name == 'GAT':
            model = GAT(n_feat=datareader.data['features_dim'],
                    n_class=datareader.data['n_classes'],
                    n_layer=args.n_agg_layer,
                    agg_hidden=args.agg_hidden,
                    fc_hidden=args.fc_hidden,
                    dropout=args.dropout,
                    readout=readout_name,
                    device=device).to(device)
#        elif model_name == 'GraphWaveletNet':
#            model = GraphWaveletNet(n_feat=datareader.data['features_dim'],
#                    n_class=datareader.data['n_classes'],
#                    n_layer=args.n_agg_layer,
#                    agg_hidden=args.agg_hidden,
#                    fc_hidden=args.fc_hidden,
#                    dropout=args.dropout,
#                    readout=readout_name,
#                    device=device,
#                    n_node=datareader.data['N_nodes_max']).to(device)
        elif model_name == 'GraphSAGE':
            model = GraphUNet(n_feat=datareader.data['features_dim'],
Exemplo n.º 7
0
    def __init__(self, args):
        self.dataset = args.dataset
        self.device = torch.device(
            f"cuda:{args.cuda_num}" if args.cuda else "cpu")
        if self.dataset in ["Cora", "Citeseer", "Pubmed", "CoauthorCS"]:
            if args.ptb:
                self.data = load_perterbued_data(self.dataset, args.ptb_rate,
                                                 args.ptb_type)
                self.loss_fn = torch.nn.functional.nll_loss
            else:
                self.data = load_data(self.dataset)
                self.loss_fn = torch.nn.functional.nll_loss
        elif self.dataset in ["PPI"]:
            self.data = load_ppi_data()
            self.loss_fn = torch.nn.BCEWithLogitsLoss()
        else:
            raise Exception(
                f"the dataset of {self.dataset} has not been implemented")

        self.entropy_loss = torch.nn.functional.binary_cross_entropy_with_logits

        self.type_model = args.type_model
        self.epochs = args.epochs
        self.weight_decay = args.weight_decay
        self.alpha = args.alpha
        self.gamma = args.gamma
        self.beta = args.beta
        self.lamb = args.lamb
        self.num_classes = args.num_classes
        self.ptb_rate = args.ptb_rate
        self.ptb_type = args.ptb_type
        self.metric = args.metric
        self.num_layers = args.num_layers

        if self.type_model == "GCN":
            self.model = GCN(args)
        elif self.type_model == "GAT":
            self.model = GAT(args)
        elif self.type_model == "NLGCN":
            self.model = NLGCN(args)
        elif self.type_model == "g_U_Net":
            self.model = gunet(args)
        elif self.type_model == "JKNet":
            self.model = JKNetMaxpool(args)
        elif self.type_model == "SGC":
            self.model = simpleGCN(args)
        elif self.type_model == "APPNP":
            self.model = APPNP(args)
        else:
            raise Exception(
                f"the model of {self.type_model} has not been implemented")

        if self.dataset in ["Cora", "Citeseer", "Pubmed", "CoauthorCS"]:
            if args.ptb:
                self.data.edge_index, self.data.x, self.data.y = utils.preprocess(
                    self.data.edge_index,
                    self.data.x,
                    self.data.y,
                    preprocess_adj=False,
                    sparse=False,
                    device=self.device)

            else:
                self.data.to(self.device)
        self.model.to(self.device)

        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)

        wandb.init(project="Gref", config=args)
        wandb.watch(self.model)
Exemplo n.º 8
0
                              num_workers=3,
                              batch_size=BATCH_SIZE)

    #
    # g = datasets.GraphDataset(load_cora_data(args.e))
    with open("test_data.pkl", 'rb') as f:
        # pickle.dump(g, f)
        g = pickle.load(f)

    test_loader = DataLoader(g,
                             collate_fn=datasets.graph_collate,
                             shuffle=True,
                             num_workers=3,
                             batch_size=BATCH_SIZE)

    net = GAT(133, 14).to(dev)
    print("TOTAL PARMS",
          sum(p.numel() for p in net.parameters() if p.requires_grad))

    # create optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

    dur = []

    lossf = F.mse_loss
    second_lossf = torch.nn.MSELoss(reduction='none')
    for epoch in range(50):
        net.train()
        train_avg = Avg()
        train2_avg = Avg()
        # if epoch < 10: