예제 #1
0
파일: main_pyg.py 프로젝트: rpatil524/ogb
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on ogbg-code2 data with Pytorch Geometrics')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gcn-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gcn-virtual)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument('--max_seq_len',
                        type=int,
                        default=5,
                        help='maximum sequence length to predict (default: 5)')
    parser.add_argument(
        '--num_vocab',
        type=int,
        default=5000,
        help=
        'the number of vocabulary used for sequence prediction (default: 5000)'
    )
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=300,
        help='dimensionality of hidden units in GNNs (default: 300)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='input batch size for training (default: 128)')
    parser.add_argument('--epochs',
                        type=int,
                        default=25,
                        help='number of epochs to train (default: 25)')
    parser.add_argument('--random_split', action='store_true')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--dataset',
                        type=str,
                        default="ogbg-code2",
                        help='dataset name (default: ogbg-code2)')

    parser.add_argument('--filename',
                        type=str,
                        default="",
                        help='filename to output result (default: )')
    args = parser.parse_args()
    print(args)

    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    ### automatic dataloading and splitting
    dataset = PygGraphPropPredDataset(name=args.dataset)

    seq_len_list = np.array([len(seq) for seq in dataset.data.y])
    print('Target seqence less or equal to {} is {}%.'.format(
        args.max_seq_len,
        np.sum(seq_len_list <= args.max_seq_len) / len(seq_len_list)))

    split_idx = dataset.get_idx_split()

    if args.random_split:
        print('Using random split')
        perm = torch.randperm(len(dataset))
        num_train, num_valid, num_test = len(split_idx['train']), len(
            split_idx['valid']), len(split_idx['test'])
        split_idx['train'] = perm[:num_train]
        split_idx['valid'] = perm[num_train:num_train + num_valid]
        split_idx['test'] = perm[num_train + num_valid:]

        assert (len(split_idx['train']) == num_train)
        assert (len(split_idx['valid']) == num_valid)
        assert (len(split_idx['test']) == num_test)

    # print(split_idx['train'])
    # print(split_idx['valid'])
    # print(split_idx['test'])

    # train_method_name = [' '.join(dataset.data.y[i]) for i in split_idx['train']]
    # valid_method_name = [' '.join(dataset.data.y[i]) for i in split_idx['valid']]
    # test_method_name = [' '.join(dataset.data.y[i]) for i in split_idx['test']]
    # print('#train')
    # print(len(train_method_name))
    # print('#valid')
    # print(len(valid_method_name))
    # print('#test')
    # print(len(test_method_name))

    # train_method_name_set = set(train_method_name)
    # valid_method_name_set = set(valid_method_name)
    # test_method_name_set = set(test_method_name)

    # # unique method name
    # print('#unique train')
    # print(len(train_method_name_set))
    # print('#unique valid')
    # print(len(valid_method_name_set))
    # print('#unique test')
    # print(len(test_method_name_set))

    # # unique valid/test method name
    # print('#valid unseen during training')
    # print(len(valid_method_name_set - train_method_name_set))
    # print('#test unseen during training')
    # print(len(test_method_name_set - train_method_name_set))

    ### building vocabulary for sequence predition. Only use training data.

    vocab2idx, idx2vocab = get_vocab_mapping(
        [dataset.data.y[i] for i in split_idx['train']], args.num_vocab)

    # test encoder and decoder
    # for data in dataset:
    #     # PyG >= 1.5.0
    #     print(data.y)
    #
    #     # PyG 1.4.3
    #     # print(data.y[0])
    #     data = encode_y_to_arr(data, vocab2idx, args.max_seq_len)
    #     print(data.y_arr[0])
    #     decoded_seq = decode_arr_to_seq(data.y_arr[0], idx2vocab)
    #     print(decoded_seq)
    #     print('')

    ## test augment_edge
    # data = dataset[2]
    # print(data)
    # data_augmented = augment_edge(data)
    # print(data_augmented)

    ### set the transform function
    # augment_edge: add next-token edge as well as inverse edges. add edge attributes.
    # encode_y_to_arr: add y_arr to PyG data object, indicating the array representation of a sequence.
    dataset.transform = transforms.Compose([
        augment_edge,
        lambda data: encode_y_to_arr(data, vocab2idx, args.max_seq_len)
    ])

    ### automatic evaluator. takes dataset name as input
    evaluator = Evaluator(args.dataset)

    train_loader = DataLoader(dataset[split_idx["train"]],
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    valid_loader = DataLoader(dataset[split_idx["valid"]],
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)
    test_loader = DataLoader(dataset[split_idx["test"]],
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    nodetypes_mapping = pd.read_csv(
        os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz'))
    nodeattributes_mapping = pd.read_csv(
        os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz'))

    print(nodeattributes_mapping)

    ### Encoding node features into emb_dim vectors.
    ### The following three node features are used.
    # 1. node type
    # 2. node attribute
    # 3. node depth
    node_encoder = ASTNodeEncoder(args.emb_dim,
                                  num_nodetypes=len(nodetypes_mapping['type']),
                                  num_nodeattributes=len(
                                      nodeattributes_mapping['attr']),
                                  max_depth=20)

    if args.gnn == 'gin':
        model = GNN(num_vocab=len(vocab2idx),
                    max_seq_len=args.max_seq_len,
                    node_encoder=node_encoder,
                    num_layer=args.num_layer,
                    gnn_type='gin',
                    emb_dim=args.emb_dim,
                    drop_ratio=args.drop_ratio,
                    virtual_node=False).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(num_vocab=len(vocab2idx),
                    max_seq_len=args.max_seq_len,
                    node_encoder=node_encoder,
                    num_layer=args.num_layer,
                    gnn_type='gin',
                    emb_dim=args.emb_dim,
                    drop_ratio=args.drop_ratio,
                    virtual_node=True).to(device)
    elif args.gnn == 'gcn':
        model = GNN(num_vocab=len(vocab2idx),
                    max_seq_len=args.max_seq_len,
                    node_encoder=node_encoder,
                    num_layer=args.num_layer,
                    gnn_type='gcn',
                    emb_dim=args.emb_dim,
                    drop_ratio=args.drop_ratio,
                    virtual_node=False).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(num_vocab=len(vocab2idx),
                    max_seq_len=args.max_seq_len,
                    node_encoder=node_encoder,
                    num_layer=args.num_layer,
                    gnn_type='gcn',
                    emb_dim=args.emb_dim,
                    drop_ratio=args.drop_ratio,
                    virtual_node=True).to(device)
    else:
        raise ValueError('Invalid GNN type')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print(f'#Params: {sum(p.numel() for p in model.parameters())}')

    valid_curve = []
    test_curve = []
    train_curve = []

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train(model, device, train_loader, optimizer)

        print('Evaluating...')
        train_perf = eval(
            model,
            device,
            train_loader,
            evaluator,
            arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab))
        valid_perf = eval(
            model,
            device,
            valid_loader,
            evaluator,
            arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab))
        test_perf = eval(
            model,
            device,
            test_loader,
            evaluator,
            arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab))

        print({
            'Train': train_perf,
            'Validation': valid_perf,
            'Test': test_perf
        })

        train_curve.append(train_perf[dataset.eval_metric])
        valid_curve.append(valid_perf[dataset.eval_metric])
        test_curve.append(test_perf[dataset.eval_metric])

    print('F1')
    best_val_epoch = np.argmax(np.array(valid_curve))
    best_train = max(train_curve)
    print('Finished training!')
    print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    print('Test score: {}'.format(test_curve[best_val_epoch]))

    if not args.filename == '':
        result_dict = {
            'Val': valid_curve[best_val_epoch],
            'Test': test_curve[best_val_epoch],
            'Train': train_curve[best_val_epoch],
            'BestTrain': best_train
        }
        torch.save(result_dict, args.filename)
예제 #2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on ogbg-code data with Pytorch Geometrics')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default="mostperfect",  #M_DAGNN_GRU,
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gcn-virtual)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument('--max_seq_len',
                        type=int,
                        default=5,
                        help='maximum sequence length to predict (default: 5)')
    parser.add_argument(
        '--num_vocab',
        type=int,
        default=5000,
        help=
        'the number of vocabulary used for sequence prediction (default: 5000)'
    )
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=300,
        help='dimensionality of hidden units in GNNs (default: 300)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=1,
                        help='input batch size for training (default: 128)')
    parser.add_argument('--epochs',
                        type=int,
                        default=30,
                        help='number of epochs to train (default: 30)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--dataset',
                        type=str,
                        default="ogbg-code",
                        help='dataset name (default: ogbg-code)')

    parser.add_argument('--filename',
                        type=str,
                        default="test",
                        help='filename to output result (default: )')

    parser.add_argument('--dir_data', type=str, default=None, help='... dir')
    parser.add_argument('--dir_results',
                        type=str,
                        default=DIR_RESULTS,
                        help='results dir')
    parser.add_argument('--dir_save',
                        default=DIR_SAVED_MODELS,
                        help='directory to save checkpoints in')
    parser.add_argument('--train_idx', default="", help='...')
    parser.add_argument('--checkpointing',
                        default=1,
                        type=int,
                        choices=[0, 1],
                        help='...')
    parser.add_argument('--checkpoint', default="", help='...')
    parser.add_argument('--folds', default=10, type=int, help='...')
    parser.add_argument('--clip', default=0, type=float, help='...')
    parser.add_argument('--lr',
                        default=1e-3,
                        type=float,
                        help='learning rate (default: 1e-3)')
    parser.add_argument('--patience',
                        default=20,
                        type=float,
                        help='learning rate (default: 1e-3)')
    ###

    args = parser.parse_args()
    args.folds = 1
    args.epochs = 1
    args.checkpointing = 0  # doesn't make sense and yields error for optimizer with current code
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    os.makedirs(args.dir_results, exist_ok=True)
    os.makedirs(args.dir_save, exist_ok=True)

    train_file = os.path.join(args.dir_results, args.filename + '_train.csv')
    if not os.path.exists(train_file):
        with open(train_file, 'w') as f:
            f.write("fold,epoch,loss,train,valid,test\n")
    res_file = os.path.join(args.dir_results, args.filename + '.csv')
    if not os.path.exists(res_file):
        with open(res_file, 'w') as f:
            f.write("fold,epoch,bestv_train,bestv_valid,bestv_test\n")

    ### automatic dataloading and splitting
    dataset = PygGraphPropPredDataset(
        name=args.dataset,
        root="dataset" if args.dir_data is None else args.dir_data)

    seq_len_list = np.array([len(seq) for seq in dataset.data.y])
    print('Target seqence less or equal to {} is {}%.'.format(
        args.max_seq_len,
        np.sum(seq_len_list <= args.max_seq_len) / len(seq_len_list)))

    split_idx = dataset.get_idx_split()

    if args.train_idx:
        train_idx = pd.read_csv(os.path.join("dataset",
                                             args.train_idx + ".csv.gz"),
                                compression="gzip",
                                header=None).values.T[0]
        train_idx = torch.tensor(train_idx, dtype=torch.long)
        split_idx['train'] = train_idx

    ### building vocabulary for sequence predition. Only use training data.

    vocab2idx, idx2vocab = get_vocab_mapping(
        [dataset.data.y[i] for i in split_idx['train']], args.num_vocab)

    # if not torch.cuda.is_available():
    #     split_idx['valid'] = list(range(50, 60))
    #     split_idx['test'] = list(range(60, 70))
    # pass

    ### set the transform function
    # augment_edge: add next-token edge as well as inverse edges. add edge attributes.
    # encode_y_to_arr: add y_arr to PyG data object, indicating the array representation of a sequence.
    # DAGNN
    augment = augment_edge2 if "dagnn" in args.gnn else augment_edge
    dataset.transform = transforms.Compose([
        augment,
        lambda data: encode_y_to_arr(data, vocab2idx, args.max_seq_len)
    ])

    ### automatic evaluator. takes dataset name as input
    evaluator = Evaluator(args.dataset)

    nodetypes_mapping = pd.read_csv(
        os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz'))
    nodeattributes_mapping = pd.read_csv(
        os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz'))
    ### Encoding node features into emb_dim vectors.
    ### The following three node features are used.
    # 1. node type
    # 2. node attribute
    # 3. node depth
    # node_encoder = ASTNodeEncoder(args.emb_dim, num_nodetypes = len(nodetypes_mapping['type']), num_nodeattributes = len(nodeattributes_mapping['attr']), max_depth = 20)

    start_fold = 1
    checkpoint_fn = ""
    train_results, valid_results, test_results = [], [], []  # on fold level

    if args.checkpointing and args.checkpoint:
        s = args.checkpoint[:-3].split("_")
        start_fold = int(s[-2])
        start_epoch = int(s[-1]) + 1

        checkpoint_fn = os.path.join(
            args.dir_save, args.checkpoint)  # need to remove it in any case

        if start_epoch > args.epochs:  # DISCARD checkpoint's model (ie not results), need a new model!
            args.checkpoint = ""
            start_fold += 1

            results = load_checkpoint_results(checkpoint_fn)
            train_results, valid_results, test_results, train_curve, valid_curve, test_curve = results

    # start
    for fold in range(start_fold, args.folds + 1):
        # fold-specific settings & data splits
        torch.manual_seed(fold)
        random.seed(fold)
        np.random.seed(fold)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(fold)
            torch.backends.cudnn.benchmark = True
            # torch.backends.cudnn.deterministic = True
            # torch.backends.cudnn.benchmark = False

        n_devices = torch.cuda.device_count(
        ) if torch.cuda.device_count() > 0 else 1
        # train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True,
        #                           num_workers = args.num_workers, n_devices=n_devices)
        valid_loader = DataLoader(dataset[split_idx["valid"]],
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  n_devices=n_devices)
        test_loader = DataLoader(dataset[split_idx["test"]],
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 n_devices=n_devices)

        start_epoch = 1

        # model etc.
        model = init_model(args, vocab2idx, nodeattributes_mapping, idx2vocab)

        print("Let's use", torch.cuda.device_count(),
              "GPUs! -- DataParallel running also on CPU only")
        device_ids = list(range(torch.cuda.device_count())
                          ) if torch.cuda.device_count() > 0 else None
        model = DataParallel(model, device_ids)
        model.to(device)

        optimizer = None  #optim.Adam(model.parameters(), lr=args.lr)

        # overwrite some settings
        if args.checkpointing and args.checkpoint:
            # signal that it has been used
            args.checkpoint = ""

            results, start_epoch, model, optimizer = load_checkpoint(
                checkpoint_fn, model, optimizer)
            train_results, valid_results, test_results, train_curve, valid_curve, test_curve = results
            start_epoch += 1
        else:
            valid_curve, test_curve, train_curve = [], [], []

        # start new epoch
        for epoch in range(start_epoch, args.epochs + 1):
            old_checkpoint_fn = checkpoint_fn
            checkpoint_fn = '%s.pt' % os.path.join(
                args.dir_save,
                args.filename + "_" + str(fold) + "_" + str(epoch))

            print("=====Fold {}, Epoch {}".format(fold, epoch))
            # loss, train_perf = train(model, device, train_loader, optimizer, args, evaluator, arr_to_seq = lambda arr: decode_arr_to_seq(arr, idx2vocab), vocab2index=vocab2idx)
            loss, train_perf = 0, {"F1": 0}
            valid_perf = {
                "F1": 0
            }  #eval(model, device, valid_loader, evaluator, arr_to_seq = lambda arr: decode_arr_to_seq(arr, idx2vocab), vocab2index=vocab2idx)
            test_perf = eval(
                model,
                device,
                test_loader,
                evaluator,
                arr_to_seq=lambda arr: decode_arr_to_seq(arr, idx2vocab),
                vocab2index=vocab2idx)

            print({
                'Train': train_perf,
                'Validation': valid_perf,
                'Test': test_perf
            })
            with open(train_file, 'a') as f:
                f.write("{},{},{:.4f},{:.4f},{:.4f},{:.4f}\n".format(
                    fold, epoch, loss, train_perf[dataset.eval_metric],
                    valid_perf[dataset.eval_metric],
                    test_perf[dataset.eval_metric]))

            train_curve.append(train_perf[dataset.eval_metric])
            valid_curve.append(valid_perf[dataset.eval_metric])
            test_curve.append(test_perf[dataset.eval_metric])

            ### DAGNN
            if args.checkpointing:
                create_checkpoint(checkpoint_fn, epoch, model, optimizer,
                                  (train_results, valid_results, test_results,
                                   train_curve, valid_curve, test_curve))
                if fold > 1 or epoch > 1:
                    remove_checkpoint(old_checkpoint_fn)

            best_val_epoch = np.argmax(np.array(valid_curve))
            if args.patience > 0 and best_val_epoch + 1 + args.patience < epoch:
                print("Early stopping!")
                break

        print('Finished training for fold {} !'.format(fold) + "*" * 20)
        print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
        print('Test score: {}'.format(test_curve[best_val_epoch]))

        with open(res_file, 'a') as f:
            results = [
                fold, best_val_epoch, train_curve[best_val_epoch],
                valid_curve[best_val_epoch], test_curve[best_val_epoch]
            ]
            f.writelines(",".join([str(v) for v in results]) + "\n")

        train_results += [train_curve[best_val_epoch]]
        valid_results += [valid_curve[best_val_epoch]]
        test_results += [test_curve[best_val_epoch]]

        results = list(summary_report(train_results)) + list(
            summary_report(valid_results)) + list(summary_report(test_results))
        # with open(res_file, 'a') as f:
        #     f.writelines(str(fold)+ ",_," + ",".join([str(v) for v in results]) + "\n")
        print(",".join([str(v) for v in results]))

    results = list(summary_report(train_results)) + list(
        summary_report(valid_results)) + list(summary_report(test_results))
    with open(res_file, 'a') as f:
        f.writelines(
            str(fold) + ",_," + ",".join([str(v) for v in results]) + "\n")