Example #1
0
def main(args):
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    graph = load(args.dataset)

    model = SkipGramModel(
        graph.num_nodes,
        args.embed_size,
        args.neg_num,
        sparse=not args.use_cuda)
    model = paddle.DataParallel(model)

    optim = Adam(
        learning_rate=args.learning_rate,
        parameters=model.parameters(),
        weight_decay=args.weight_decay)

    train_ds = ShardedDataset(graph.nodes)
    collate_fn = BatchRandWalk(graph, args.walk_len, args.win_size,
                               args.neg_num, args.neg_sample_type)
    data_loader = Dataloader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.sample_workers,
        collate_fn=collate_fn)

    for epoch in tqdm.tqdm(range(args.epoch)):
        train_loss = train(model, data_loader, optim)
        log.info("Runing epoch:%s\t train_loss:%.6f", epoch, train_loss)
Example #2
0
    def get_data_loader(self,
                        batch_size,
                        num_workers=4,
                        shuffle=False,
                        collate_fn=None):
        """
        It returns an batch iterator which yields a batch of data. Firstly, a sub-list of
        `data` of size ``batch_size`` will be draw from the ``data_list``, then 
        the function ``collate_fn`` will be applied to the sub-list to create a batch and 
        yield back. This process is accelerated by multiprocess.

        Args:
            batch_size(int): the batch_size of the batch data of each yield.
            num_workers(int): the number of workers used to generate batch data. Required by 
                multiprocess.
            shuffle(bool): whether to shuffle the order of the ``data_list``.
            collate_fn(function): used to convert the sub-list of ``data_list`` to the 
                aggregated batch data.

        Yields:
            the batch data processed by ``collate_fn``.
        """
        return Dataloader(self,
                          batch_size=batch_size,
                          num_workers=num_workers,
                          shuffle=shuffle,
                          collate_fn=collate_fn)
Example #3
0
def mp_pool_map(list_input, func, num_workers):
    """list_output = [func(input) for input in list_input]"""
    class _CollateFn(object):
        def __init__(self, func):
            self.func = func
        def __call__(self, data_list):
            new_data_list = []
            for data in data_list:
                index, input = data
                new_data_list.append((index, self.func(input)))
            return new_data_list

    # add index
    list_new_input = [(index, x) for index, x in enumerate(list_input)]
    data_gen = Dataloader(list_new_input, 
            batch_size=8, 
            num_workers=num_workers, 
            shuffle=False,
            collate_fn=_CollateFn(func))  

    list_output = []
    for sub_outputs in data_gen:
        list_output += sub_outputs
    list_output = sorted(list_output, key=lambda x: x[0])
    # remove index
    list_output = [x[1] for x in list_output]
    return list_output
Example #4
0
def main(args):
    if not args.use_cuda:
        paddle.set_device("cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    graph = load(args.dataset)

    model = SkipGramModel(graph.num_nodes,
                          args.embed_size,
                          args.neg_num,
                          sparse=not args.use_cuda)
    model = paddle.DataParallel(model)

    train_steps = int(graph.num_nodes / args.batch_size) * args.epoch
    scheduler = paddle.optimizer.lr.PolynomialDecay(
        learning_rate=args.learning_rate,
        decay_steps=train_steps,
        end_lr=0.0001)

    optim = Adam(learning_rate=scheduler, parameters=model.parameters())

    train_ds = ShardedDataset(graph.nodes)
    collate_fn = BatchRandWalk(graph, args.walk_len, args.win_size,
                               args.neg_num, args.neg_sample_type)
    data_loader = Dataloader(train_ds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.sample_workers,
                             collate_fn=collate_fn)

    for epoch in tqdm.tqdm(range(args.epoch)):
        train_loss = train(model, data_loader, optim)
        log.info("Runing epoch:%s\t train_loss:%.6f", epoch, train_loss)
    paddle.save(model.state_dict(), "model.pdparams")
Example #5
0
File: test.py Project: WenjinW/PGL
def infer(config, output_path):
    model = getattr(M, config.model_type)(config)

    log.info("infer model from %s" % config.infer_from)
    model.set_state_dict(paddle.load(config.infer_from))

    log.info("loading data")
    ds = getattr(DS, config.dataset_type)(config)

    split_idx = ds.get_idx_split()
    test_ds = DS.Subset(ds, split_idx['test'], mode='test')
    log.info("Test exapmles: %s" % len(test_ds))

    test_loader = Dataloader(test_ds,
                             batch_size=config.valid_batch_size,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=DS.CollateFn(config))

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    # ---------------- test ----------------------- #
    log.info("testing ...")
    pred_dict = evaluate(model, test_loader)

    test_output_path = os.path.join(config.output_dir, config.task_name)
    make_dir(test_output_path)
    test_output_file = os.path.join(test_output_path, "test_pred.npz")

    log.info("saving test result to %s" % test_output_file)
    np.savez_compressed(test_output_file,
                        pred_dict['y_pred'].astype(np.float32))
Example #6
0
def test_InferDataset():
    config_file = "../../../config.yaml"
    ip_list_file = "../../../ip_list.txt"
    config = prepare_config(config_file)

    ds = InferDataset(config, ip_list_file)
    loader = Dataloader(ds, batch_size=1, num_workers=1)
    for data in loader:
        print(data[0])
        break
Example #7
0
def main(args):
    paddle.set_device("cpu")
    paddle.enable_static()

    fleet.init()

    fake_num_nodes = 1
    py_reader, loss = StaticSkipGramModel(
        fake_num_nodes,
        args.neg_num,
        args.embed_size,
        sparse_embedding=True,
        shared_embedding=args.shared_embedding)

    optimizer = F.optimizer.Adam(args.learning_rate, lazy_mode=True)
    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.a_sync = True
    optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
    optimizer.minimize(loss)

    # init and run server or worker
    if fleet.is_server():
        fleet.init_server()
        fleet.run_server()

    if fleet.is_worker():
        place = paddle.CPUPlace()
        exe = paddle.static.Executor(place)
        exe.run(paddle.static.default_startup_program())
        fleet.init_worker()

        graph = build_graph(args)
        # bind gen
        train_ds = ShardedDataset(graph.nodes, args.epoch)
        collate_fn = BatchRandWalk(graph, args.walk_len, args.win_size,
                                   args.neg_num, args.neg_sample_type)
        data_loader = Dataloader(train_ds,
                                 batch_size=args.cpu_batch_size,
                                 shuffle=True,
                                 num_workers=args.sample_workers,
                                 collate_fn=collate_fn)
        py_reader.set_batch_generator(lambda: data_loader)

        train_loss = train(exe, paddle.static.default_main_program(),
                           py_reader, loss)
        fleet.stop_worker()

        if fleet.is_first_worker():
            fleet.save_persistables(exe, "./model",
                                    paddle.static.default_main_program())
Example #8
0
def main(args):
    paddle.set_device("cpu")
    paddle.enable_static()
    role = role_maker.PaddleCloudRoleMaker()
    fleet.init(role)

    if args.num_nodes is None:
        num_nodes = load(args.dataset).num_nodes
    else:
        num_nodes = args.num_nodes

    loss = StaticSkipGramModel(
        num_nodes, args.neg_num, args.embed_size, sparse=True)

    optimizer = F.optimizer.Adam(args.learning_rate, lazy_mode=True)
    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.a_sync = True
    optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
    optimizer.minimize(loss)

    # init and run server or worker
    if fleet.is_server():
        fleet.init_server()
        fleet.run_server()

    if fleet.is_worker():
        place = paddle.CPUPlace()
        exe = paddle.static.Executor(place)
        exe.run(paddle.static.default_startup_program())
        fleet.init_worker()

        graph = load(args.dataset)
        # bind gen
        train_ds = ShardedDataset(graph.nodes)
        collate_fn = BatchRandWalk(graph, args.walk_len, args.win_size,
                                   args.neg_num, args.neg_sample_type)
        data_loader = Dataloader(
            train_ds,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.sample_workers,
            collate_fn=collate_fn)

        for epoch in range(args.epoch):
            train_loss = train(exe,
                               paddle.static.default_main_program(),
                               data_loader, loss)
            log.info("Runing epoch:%s\t train_loss:%.6f", epoch, train_loss)
        fleet.stop_worker()
Example #9
0
def main(args):
    paddle.enable_static()
    paddle.set_device('gpu:%d' % paddle.distributed.ParallelEnv().dev_id)

    fleet.init(is_collective=True)

    graph = load(args.dataset)

    loss = StaticSkipGramModel(graph.num_nodes,
                               args.neg_num,
                               args.embed_size,
                               num_emb_part=args.num_emb_part,
                               shared_embedding=args.shared_embedding)

    optimizer = F.optimizer.Adam(args.learning_rate)
    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.sharding = True
    dist_strategy.sharding_configs = {
        "segment_anchors": None,
        "sharding_segment_strategy": "segment_broadcast_MB",
        "segment_broadcast_MB": 32,
        "sharding_degree": int(paddle.distributed.get_world_size()),
    }
    optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
    optimizer.minimize(loss)

    place = paddle.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
    exe = paddle.static.Executor(place)
    exe.run(paddle.static.default_startup_program())

    # bind gen
    train_ds = ShardedDataset(graph.nodes)
    collate_fn = BatchRandWalk(graph, args.walk_len, args.win_size,
                               args.neg_num, args.neg_sample_type)
    data_loader = Dataloader(train_ds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.sample_workers,
                             collate_fn=collate_fn)

    for epoch in range(args.epoch):
        train_loss = train(exe, paddle.static.default_main_program(),
                           data_loader, loss)
        log.info("Runing epoch:%s\t train_loss:%.6f", epoch, train_loss)
    fleet.stop_worker()

    if fleet.is_first_worker():
        fleet.save_persistables(exe, "./model",
                                paddle.static.default_main_program())
Example #10
0
 def get_data_loader(self, batch_size, num_workers=1,
                     shuffle=False, collate_fn=None):
     """Get dataloader.
     Args:
         batch_size (int): number of data items in a batch.
         num_workers (int): number of parallel workers.
         shuffle (int): whether to shuffle yield data.
         collate_fn: callable function that processes batch data to a list of paddle tensor.
     """
     return Dataloader(
         self,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=shuffle,
         collate_fn=collate_fn)
Example #11
0
def test_PairDataset():
    config_file = "../../../config.yaml"
    ip_list_file = "../../../ip_list.txt"
    config = prepare_config(config_file)

    ds = TrainPairDataset(config, ip_list_file)

    loader = Dataloader(ds,
                        batch_size=4,
                        num_workers=1,
                        stream_shuffle_size=100,
                        collate_fn=CollateFn())
    pairs = []
    start = time.time()
    for batch_data in loader:
        pairs.extend(batch_data)
        print(batch_data)
        time.sleep(10)
    print("total time: %s" % (time.time() - start))
Example #12
0
def main(args):
    if not args.use_cuda:
        paddle.set_device("cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    if args.edge_file:
        graph = load_from_file(args.edge_file)
    else:
        graph = load(args.dataset)

    edges = np.load("./edges.npy")
    edges = np.concatenate([edges, edges[:, [1, 0]]])
    graph = pgl.Graph(edges)

    model = SkipGramModel(graph.num_nodes,
                          args.embed_size,
                          args.neg_num,
                          sparse=not args.use_cuda)
    model = paddle.DataParallel(model)

    train_ds = ShardedDataset(graph.nodes, repeat=args.epoch)

    train_steps = int(len(train_ds) // args.batch_size)
    log.info("train_steps: %s" % train_steps)
    scheduler = paddle.optimizer.lr.PolynomialDecay(
        learning_rate=args.learning_rate,
        decay_steps=train_steps,
        end_lr=0.0001)

    optim = Adam(learning_rate=scheduler, parameters=model.parameters())

    collate_fn = BatchRandWalk(graph, args.walk_len, args.win_size,
                               args.neg_num, args.neg_sample_type)
    data_loader = Dataloader(train_ds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.sample_workers,
                             collate_fn=collate_fn)

    train_loss = train(model, data_loader, optim)
    paddle.save(model.state_dict(), "model.pdparams")
Example #13
0
File: train.py Project: WenjinW/PGL
def main(config, ip_list_file):
    ds = TrainPairDataset(config, ip_list_file)
    loader = Dataloader(
        ds,
        batch_size=config.batch_pair_size,
        num_workers=config.num_workers,
        stream_shuffle_size=config.pair_stream_shuffle_size,
        collate_fn=CollateFn())

    model = SkipGramModel(config)

    if config.warm_start_from:
        log.info("warm start from %s" % config.warm_start_from)
        model.set_state_dict(paddle.load(config.warm_start_from))

    optim = Adam(
        learning_rate=config.lr,
        parameters=model.parameters(),
        lazy_mode=config.lazy_mode)

    log.info("starting training...")
    train(config, model, loader, optim)
Example #14
0
def main(args):
    """
    Args:
        -ddi: drug drug synergy file.
        -rna: cell line gene expression file.
        -lincs: gene embeddings.
        -dropout: dropout rate for transformer blocks.
        -epochs: training epochs.
        -batch_size
        -lr: learning rate.

    """
    #paddle.set_device('cpu')
    ddi = pd.read_csv(args.ddi)
    rna = pd.read_csv(args.rna, index_col=0)
    lincs = pd.read_csv(args.lincs, index_col=0, header=None).values
    lincs = paddle.to_tensor(lincs, 'float32')

    ##############independent validation############
    #5-fold cross validation
    """NUM_CROSS = 5
    ddi_shuffle = shuffle(ddi)
    data_size = len(ddi)
    fold_num = int(data_size / NUM_CROSS)
    for fold in range(NUM_CROSS):
        ddi_test = ddi_shuffle.iloc[fold*fold_num:fold_num * (fold + 1), :]
        ddi_train_before = ddi_shuffle.iloc[:fold*fold_num, :]
        ddi_train_after = ddi_shuffle.iloc[fold_num * (fold + 1):, :]
        ddi_train = pd.concat([ddi_train_before, ddi_train_after])"""

    ddi_train = ddi.copy()
    train_cell = join_cell(ddi_train, rna)
    bt_tr = DDsData(ddi_train['drug1'].values, ddi_train['drug2'].values,
                    train_cell, ddi_train['label'].values)
    """test_cell = join_cell(ddi_test, rna)
            #test_pta, test_ptb = join_pert(ddi_test, drugs_pert)
        bt_test = DDsData(ddi_test['drug1'].values,   
                    ddi_test['drug2'].values, 
                    test_cell, 
                    
                    ddi_test['label'].values)"""

    loader_tr = Dataloader(bt_tr,
                           batch_size=args.batch_size,
                           num_workers=4,
                           collate_fn=collate)
    #loader_test = Dataloader(bt_test, batch_size=args.batch_size, num_workers=4, collate_fn=collate)
    #loader_val = Dataloader(bt_val, batch_size=args.batch_size, num_workers=1, collate_fn=collate)

    model = TSNet(num_drug_feat=78,
                  num_L_feat=978,
                  num_cell_feat=rna.shape[1],
                  num_drug_out=128,
                  coarsed_heads=4,
                  fined_heads=4,
                  coarse_hidd=64,
                  fine_hidd=64,
                  dropout=args.dropout)
    opt = paddle.optimizer.Adam(learning_rate=args.lr,
                                parameters=model.parameters())
    loss_fn = paddle.nn.CrossEntropyLoss()

    for e in range(args.epochs):
        train_loss = train(model, loader_tr, lincs, loss_fn, opt)
        print('Epoch {}---training loss:{}'.format(e, train_loss))
        t_auc, test_prauc, test_loss, acc, bacc, prec, tpr, kappa = test_auc(
            model, loader_test, lincs, loss_fn)
        print(
            '---Testing loss:{:.4f}, AUC:{:.4f}, PRAUC:{:.4f}, ACC:{:.4f}, BACC:{:.4f}, PREC:{:.4f}, TPR:{:.4f}, KAPPA:{:.4f}'
            .format(test_loss, t_auc, test_prauc, acc, bacc, prec, tpr, kappa))
Example #15
0
File: main.py Project: WenjinW/PGL
def main(config):
    if dist.get_world_size() > 1:
        dist.init_parallel_env()

    if not config.use_cuda:
        paddle.set_device("cpu")

    model = getattr(M, config.model_type)(config)

    if config.warm_start_from:
        log.info("warm start from %s" % config.warm_start_from)
        model.set_state_dict(paddle.load(config.warm_start_from))

    model = paddle.DataParallel(model)

    num_params = sum(p.numel() for p in model.parameters())
    log.info("total Parameters: %s" % num_params)

    if config.lr_mode == "step_decay":
        scheduler = paddle.optimizer.lr.StepDecay(learning_rate=config.lr,
                                                  step_size=config.step_size,
                                                  gamma=config.gamma)
    elif config.lr_mode == "multistep":
        scheduler = paddle.optimizer.lr.MultiStepDecay(
            learning_rate=config.lr,
            milestones=config.milestones,
            gamma=config.gamma)
    elif config.lr_mode == "piecewise":
        log.info(['boundery: ', config.boundery])
        log.info(['lr_value: ', config.lr_value])
        for i in config.lr_value:
            if not isinstance(i, float):
                raise "lr_value %s is not float number" % i
        scheduler = paddle.optimizer.lr.PiecewiseDecay(config.boundery,
                                                       config.lr_value)
    elif config.lr_mode == "Reduce":
        scheduler = paddle.optimizer.lr.ReduceOnPlateau(
            learning_rate=config.lr,
            factor=config.factor,
            patience=config.patience)
    else:
        scheduler = config.lr
    optimizer = getattr(paddle.optimizer,
                        config.optim)(learning_rate=scheduler,
                                      parameters=model.parameters())

    log.info("loading data")
    ds = getattr(DS, config.dataset_type)(config)

    if config.split_mode == "cross1":
        split_idx = ds.get_cross_idx_split()
        train_ds = DS.Subset(ds, split_idx['cross_train_1'], mode='train')
        valid_ds = DS.Subset(ds, split_idx['cross_valid_1'], mode='valid')
        left_valid_ds = DS.Subset(ds,
                                  split_idx['valid_left_1percent'],
                                  mode='valid')
        test_ds = DS.Subset(ds, split_idx['test'], mode='test')
    elif config.split_mode == "cross1_few":
        split_idx = ds.get_cross_idx_split()
        train_ds = DS.Subset(ds,
                             split_idx['cross_train_1'][:10000],
                             mode='train')
        valid_ds = DS.Subset(ds,
                             split_idx['cross_train_1'][10000:11000],
                             mode='valid')
        left_valid_ds = DS.Subset(ds,
                                  split_idx['cross_train_1'][10000:11000],
                                  mode='valid')
        test_ds = DS.Subset(ds,
                            split_idx['cross_train_1'][11000:12000],
                            mode='test')
    elif config.split_mode == "cross2":
        split_idx = ds.get_cross_idx_split()
        train_ds = DS.Subset(ds, split_idx['cross_train_2'], mode='train')
        valid_ds = DS.Subset(ds, split_idx['cross_valid_2'], mode='valid')
        left_valid_ds = DS.Subset(ds,
                                  split_idx['valid_left_1percent'],
                                  mode='valid')
        test_ds = DS.Subset(ds, split_idx['test'], mode='test')
    else:
        split_idx = ds.get_idx_split()
        train_ds = DS.Subset(ds, split_idx['train'], mode='train')
        valid_ds = DS.Subset(ds, split_idx['valid'], mode='valid')
        left_valid_ds = DS.Subset(ds, split_idx['valid'], mode='valid')
        test_ds = DS.Subset(ds, split_idx['test'], mode='test')

    log.info("Train exapmles: %s" % len(train_ds))
    log.info("Valid exapmles: %s" % len(valid_ds))
    log.info("Test exapmles: %s" % len(test_ds))
    log.info("Left Valid exapmles: %s" % len(left_valid_ds))

    train_loader = Dataloader(train_ds,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.num_workers,
                              collate_fn=DS.CollateFn(config),
                              drop_last=True)

    valid_loader = Dataloader(valid_ds,
                              batch_size=config.valid_batch_size,
                              shuffle=False,
                              num_workers=1,
                              collate_fn=DS.CollateFn(config))

    left_valid_loader = Dataloader(left_valid_ds,
                                   batch_size=config.valid_batch_size,
                                   shuffle=False,
                                   num_workers=1,
                                   collate_fn=DS.CollateFn(config))

    test_loader = Dataloader(test_ds,
                             batch_size=config.valid_batch_size,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=DS.CollateFn(config))

    if config.split_mode is not None:
        valids = {'valid': valid_loader, 'left': left_valid_loader}
    else:
        valids = {'valid': valid_loader}

    if config.pretrain_tasks:
        pretrain_train_and_eval(model, config, train_loader, valids,
                                test_loader, optimizer, scheduler)
    else:
        train_and_eval(model, config, train_loader, valids, test_loader,
                       optimizer, scheduler)
Example #16
0
        list, zip(*batch))

    a2a_g = pgl.Graph.batch(a2a_gs).tensor()
    b2a_g = pgl.BiGraph.batch(b2a_gs).tensor()
    b2b_gl = [
        pgl.Graph.batch([g[i] for g in b2b_gs_l]).tensor()
        for i in range(len(b2b_gs_l[0]))
    ]
    feats = paddle.concat(
        [paddle.to_tensor(f, dtype='float32') for f in feats])
    types = paddle.concat([paddle.to_tensor(t) for t in types])
    counts = paddle.stack([paddle.to_tensor(c) for c in counts], axis=1)
    labels = paddle.to_tensor(np.array(labels), dtype='float32')

    return a2a_g, b2a_g, b2b_gl, feats, types, counts, labels


if __name__ == "__main__":
    complex_data = ComplexDataset("./data/", "pdbbind2016_test", 5, 6)
    loader = Dataloader(complex_data,
                        batch_size=32,
                        shuffle=False,
                        num_workers=1,
                        collate_fn=collate_fn)
    cc = 0
    for batch in loader:
        a2a_g, b2a_g, b2b_gl, feats, types, counts, labels = batch
        print(labels)
        cc += 1
        if cc == 2:
            break
Example #17
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with PGL')
    parser.add_argument('--use_cuda', action='store_true')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset', action='store_true')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=1,
                        help='number of workers (default: 1)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    random.seed(42)
    np.random.seed(42)
    paddle.seed(42)

    if not args.use_cuda:
        paddle.set_device("cpu")

    ### automatic dataloading and splitting
    class Config():
        def __init__(self):
            self.base_data_path = "./dataset"

    config = Config()
    ds = MolDataset(config)
    split_idx = ds.get_idx_split()
    test_ds = Subset(ds, split_idx['test'])

    print("Test exapmles: ", len(test_ds))

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    test_loader = Dataloader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             collate_fn=CollateFn())

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False, **shared_params)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True, **shared_params)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False, **shared_params)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True, **shared_params)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pdparams')
    if not os.path.exists(checkpoint_path):
        raise RuntimeError(f'Checkpoint file not found at {checkpoint_path}')

    model.set_state_dict(paddle.load(checkpoint_path))

    print('Predicting on test data...')
    y_pred = test(model, test_loader)
    print('Saving test submission file...')
    evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir)
Example #18
0
File: main.py Project: Yelrose/PGL
def main(config):
    if dist.get_world_size() > 1:
        dist.init_parallel_env()

    if dist.get_rank() == 0:
        timestamp = datetime.now().strftime("%Hh%Mm%Ss")
        log_path = os.path.join(config.log_dir,
                                "tensorboard_log_%s" % timestamp)
        writer = SummaryWriter(log_path)

    log.info("loading data")
    raw_dataset = GraphPropPredDataset(name=config.dataset_name)
    config.num_class = raw_dataset.num_tasks
    config.eval_metric = raw_dataset.eval_metric
    config.task_type = raw_dataset.task_type

    mol_dataset = MolDataset(config,
                             raw_dataset,
                             transform=make_multihop_edges)
    splitted_index = raw_dataset.get_idx_split()
    train_ds = Subset(mol_dataset, splitted_index['train'], mode='train')
    valid_ds = Subset(mol_dataset, splitted_index['valid'], mode="valid")
    test_ds = Subset(mol_dataset, splitted_index['test'], mode="test")

    log.info("Train Examples: %s" % len(train_ds))
    log.info("Val Examples: %s" % len(valid_ds))
    log.info("Test Examples: %s" % len(test_ds))

    fn = CollateFn(config)

    train_loader = Dataloader(train_ds,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.num_workers,
                              collate_fn=fn)

    valid_loader = Dataloader(valid_ds,
                              batch_size=config.batch_size,
                              num_workers=config.num_workers,
                              collate_fn=fn)

    test_loader = Dataloader(test_ds,
                             batch_size=config.batch_size,
                             num_workers=config.num_workers,
                             collate_fn=fn)

    model = ClassifierNetwork(config.hidden_size, config.out_dim,
                              config.num_layers, config.dropout_prob,
                              config.virt_node, config.K, config.conv_type,
                              config.appnp_hop, config.alpha)
    model = paddle.DataParallel(model)

    optim = Adam(learning_rate=config.lr, parameters=model.parameters())
    criterion = nn.loss.BCEWithLogitsLoss()

    evaluator = Evaluator(config.dataset_name)

    best_valid = 0

    global_step = 0
    for epoch in range(1, config.epochs + 1):
        model.train()
        for idx, batch_data in enumerate(train_loader):
            g, mh_graphs, labels, unmask = batch_data
            g = g.tensor()
            multihop_graphs = []
            for item in mh_graphs:
                multihop_graphs.append(item.tensor())
            g.multi_hop_graphs = multihop_graphs
            labels = paddle.to_tensor(labels)
            unmask = paddle.to_tensor(unmask)

            pred = model(g)
            pred = paddle.masked_select(pred, unmask)
            labels = paddle.masked_select(labels, unmask)
            train_loss = criterion(pred, labels)
            train_loss.backward()
            optim.step()
            optim.clear_grad()

            if global_step % 80 == 0:
                message = "train: epoch %d | step %d | " % (epoch, global_step)
                message += "loss %.6f" % (train_loss.numpy())
                log.info(message)
                if dist.get_rank() == 0:
                    writer.add_scalar("loss", train_loss.numpy(), global_step)
            global_step += 1

        valid_result = evaluate(model, valid_loader, criterion, evaluator)
        message = "valid: epoch %d | step %d | " % (epoch, global_step)
        for key, value in valid_result.items():
            message += " | %s %.6f" % (key, value)
            if dist.get_rank() == 0:
                writer.add_scalar("valid_%s" % key, value, global_step)
        log.info(message)

        test_result = evaluate(model, test_loader, criterion, evaluator)
        message = "test: epoch %d | step %d | " % (epoch, global_step)
        for key, value in test_result.items():
            message += " | %s %.6f" % (key, value)
            if dist.get_rank() == 0:
                writer.add_scalar("test_%s" % key, value, global_step)
        log.info(message)

        if best_valid < valid_result[config.metrics]:
            best_valid = valid_result[config.metrics]
            best_valid_result = valid_result
            best_test_result = test_result

        message = "best result: epoch %d | " % (epoch)
        message += "valid %s: %.6f | " % (config.metrics,
                                          best_valid_result[config.metrics])
        message += "test %s: %.6f | " % (config.metrics,
                                         best_test_result[config.metrics])
        log.info(message)

    message = "final eval best result:%.6f" % best_valid_result[config.metrics]
    log.info(message)
    message = "final test best result:%.6f" % best_test_result[config.metrics]
    log.info(message)
Example #19
0
File: main.py Project: Yelrose/PGL
def main(args):
    ds = GINDataset(args.data_path,
                    args.dataset_name,
                    self_loop=not args.train_eps,
                    degree_as_nlabel=True)
    args.feat_size = ds.dim_nfeats

    train_ds, test_ds = fold10_split(ds,
                                     fold_idx=args.fold_idx,
                                     seed=args.seed)

    train_loader = Dataloader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=1,
                              collate_fn=collate_fn)
    test_loader = Dataloader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=collate_fn)

    model = GINModel(args, ds.gclasses)

    epoch_step = len(train_loader)
    boundaries = [
        i for i in range(50 * epoch_step, args.epochs *
                         epoch_step, epoch_step * 50)
    ]
    values = [args.lr * 0.5**i for i in range(0, len(boundaries) + 1)]
    scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=boundaries,
                                                   values=values,
                                                   verbose=False)
    optim = Adam(learning_rate=scheduler, parameters=model.parameters())
    criterion = nn.loss.CrossEntropyLoss()

    global_step = 0
    best_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        model.train()
        for idx, batch_data in enumerate(train_loader):
            graphs, labels = batch_data
            g = pgl.Graph.batch(graphs).tensor()
            labels = paddle.to_tensor(labels)

            pred = model(g)
            train_loss = criterion(pred, labels)
            train_loss.backward()
            train_acc = paddle.metric.accuracy(input=pred, label=labels, k=1)
            optim.step()
            optim.clear_grad()
            scheduler.step()

            global_step += 1
            if global_step % 10 == 0:
                message = "train: epoch %d | step %d | " % (epoch, global_step)
                message += "loss %.6f | acc %.4f" % (train_loss, train_acc)
                log.info(message)

        result = evaluate(model, test_loader, criterion)
        message = "eval: epoch %d | step %d | " % (epoch, global_step)
        for key, value in result.items():
            message += " | %s %.6f" % (key, value)
        log.info(message)

        if best_acc < result['acc']:
            best_acc = result['acc']

    log.info("best evaluating accuracy: %.6f" % best_acc)
Example #20
0
def main(args):
    paddle.set_device("cpu")
    paddle.enable_static()

    fleet.init()

    if args.num_nodes is None:
        num_nodes = load(args.dataset).num_nodes
    else:
        num_nodes = args.num_nodes

    loss = StaticSkipGramModel(num_nodes,
                               args.neg_num,
                               args.embed_size,
                               sparse=True)

    optimizer = F.optimizer.Adam(args.learning_rate, lazy_mode=True)
    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.a_sync = True
    optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
    optimizer.minimize(loss)

    # init and run server or worker
    if fleet.is_server():
        fleet.init_server()
        fleet.run_server()

    if fleet.is_worker():
        place = paddle.CPUPlace()
        exe = paddle.static.Executor(place)
        exe.run(paddle.static.default_startup_program())
        fleet.init_worker()

        graph = load(args.dataset)
        # bind gen
        train_ds = ShardedDataset(graph.nodes)
        collate_fn = BatchRandWalk(graph, args.walk_len, args.win_size,
                                   args.neg_num, args.neg_sample_type)
        data_loader = Dataloader(train_ds,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.sample_workers,
                                 collate_fn=collate_fn)

        cpu_num = int(os.environ.get('CPU_NUM', 1))
        if int(cpu_num) > 1:
            parallel_places = [paddle.CPUPlace()] * cpu_num
            exec_strategy = paddle.static.ExecutionStrategy()
            exec_strategy.num_threads = int(cpu_num)
            build_strategy = paddle.static.BuildStrategy()
            build_strategy.reduce_strategy = paddle.static.BuildStrategy.ReduceStrategy.Reduce
            compiled_prog = paddle.static.CompiledProgram(
                paddle.static.default_main_program()).with_data_parallel(
                    loss_name=loss.name,
                    places=parallel_places,
                    build_strategy=build_strategy,
                    exec_strategy=exec_strategy)
        else:
            compiled_prog = paddle.static.default_main_program()

        for epoch in range(args.epoch):
            train_loss = train(exe, compiled_prog, data_loader, loss)
            log.info("Runing epoch:%s\t train_loss:%.6f", epoch, train_loss)
        fleet.stop_worker()

        if fleet.is_first_worker():
            fleet.save_persistables(exe, "./model",
                                    paddle.static.default_main_program())
Example #21
0
        args.label_data_path,
        args.use_cache,
    )
    valid_dataset = MyDataset(
        valid_chain_list,
        args.n_channels,
        args.pad_len,
        args.protein_chain_graphs,
        args.cmap_thresh,
        args.label_data_path,
        args.use_cache,
    )

    train_loader = Dataloader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=train_dataset.collate_fn,
    )
    valid_loader = Dataloader(valid_dataset,
                              batch_size=args.batch_size,
                              collate_fn=valid_dataset.collate_fn)

    args.n_labels = train_dataset.n_labels
    model = DeepFRI(args)
    task_name = os.path.split(args.label_data_path)[-1]
    task_name = os.path.splitext(task_name)[0]
    args.task = task_name
    time_stamp = str(datetime.now()).replace(":",
                                             "-").replace(" ",
                                                          "_").split(".")[0]
    args.model_name = (
Example #22
0
        labels = np.nan_to_num(labels).astype("float32")

        g = pgl.Graph.batch(graph_list)
        multihop_graphs = []
        for g_list in multihop_graph_list:
            multihop_graphs.append(pgl.Graph.batch(g_list))

        return g, multihop_graphs, labels, batch_valid


if __name__ == "__main__":
    config = prepare_config("pcba_config.yaml", isCreate=False, isSave=False)
    raw_dataset = GraphPropPredDataset(name=config.dataset_name)
    ds = MolDataset(config, raw_dataset, transform=make_multihop_edges)
    splitted_index = raw_dataset.get_idx_split()
    train_ds = Subset(ds, splitted_index['train'], mode='train')
    valid_ds = Subset(ds, splitted_index['valid'], mode="valid")
    test_ds = Subset(ds, splitted_index['test'], mode="test")

    Fn = CollateFn(config)
    loader = Dataloader(train_ds,
                        batch_size=3,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=Fn)
    for batch_data in loader:
        print("batch", batch_data[0][0].node_feat)
        g = pgl.Graph.batch(batch_data[0])
        print(g.node_feat)
        time.sleep(3)
Example #23
0
File: train.py Project: Yelrose/PGL
def main(args):
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    data = pgl.dataset.RedditDataset(args.normalize, args.symmetry)
    log.info("Preprocess finish")
    log.info("Train Examples: %s" % len(data.train_index))
    log.info("Val Examples: %s" % len(data.val_index))
    log.info("Test Examples: %s" % len(data.test_index))
    log.info("Num nodes %s" % data.graph.num_nodes)
    log.info("Num edges %s" % data.graph.num_edges)
    log.info("Average Degree %s" % np.mean(data.graph.indegree()))

    graph = data.graph
    train_index = data.train_index
    val_index = data.val_index
    test_index = data.test_index

    train_label = data.train_label
    val_label = data.val_label
    test_label = data.test_label

    model = GraphSage(
        input_size=data.feature.shape[-1],
        num_class=data.num_classes,
        hidden_size=args.hidden_size,
        num_layers=len(args.samples))

    model = paddle.DataParallel(model)

    criterion = paddle.nn.loss.CrossEntropyLoss()

    optim = paddle.optimizer.Adam(
        learning_rate=args.lr,
        parameters=model.parameters(),
        weight_decay=0.001)

    feature = paddle.to_tensor(data.feature)

    train_ds = ShardedDataset(train_index, train_label)
    val_ds = ShardedDataset(val_index, val_label)
    test_ds = ShardedDataset(test_index, test_label)

    collate_fn = partial(batch_fn, graph=graph, samples=args.samples)

    train_loader = Dataloader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.sample_workers,
        collate_fn=collate_fn)
    val_loader = Dataloader(
        test_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.sample_workers,
        collate_fn=collate_fn)
    test_loader = Dataloader(
        test_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.sample_workers,
        collate_fn=collate_fn)

    cal_val_acc = []
    cal_test_acc = []
    cal_val_loss = []
    for epoch in tqdm.tqdm(range(args.epoch)):
        train_loss, train_acc = train(train_loader, model, feature, criterion,
                                      optim)
        log.info("Runing epoch:%s\t train_loss:%s\t train_acc:%s", epoch,
                 train_loss, train_acc)
        val_loss, val_acc = eval(val_loader, model, feature, criterion)
        cal_val_acc.append(val_acc)
        cal_val_loss.append(val_loss)
        log.info("Runing epoch:%s\t val_loss:%s\t val_acc:%s", epoch, val_loss,
                 val_acc)
        test_loss, test_acc = eval(test_loader, model, feature, criterion)
        cal_test_acc.append(test_acc)
        log.info("Runing epoch:%s\t test_loss:%s\t test_acc:%s", epoch,
                 test_loss, test_acc)

    log.info("Runs %s: Model: %s Best Test Accuracy: %f" %
             (0, "graphsage", cal_test_acc[np.argmax(cal_val_acc)]))
Example #24
0
    test_chain_list = [p.strip() for p in open(args.test_file)]

    saved_state_dict = paddle.load(args.model_name)
    # In-place assignment
    add_saved_args_and_params(args, saved_state_dict)
    test_dataset = MyDataset(
        test_chain_list,
        args.n_channels,
        args.pad_len,
        args.protein_chain_graphs,
        args.cmap_thresh,
        args.label_data_path,
        args.use_cache,
    )

    test_loader = Dataloader(test_dataset,
                             batch_size=args.batch_size,
                             collate_fn=test_dataset.collate_fn)

    args.n_labels = test_dataset.n_labels
    model = DeepFRI(args)
    model.set_state_dict(saved_state_dict["model"])
    model.eval()

    print(f"\n{args.task}: Testing on {len(test_dataset)} protein samples.")
    print(f"Starting  at {datetime.now()}\n")
    print(args)

    test(model, test_loader)
Example #25
0
def collate_fn(batch_data):
    graphs = []
    labels = []
    for g, l in batch_data:
        graphs.append(g)
        labels.append(l)

    labels = np.array(labels, dtype="int64").reshape(-1, 1)

    return graphs, labels


if __name__ == "__main__":
    gindataset = GINDataset("./gin_data/",
                            "MUTAG",
                            self_loop=True,
                            degree_as_nlabel=False)
    loader = Dataloader(gindataset,
                        batch_size=3,
                        shuffle=False,
                        num_workers=1,
                        collate_fn=collate_fn)
    cc = 0
    for batch in loader:
        g, label = batch
        print(label)
        cc += 1
        if cc == 2:
            break
Example #26
0
def infer(config):
    model = getattr(M, config.model_type)(config)

    log.info("infer model from %s" % config.infer_from)
    model.set_state_dict(paddle.load(config.infer_from))

    log.info("loading data")
    ds = getattr(DS, config.dataset_type)(config)

    split_idx = ds.get_idx_split()
    train_ds = DS.Subset(ds, split_idx['train'], mode='train')
    valid_ds = DS.Subset(ds, split_idx['valid'], mode='valid')
    test_ds = DS.Subset(ds, split_idx['test'], mode='test')

    log.info("Train exapmles: %s" % len(train_ds))
    log.info("Valid exapmles: %s" % len(valid_ds))
    log.info("Test exapmles: %s" % len(test_ds))

    train_loader = Dataloader(train_ds,
                              batch_size=config.batch_size,
                              shuffle=False,
                              num_workers=config.num_workers,
                              collate_fn=DS.CollateFn(config),
                              drop_last=True)

    valid_loader = Dataloader(valid_ds,
                              batch_size=config.valid_batch_size,
                              shuffle=False,
                              num_workers=1,
                              collate_fn=DS.CollateFn(config))

    test_loader = Dataloader(test_ds,
                             batch_size=config.valid_batch_size,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=DS.CollateFn(config))

    try:
        task_name = config.infer_from.split("/")[-2]
    except:
        task_name = "ogb_kdd"
    log.info("task_name: %s" % task_name)

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    # ---------------- valid ----------------------- #
    #  log.info("validating ...")
    #  pred_dict = evaluate(model, valid_loader)
    #
    #  log.info("valid MAE: %s" % evaluator.eval(pred_dict)["mae"])
    #  valid_output_path = os.path.join(config.output_dir, task_name)
    #  make_dir(valid_output_path)
    #  valid_output_file = os.path.join(valid_output_path, "valid_mae.txt")
    #
    #  log.info("saving valid result to %s" % valid_output_file)
    #  with open(valid_output_file, 'w') as f:
    #      for y_pred, idx in zip(pred_dict['y_pred'], split_idx['valid']):
    #          smiles, label = ds.raw_dataset[idx]
    #          f.write("%s\t%s\t%s\n" % (y_pred, label, smiles))
    #
    # ---------------- test ----------------------- #

    log.info("testing ...")
    pred_dict = evaluate(model, test_loader)

    test_output_path = os.path.join(config.output_dir, task_name)
    make_dir(test_output_path)
    test_output_file = os.path.join(test_output_path, "test_mae.txt")

    log.info("saving test result to %s" % test_output_file)
    with open(test_output_file, 'w') as f:
        for y_pred, idx in zip(pred_dict['y_pred'], split_idx['test']):
            smiles, label = ds.raw_dataset[idx]
            f.write("%s\t%s\n" % (y_pred, smiles))

    log.info("saving submition format to %s" % test_output_path)
    evaluator.save_test_submission({'y_pred': pred_dict['y_pred']},
                                   test_output_path)
Example #27
0
def main(args):
    role = role_maker.PaddleCloudRoleMaker()
    fleet.init(role)
    data = pgl.dataset.RedditDataset(args.normalize, args.symmetry)
    log.info("Preprocess finish")
    log.info("Train Examples: %s" % len(data.train_index))
    log.info("Val Examples: %s" % len(data.val_index))
    log.info("Test Examples: %s" % len(data.test_index))
    log.info("Num nodes %s" % data.graph.num_nodes)
    log.info("Num edges %s" % data.graph.num_edges)
    log.info("Average Degree %s" % np.mean(data.graph.indegree()))

    graph = data.graph
    train_index = data.train_index
    val_index = data.val_index
    test_index = data.test_index

    train_label = data.train_label
    val_label = data.val_label
    test_label = data.test_label

    loss, acc = build_net(
        input_size=data.feature.shape[-1],
        num_class=data.num_classes,
        hidden_size=args.hidden_size,
        num_layers=len(args.samples))
    test_program = paddle.static.default_main_program().clone(for_test=True)

    strategy = fleet.DistributedStrategy()
    strategy.a_sync = True
    optimizer = paddle.fluid.optimizer.Adam(learning_rate=args.lr)
    optimizer = fleet.distributed_optimizer(optimizer, strategy)
    optimizer.minimize(loss)

    if fleet.is_server():
        fleet.init_server()
        fleet.run_server()
    else:
        place = paddle.CPUPlace()
        exe = paddle.static.Executor(place)
        exe.run(paddle.static.default_startup_program())
        fleet.init_worker()

        train_ds = ShardedDataset(train_index, train_label)
        valid_ds = ShardedDataset(val_index, val_label)
        test_ds = ShardedDataset(test_index, test_label)

        collate_fn = partial(batch_fn, graph=graph, samples=args.samples)

        train_loader = Dataloader(
            train_ds,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.sample_workers,
            collate_fn=collate_fn)

        valid_loader = Dataloader(
            valid_ds,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.sample_workers,
            collate_fn=collate_fn)

        test_loader = Dataloader(
            test_ds,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.sample_workers,
            collate_fn=collate_fn)

        compiled_prog, cpu_num = setup_compiled_prog(loss)

        for epoch in tqdm.tqdm(range(args.epoch)):
            train_loss, train_acc = run(train_loader,
                                        data.feature,
                                        exe,
                                        compiled_prog,
                                        loss,
                                        acc,
                                        phase="train",
                                        cpu_num=cpu_num)

            valid_loss, valid_acc = run(valid_loader,
                                        data.feature,
                                        exe,
                                        test_program,
                                        loss,
                                        acc,
                                        phase="valid",
                                        cpu_num=1)

            log.info("Epoch %s Valid-Loss %s Valid-Acc %s" %
                     (epoch, valid_loss, valid_acc))
        test_loss, test_acc = run(test_loader,
                                  data.feature,
                                  exe,
                                  test_program,
                                  loss,
                                  acc,
                                  phase="test",
                                  cpu_num=1)
        log.info("Epoch %s Test-Loss %s Test-Acc %s" %
                 (epoch, test_loss, test_acc))

        fleet.stop_worker()
Example #28
0
    parser.add_argument('--epochs', type=int, default=300)

    parser.add_argument("--num_convs", type=int, default=2)
    parser.add_argument("--hidden_dim", type=int, default=128)
    parser.add_argument("--infeat_dim", type=int, default=36)
    parser.add_argument("--dense_dims", type=str, default='128*4,128*2,128')

    parser.add_argument('--num_heads', type=int, default=4)
    parser.add_argument('--cut_dist', type=float, default=5.)
    parser.add_argument('--num_angle', type=int, default=6)
    parser.add_argument('--merge_b2b', type=str, default='cat')
    parser.add_argument('--merge_b2a', type=str, default='mean')

    args = parser.parse_args()
    args.activation = F.relu
    args.dense_dims = [eval(dim) for dim in args.dense_dims.split(',')]

    if int(args.cuda) == -1:
        paddle.set_device('cpu')
    else:
        paddle.set_device('gpu:%s' % args.cuda)

    tst_complex = ComplexDataset(args.data_dir, "%s_test" % args.dataset, args.cut_dist, args.num_angle)
    tst_loader = Dataloader(tst_complex, args.batch_size, shuffle=False, num_workers=1, collate_fn=collate_fn)

    model = SIGN(args)
    path = os.path.join(args.model_dir, 'saved_model')
    load_state_dict = paddle.load(path)
    model.set_state_dict(load_state_dict['model'])
    rmse_tst, mae_tst, sd_tst, r_tst = evaluate(model, tst_loader)
    print('Test - RMSE: %.6f, MAE: %.6f, SD: %.6f, R: %.6f.\n' % (rmse_tst, mae_tst, sd_tst, r_tst))
Example #29
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with PGL')
    parser.add_argument('--use_cuda', action='store_true')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset', action='store_true')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=1,
                        help='number of workers (default: 1)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    random.seed(42)
    np.random.seed(42)
    paddle.seed(42)

    if not args.use_cuda:
        paddle.set_device("cpu")

    ### automatic dataloading and splitting
    class Config():
        def __init__(self):
            self.base_data_path = "./dataset"

    config = Config()
    ds = MolDataset(config)

    split_idx = ds.get_idx_split()
    train_ds = Subset(ds, split_idx['train'])
    valid_ds = Subset(ds, split_idx['valid'])
    test_ds = Subset(ds, split_idx['test'])

    print("Train exapmles: ", len(train_ds))
    print("Valid exapmles: ", len(valid_ds))
    print("Test exapmles: ", len(test_ds))

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    train_loader = Dataloader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              collate_fn=CollateFn())

    valid_loader = Dataloader(valid_ds,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=CollateFn())

    if args.save_test_dir is not '':
        test_loader = Dataloader(test_ds,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=CollateFn())

    if args.checkpoint_dir is not '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False, **shared_params)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True, **shared_params)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False, **shared_params)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True, **shared_params)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    if args.log_dir is not '':
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.001,
                                              step_size=300,
                                              gamma=0.25)

    optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
                                      parameters=model.parameters())

    msg = "ogbg_lsc_paddle_baseline\n"
    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train_mae = train(model, train_loader, optimizer)

        print('Evaluating...')
        valid_mae = eval(model, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae})

        if args.log_dir is not '':
            writer.add_scalar('valid/mae', valid_mae, epoch)
            writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir is not '':
                print('Saving checkpoint...')
                paddle.save(
                    model.state_dict(),
                    os.path.join(args.checkpoint_dir, 'checkpoint.pdparams'))

            if args.save_test_dir is not '':
                print('Predicting on test data...')
                y_pred = test(model, test_loader)
                print('Saving test submission file...')
                evaluator.save_test_submission({'y_pred': y_pred},
                                               args.save_test_dir)

        scheduler.step()

        print(f'Best validation MAE so far: {best_valid_mae}')

        try:
            msg +="Epoch: %d | Train: %.6f | Valid: %.6f | Best Valid: %.6f\n" \
                    % (epoch, train_mae, valid_mae, best_valid_mae)
            print(msg)
        except:
            continue

    if args.log_dir is not '':
        writer.close()