Example #1
0
 def __init__(self, dataset_name, dev=torch.device('cpu')):
     dataset = LoadData(dataset_name)
     test_loader = DataLoader(dataset.test,
                              shuffle=False,
                              drop_last=False,
                              collate_fn=dataset.collate)
     self.test_data_ls = list(test_loader)
     self.device = dev
def main(config):
    """
        USER CONTROLS
    """
    # parameters
    params = config['params']

    # dataset
    DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)

    # device
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    out_dir = config['out_dir']

    # GNN model
    MODEL_NAME = config['model']

    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']

    # TSP
    net_params['in_dim'] = dataset.train[0][0].ndata['feat'][0].shape[0]
    net_params['in_dim_edge'] = dataset.train[0][0].edata['feat'][0].size(0)
    num_classes = len(np.unique(np.concatenate(dataset.train[:][1])))
    net_params['n_classes'] = num_classes

    if MODEL_NAME == 'RingGNN':
        num_nodes = [
            dataset.train[i][0].number_of_nodes()
            for i in range(len(dataset.train))
        ]
        net_params['avg_node_num'] = int(np.ceil(np.mean(num_nodes)))

    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
Example #3
0
def main():
    """
        USER CONTROLS
    """

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        help=
        "Please give a config.json file with training/model/data/param details"
    )
    parser.add_argument('--framework',
                        type=str,
                        default=None,
                        help="Please give a framework to use")
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--dataset',
                        help="Please give a value for dataset name")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed', help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--batch_size',
                        help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor',
                        help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience',
                        help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay',
                        help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval',
                        help="Please give a value for print_epoch_interval")
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--hidden_dim',
                        help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat',
                        help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--kernel', help="Please give a value for kernel")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout',
                        help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--layer_norm',
                        help="Please give a value for layer_norm")
    parser.add_argument('--batch_norm',
                        help="Please give a value for batch_norm")
    parser.add_argument('--sage_aggregator',
                        help="Please give a value for sage_aggregator")
    parser.add_argument('--data_mode',
                        help="Please give a value for data_mode")
    parser.add_argument('--num_pool', help="Please give a value for num_pool")
    parser.add_argument('--gnn_per_block',
                        help="Please give a value for gnn_per_block")
    parser.add_argument('--embedding_dim',
                        help="Please give a value for embedding_dim")
    parser.add_argument('--pool_ratio',
                        help="Please give a value for pool_ratio")
    parser.add_argument('--linkpred', help="Please give a value for linkpred")
    parser.add_argument('--cat', help="Please give a value for cat")
    parser.add_argument('--self_loop',
                        help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    parser.add_argument('--pos_enc_dim',
                        help="Please give a value for pos_enc_dim")
    parser.add_argument('--pos_enc', help="Please give a value for pos_enc")
    args = parser.parse_args()
    with open(args.config) as f:
        config = json.load(f)

    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    # parameters
    params = config['params']
    params['framework'] = 'pyg' if MODEL_NAME[-3:] == 'pyg' else 'dgl'
    if args.framework is not None:
        params['framework'] = str(args.framework)
    dataset = LoadData(DATASET_NAME, framework=params['framework'])
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)

    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)
    if args.residual is not None:
        net_params['residual'] = True if args.residual == 'True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat == 'True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.kernel is not None:
        net_params['kernel'] = int(args.kernel)
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.gated is not None:
        net_params['gated'] = True if args.gated == 'True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.layer_norm is not None:
        net_params['layer_norm'] = True if args.layer_norm == 'True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm == 'True' else False
    if args.sage_aggregator is not None:
        net_params['sage_aggregator'] = args.sage_aggregator
    if args.data_mode is not None:
        net_params['data_mode'] = args.data_mode
    if args.num_pool is not None:
        net_params['num_pool'] = int(args.num_pool)
    if args.gnn_per_block is not None:
        net_params['gnn_per_block'] = int(args.gnn_per_block)
    if args.embedding_dim is not None:
        net_params['embedding_dim'] = int(args.embedding_dim)
    if args.pool_ratio is not None:
        net_params['pool_ratio'] = float(args.pool_ratio)
    if args.linkpred is not None:
        net_params['linkpred'] = True if args.linkpred == 'True' else False
    if args.cat is not None:
        net_params['cat'] = True if args.cat == 'True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop == 'True' else False
    if args.pos_enc is not None:
        net_params['pos_enc'] = True if args.pos_enc == 'True' else False
    if args.pos_enc_dim is not None:
        net_params['pos_enc_dim'] = int(args.pos_enc_dim)

    # SBM
    # net_params['in_dim'] = torch.unique(dataset.train[0][0].ndata['feat'],dim=0).size(0) # node_dim (feat is an integer)
    # net_params['n_classes'] = torch.unique(dataset.train[0][1],dim=0).size(0)
    net_params['in_dim'] = torch.unique(
        dataset.train[0].x,
        dim=0).size(0) if 'pyg' == params['framework'] else torch.unique(
            dataset.train[0][0].ndata['feat'], dim=0).size(0)
    net_params['n_classes'] = torch.unique(
        dataset.train[0].y,
        dim=0).size(0) if 'pyg' == params['framework'] else torch.unique(
            dataset.train[0][1], dim=0).size(0)

    if MODEL_NAME == 'RingGNN':
        num_nodes = [
            dataset.train[i][0].number_of_nodes()
            for i in range(len(dataset.train))
        ]
        net_params['avg_node_num'] = int(np.ceil(np.mean(num_nodes)))

    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
Example #4
0
                        type=int,
                        default=4,
                        help='Number of WL-iterations')
    parser.add_argument('-bs', '--batch_size', type=int, default=128)
    parser.add_argument('--dataset', default='ZINC')
    parser.add_argument('--net', required=True)
    parser.add_argument('--seed', default=1, type=int)

    return parser.parse_args()


args = get_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.backends.cudnn.benchmark = True

dataset = LoadData(args.dataset)

config = 'configs/molecules_graph_regression_nns.json'
with open(config) as inf:
    config = json.load(inf)

net_params = config[args.net]

params = {
    "seed": args.seed,
    "epochs": 1000,
    "init_lr": 1e-3,
    "lr_reduce_factor": 0.5,
    "lr_schedule_patience": 5,
    "min_lr": 1e-5,
    "weight_decay": 0.0,
Example #5
0
def start(args):
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info("args = %s", args)

    dataset = LoadData(args.data_name)
    if args.data_name == 'SBM_PATTERN':
        in_dim = 3
        num_classes = 2
    elif args.data_name == 'SBM_CLUSTER':
        in_dim = 7
        num_classes = 6
    print(f"input dimension: {in_dim}, number classes: {num_classes}")

    criterion = MyCriterion(num_classes)
    criterion = criterion.cuda()

    model = Network(args.layers, args.nodes, in_dim, args.feature_dim, num_classes, criterion, args.data_type, args.readout)
    model = model.cuda()
    logging.info("param size = %fMB", count_parameters_in_MB(model))

    train_data, val_data, test_data = dataset.train, dataset.val, dataset.test

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    print(f"train set full size : {num_train}; split train set size : {split}")
    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size = args.batch_size,
        sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size = args.batch_size,
        sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate)

    true_valid_queue = torch.utils.data.DataLoader(
        val_data, batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
        collate_fn=dataset.collate)

    test_queue = torch.utils.data.DataLoader(
        test_data, batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
        collate_fn=dataset.collate)

    optimizer = torch.optim.SGD(model.parameters(),args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    architect = Architect(model, args)

    # viz = Visdom(env = '{} {}'.format(args.data_name,  time.asctime(time.localtime(time.time()))  ))
    viz = None
    save_file = open(args.save_result, "w")
    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('[LR]\t%f', lr)

        if epoch % args.save_freq == 0:
            print(model.show_genotypes())
            save_file.write(f"Epoch : {epoch}\n{model.show_genotypes()}\n")
            for i in range(args.layers):
                logging.info('layer = %d', i)
                genotype = model.show_genotype(i)
                logging.info('genotype = %s', genotype)
            '''
            w1, w2, w3 = model.show_weights(0)
            print('[1] weights in first cell\n',w1)
            print('[2] weights in middle cell\n', w2)
            print('[3] weights in last cell\n', w3)
            '''
        # training
        macro_acc, micro_acc, loss = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch, viz)
        # true validation
        macro_acc, micro_acc, loss = infer(true_valid_queue, model, criterion, stage = 'validating')
        # testing
        macro_acc, micro_acc, loss = infer(test_queue, model, criterion, stage = ' testing  ')
Example #6
0
def start(args):
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    dataset = LoadData(args.data_name)
    in_dim = dataset.train[0][0].ndata['feat'][0].size(0)
    num_classes = len(np.unique(np.array(dataset.train[:][1])))
    print(f"=> input dimension: {in_dim}, number classes: {num_classes}")
    
    if args.data_name == 'MNIST':
        genotype = MNIST_Net
    elif args.data_name == 'CIFAR10':
        genotype = CIFAR10_Net
    else:
        print("Unknown dataset.")
        exit()

    print('=> loading from genotype: \n', genotype)
    model = Network(args, genotype, num_classes, in_dim, criterion)
    model = model.cuda()
    logging.info("param size = %fMB", count_parameters_in_MB(model))

    train_data, val_data, test_data = dataset.train, dataset.val, dataset.test

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size = args.batch_size,
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate,
        shuffle = True) #新增shuffle

    valid_queue = torch.utils.data.DataLoader(
        val_data, batch_size = args.batch_size,
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate,
        shuffle = False)

    test_queue = torch.utils.data.DataLoader(
        test_data, batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
        collate_fn=dataset.collate,
        shuffle = False)
    
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    elif args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=3e-6)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                               factor=0.5,
                                                               patience=5,
                                                               verbose=True)
        
    for epoch in range(args.epochs):
        logging.info('[EPOCH]\t%d', epoch)
        if args.optimizer == 'SGD':
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('[LR]\t%f', lr)

        # training
        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion, stage = 'validating')
        # testing
        test_acc, test_obj = infer(test_queue, model, criterion, stage = 'testing   ')
        desc = '[train] acc: {:.3f}, loss: {:.3f}\t[validate] acc:{:.3f}, loss: {:.3f}\t[test] acc: {:.3f}, loss: {:.3f}'.format(
            train_acc, train_obj, valid_acc, valid_obj, test_acc, test_obj
        )
        logging.info(desc)

        if args.optimizer == 'ADAM':
            scheduler.step(valid_obj)
            if optimizer.param_groups[0]['lr'] < 1e-5:
                print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
                break
def main():    
    """
        USER CONTROLS
    """
    
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details")
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--dataset', help="Please give a value for dataset name")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed', help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--batch_size', help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay', help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval")    
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat', help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--kernel', help="Please give a value for kernel")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--layer_norm', help="Please give a value for layer_norm")
    parser.add_argument('--batch_norm', help="Please give a value for batch_norm")
    parser.add_argument('--sage_aggregator', help="Please give a value for sage_aggregator")
    parser.add_argument('--data_mode', help="Please give a value for data_mode")
    parser.add_argument('--num_pool', help="Please give a value for num_pool")
    parser.add_argument('--gnn_per_block', help="Please give a value for gnn_per_block")
    parser.add_argument('--embedding_dim', help="Please give a value for embedding_dim")
    parser.add_argument('--pool_ratio', help="Please give a value for pool_ratio")
    parser.add_argument('--linkpred', help="Please give a value for linkpred")
    parser.add_argument('--cat', help="Please give a value for cat")
    parser.add_argument('--self_loop', help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    parser.add_argument('--layer_type', help="Please give a value for layer_type (for GAT and GatedGCN only)")
    parser.add_argument('--pos_enc_dim', help="Please give a value for pos_enc_dim")
    parser.add_argument('--pos_enc', help="Please give a value for pos_enc")
    args = parser.parse_args()
    with open(args.config) as f:
        config = json.load(f)
        
    
    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    # parameters
    params = config['params']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)   
    if args.residual is not None:
        net_params['residual'] = True if args.residual=='True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat=='True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.kernel is not None:
        net_params['kernel'] = int(args.kernel)
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.gated is not None:
        net_params['gated'] = True if args.gated=='True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.layer_norm is not None:
        net_params['layer_norm'] = True if args.layer_norm=='True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm=='True' else False
    if args.sage_aggregator is not None:
        net_params['sage_aggregator'] = args.sage_aggregator
    if args.data_mode is not None:
        net_params['data_mode'] = args.data_mode
    if args.num_pool is not None:
        net_params['num_pool'] = int(args.num_pool)
    if args.gnn_per_block is not None:
        net_params['gnn_per_block'] = int(args.gnn_per_block)
    if args.embedding_dim is not None:
        net_params['embedding_dim'] = int(args.embedding_dim)
    if args.pool_ratio is not None:
        net_params['pool_ratio'] = float(args.pool_ratio)
    if args.linkpred is not None:
        net_params['linkpred'] = True if args.linkpred=='True' else False
    if args.cat is not None:
        net_params['cat'] = True if args.cat=='True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop=='True' else False
    if args.layer_type is not None:
        net_params['layer_type'] = layer_type
    if args.pos_enc is not None:
        net_params['pos_enc'] = True if args.pos_enc=='True' else False
    if args.pos_enc_dim is not None:
        net_params['pos_enc_dim'] = int(args.pos_enc_dim)
 

      
    
    # COLLAB
    net_params['in_dim'] = dataset.graph.ndata['feat'].shape[-1]
    net_params['in_dim_edge'] = dataset.graph.edata['feat'].shape[-1]
    net_params['n_classes'] = 1  # binary prediction
    
    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')
        
    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
# MODEL_NAME = 'GatedGCN'
# MODEL_NAME = 'GAT'
# MODEL_NAME = 'GraphSage'
# MODEL_NAME = 'DiffPool'
# MODEL_NAME = 'GIN'

DATASET_NAME = 'TSP'

out_dir = 'out/TSP_edge_classification/debug/'
root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_" + time.strftime(
    '%Hh%Mm%Ss_on_%b_%d_%Y')
root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_" + time.strftime(
    '%Hh%Mm%Ss_on_%b_%d_%Y')

print("[I] Loading data (notebook) ...")
dataset = LoadData(DATASET_NAME)
trainset, valset, testset = dataset.train, dataset.val, dataset.test
print("[I] Finished loading.")

MODEL_NAME = 'GatedGCN'

n_heads = -1
edge_feat = False
pseudo_dim_MoNet = -1
kernel = -1
gnn_per_block = -1
embedding_dim = -1
pool_ratio = -1
n_mlp_GIN = -1
gated = False
self_loop = False
def main(config=None):

    # parameters
    params = config['params']

    # dataset
    DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)

    # device
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    out_dir = config['out_dir']

    # GNN model
    MODEL_NAME = config['model']

    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']

    # CSL
    net_params['num_node_type'] = dataset.all.num_node_type
    net_params['num_edge_type'] = dataset.all.num_edge_type
    num_classes = len(np.unique(dataset.all.graph_labels))
    net_params['n_classes'] = num_classes

    # RingGNN
    if MODEL_NAME == 'RingGNN':
        num_nodes_train = [
            dataset.train[0][i][0].number_of_nodes()
            for i in range(len(dataset.train))
        ]
        num_nodes_test = [
            dataset.test[0][i][0].number_of_nodes()
            for i in range(len(dataset.test))
        ]
        num_nodes = num_nodes_train + num_nodes_test
        net_params['avg_node_num'] = int(np.ceil(np.mean(num_nodes)))

    # RingGNN, 3WLGNN
    if MODEL_NAME in ['RingGNN', '3WLGNN']:
        if net_params['pos_enc']:
            net_params['in_dim'] = net_params['pos_enc_dim']
        else:
            net_params['in_dim'] = 1

    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, dirs)
def main():
    """
        USER CONTROLS
    """

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        help=
        "Please give a config.json file with training/model/data/param details"
    )
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--dataset',
                        help="Please give a value for dataset name")
    parser.add_argument('--builtin', help="Please give a value for builtin")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed', help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--runs', help="Please give a value for runs")
    parser.add_argument('--batch_size',
                        help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor',
                        help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience',
                        help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay',
                        help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval',
                        help="Please give a value for print_epoch_interval")
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--K', help="Please give a value for K")
    parser.add_argument('--num_low',
                        help="Please give a value for number of low filters")
    parser.add_argument('--num_high',
                        help="Please give a value for number of high filters")
    parser.add_argument(
        '--num_mid', help="Please give a value for number of middle filters")
    parser.add_argument('--opt', help="Please give a value for AutoGCN option")
    parser.add_argument('--gate',
                        help="Please give a value for AutoGCN gate option")
    parser.add_argument('--hidden_dim',
                        help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat',
                        help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--kernel', help="Please give a value for kernel")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout',
                        help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--graph_norm',
                        help="Please give a value for graph_norm")
    parser.add_argument('--batch_norm',
                        help="Please give a value for batch_norm")
    parser.add_argument('--sage_aggregator',
                        help="Please give a value for sage_aggregator")
    parser.add_argument('--data_mode',
                        help="Please give a value for data_mode")
    parser.add_argument('--num_pool', help="Please give a value for num_pool")
    parser.add_argument('--gnn_per_block',
                        help="Please give a value for gnn_per_block")
    parser.add_argument('--embedding_dim',
                        help="Please give a value for embedding_dim")
    parser.add_argument('--pool_ratio',
                        help="Please give a value for pool_ratio")
    parser.add_argument('--linkpred', help="Please give a value for linkpred")
    parser.add_argument('--cat', help="Please give a value for cat")
    parser.add_argument('--self_loop',
                        help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    args = parser.parse_args()
    with open(args.config) as f:
        config = json.load(f)

    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    # parameters
    params = config['params']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.runs is not None:
        params['runs'] = int(args.runs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.K is not None:
        net_params['K'] = int(args.K)
    if args.num_low is not None:
        net_params['num_low'] = int(args.num_low)
    if args.num_high is not None:
        net_params['num_high'] = int(args.num_high)
    if args.num_mid is not None:
        net_params['num_mid'] = int(args.num_mid)
    if args.opt is not None:
        net_params['opt'] = args.opt
    if args.gate is not None:
        net_params['gate'] = True if args.gate == 'True' else False
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)
    if args.residual is not None:
        net_params['residual'] = True if args.residual == 'True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat == 'True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.kernel is not None:
        net_params['kernel'] = int(args.kernel)
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.gated is not None:
        net_params['gated'] = True if args.gated == 'True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.graph_norm is not None:
        net_params['graph_norm'] = True if args.graph_norm == 'True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm == 'True' else False
    if args.sage_aggregator is not None:
        net_params['sage_aggregator'] = args.sage_aggregator
    if args.data_mode is not None:
        net_params['data_mode'] = args.data_mode
    if args.num_pool is not None:
        net_params['num_pool'] = int(args.num_pool)
    if args.gnn_per_block is not None:
        net_params['gnn_per_block'] = int(args.gnn_per_block)
    if args.embedding_dim is not None:
        net_params['embedding_dim'] = int(args.embedding_dim)
    if args.pool_ratio is not None:
        net_params['pool_ratio'] = float(args.pool_ratio)
    if args.linkpred is not None:
        net_params['linkpred'] = True if args.linkpred == 'True' else False
    if args.cat is not None:
        net_params['cat'] = True if args.cat == 'True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop == 'True' else False
    if args.builtin is not None:
        net_params['builtin'] = True if args.builtin == 'True' else False

    # CitationGraph
    net_params['in_dim'] = dataset.num_dims  # node_dim (feat is an integer)
    net_params['n_classes'] = dataset.num_classes

    if MODEL_NAME in ['MLP', 'MLP_GATED']:
        builtin = ''
    else:
        builtin = 'DGL' if net_params['builtin'] else 'Custom'
    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime(
            '%Hh%Mm%Ss_on_%b_%d_%Y') + builtin
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime(
            '%Hh%Mm%Ss_on_%b_%d_%Y') + builtin
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime(
            '%Hh%Mm%Ss_on_%b_%d_%Y') + builtin
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime(
            '%Hh%Mm%Ss_on_%b_%d_%Y') + builtin
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    val_acc = []
    test_acc = []
    np.random.seed(121)
    seeds = [np.random.randint(10000) for i in range(params['runs'])]
    print('seeds', seeds)
    for i in range(params['runs']):
        params['seed'] = seeds[i]
        net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
        val_ac, test_ac = train_val_pipeline(MODEL_NAME, dataset, params,
                                             net_params, dirs)
        val_acc.append(val_ac)
        test_acc.append(test_ac)
    val_acc = np.array(val_acc)
    test_acc = np.array(test_acc)

    print('\n')
    print('-' * 89)
    print('Val F1: {:.3f} ± {:.3f}, Test F1: {:.3f} ± {:.3f}'.format(
        val_acc.mean(), val_acc.std(), test_acc.mean(), test_acc.std()))
def main():
    """
        USER CONTROLS
    """

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        help=
        "Please give a config.json file with training/model/data/param details"
    )
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--dataset',
                        help="Please give a value for dataset name")
    parser.add_argument('--builtin', help="Please give a value for builtin")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed', help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--batch_size',
                        help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor',
                        help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience',
                        help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay',
                        help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval',
                        help="Please give a value for print_epoch_interval")
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--hidden_dim',
                        help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat',
                        help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--kernel', help="Please give a value for kernel")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout',
                        help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--graph_norm',
                        help="Please give a value for graph_norm")
    parser.add_argument('--batch_norm',
                        help="Please give a value for batch_norm")
    parser.add_argument('--sage_aggregator',
                        help="Please give a value for sage_aggregator")
    parser.add_argument('--data_mode',
                        help="Please give a value for data_mode")
    parser.add_argument('--num_pool', help="Please give a value for num_pool")
    parser.add_argument('--gnn_per_block',
                        help="Please give a value for gnn_per_block")
    parser.add_argument('--embedding_dim',
                        help="Please give a value for embedding_dim")
    parser.add_argument('--pool_ratio',
                        help="Please give a value for pool_ratio")
    parser.add_argument('--linkpred', help="Please give a value for linkpred")
    parser.add_argument('--cat', help="Please give a value for cat")
    parser.add_argument('--self_loop',
                        help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    parser.add_argument('--my_layer', help="Please give a value for my_layer")
    args = parser.parse_args()
    with open(args.config) as f:
        config = json.load(f)

    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    # parameters
    params = config['params']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    net_params['my_layer'] = False
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)
    if args.residual is not None:
        net_params['residual'] = True if args.residual == 'True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat == 'True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.kernel is not None:
        net_params['kernel'] = int(args.kernel)
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.gated is not None:
        net_params['gated'] = True if args.gated == 'True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.graph_norm is not None:
        net_params['graph_norm'] = True if args.graph_norm == 'True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm == 'True' else False
    if args.sage_aggregator is not None:
        net_params['sage_aggregator'] = args.sage_aggregator
    if args.data_mode is not None:
        net_params['data_mode'] = args.data_mode
    if args.num_pool is not None:
        net_params['num_pool'] = int(args.num_pool)
    if args.gnn_per_block is not None:
        net_params['gnn_per_block'] = int(args.gnn_per_block)
    if args.embedding_dim is not None:
        net_params['embedding_dim'] = int(args.embedding_dim)
    if args.pool_ratio is not None:
        net_params['pool_ratio'] = float(args.pool_ratio)
    if args.linkpred is not None:
        net_params['linkpred'] = True if args.linkpred == 'True' else False
    if args.cat is not None:
        net_params['cat'] = True if args.cat == 'True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop == 'True' else False
    if args.builtin is not None:
        net_params['builtin'] = True if args.builtin == 'True' else False
    if args.my_layer is not None:
        net_params['my_layer'] = True if args.my_layer == 'True' else False

    # CitationGraph
    net_params['in_dim'] = dataset.num_dims  # node_dim (feat is an integer)
    net_params['n_classes'] = dataset.num_classes

    if net_params['my_layer']:
        net_params['builtin'] = False

    if MODEL_NAME in ['MLP', 'MLP_GATED']:
        builtin = ''
    else:
        builtin = 'DGL' if net_params['builtin'] else 'Custom'

    # root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + builtin
    # root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + builtin
    # write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + builtin
    # write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') + builtin
    # dirs = root_log_dir, root_ckpt_dir, write_file_name, MODEL_NAME

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    vis_model_name = ('My' if net_params['my_layer'] else '') + MODEL_NAME
    N = max([
        int(fn.split('.')[0].split('_')[-1])
        for fn in os.listdir(out_dir + 'results/')
    ] + [0]) + 1
    root_log_dir = out_dir + 'logs/' + DATASET_NAME + "_" + vis_model_name + "_" + builtin + "_" + str(
        N)
    root_ckpt_dir = out_dir + 'checkpoints/' + DATASET_NAME + "_" + vis_model_name + "_" + builtin + "_" + str(
        N)
    write_file_name = out_dir + 'results/result_' + DATASET_NAME + "_" + vis_model_name + "_" + builtin + "_" + str(
        N)
    write_config_file = out_dir + 'configs/config_' + DATASET_NAME + "_" + vis_model_name + "_" + builtin + "_" + str(
        N)
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
Example #12
0
def main():
    """
        USER CONTROLS
    """

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        help=
        "Please give a config.json file with training/model/data/param details"
    )
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--dataset',
                        help="Please give a value for dataset name")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed', help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--runs', help="Please give a value for runs")
    parser.add_argument('--batch_size',
                        help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor',
                        help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience',
                        help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay',
                        help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval',
                        help="Please give a value for print_epoch_interval")
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--K', help="Please give a value for K")
    parser.add_argument('--num_low',
                        help="Please give a value for number of low filters")
    parser.add_argument('--num_high',
                        help="Please give a value for number of high filters")
    parser.add_argument(
        '--num_mid', help="Please give a value for number of middle filters")
    parser.add_argument('--opt', help="Please give a value for AUTOGCN option")
    parser.add_argument('--gate',
                        help="Please give a value for AutoGCN gate option")
    parser.add_argument('--hidden_dim',
                        help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat',
                        help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--kernel', help="Please give a value for kernel")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout',
                        help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--graph_norm',
                        help="Please give a value for graph_norm")
    parser.add_argument('--batch_norm',
                        help="Please give a value for batch_norm")
    parser.add_argument('--sage_aggregator',
                        help="Please give a value for sage_aggregator")
    parser.add_argument('--data_mode',
                        help="Please give a value for data_mode")
    parser.add_argument('--num_pool', help="Please give a value for num_pool")
    parser.add_argument('--gnn_per_block',
                        help="Please give a value for gnn_per_block")
    parser.add_argument('--embedding_dim',
                        help="Please give a value for embedding_dim")
    parser.add_argument('--pool_ratio',
                        help="Please give a value for pool_ratio")
    parser.add_argument('--linkpred', help="Please give a value for linkpred")
    parser.add_argument('--cat', help="Please give a value for cat")
    parser.add_argument('--self_loop',
                        help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    parser.add_argument('--model_path',
                        help="Please give a value for model_path")

    args = parser.parse_args()
    with open(args.config) as f:
        config = json.load(f)

    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    # parameters
    params = config['params']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.runs is not None:
        params['runs'] = int(args.runs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.K is not None:
        net_params['K'] = int(args.K)
    if args.num_low is not None:
        net_params['num_low'] = int(args.num_low)
    if args.num_high is not None:
        net_params['num_high'] = int(args.num_high)
    if args.num_mid is not None:
        net_params['num_mid'] = int(args.num_mid)
    if args.opt is not None:
        net_params['opt'] = args.opt
    if args.gate is not None:
        net_params['gate'] = True if args.gate == 'True' else False
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)
    if args.residual is not None:
        net_params['residual'] = True if args.residual == 'True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat == 'True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.kernel is not None:
        net_params['kernel'] = int(args.kernel)
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.gated is not None:
        net_params['gated'] = True if args.gated == 'True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.graph_norm is not None:
        net_params['graph_norm'] = True if args.graph_norm == 'True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm == 'True' else False
    if args.sage_aggregator is not None:
        net_params['sage_aggregator'] = args.sage_aggregator
    if args.data_mode is not None:
        net_params['data_mode'] = args.data_mode
    if args.num_pool is not None:
        net_params['num_pool'] = int(args.num_pool)
    if args.gnn_per_block is not None:
        net_params['gnn_per_block'] = int(args.gnn_per_block)
    if args.embedding_dim is not None:
        net_params['embedding_dim'] = int(args.embedding_dim)
    if args.pool_ratio is not None:
        net_params['pool_ratio'] = float(args.pool_ratio)
    if args.linkpred is not None:
        net_params['linkpred'] = True if args.linkpred == 'True' else False
    if args.cat is not None:
        net_params['cat'] = True if args.cat == 'True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop == 'True' else False

    # ZINC
    net_params['num_atom_type'] = dataset.num_atom_type
    net_params['num_bond_type'] = dataset.num_bond_type

    if MODEL_NAME == 'DiffPool':
        # calculate assignment dimension: pool_ratio * largest graph's maximum
        # number of nodes  in the dataset
        num_nodes = [
            dataset.train[i][0].number_of_nodes()
            for i in range(len(dataset.train))
        ]
        max_num_node = max(num_nodes)
        net_params['assign_dim'] = int(
            max_num_node * net_params['pool_ratio']) * net_params['batch_size']

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    inference(MODEL_NAME, dataset, params, net_params, args.model_path)
Example #13
0
def start(args):
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info("args = %s", args)

    dataset = LoadData(args.data_name)
    if args.data_name == 'SBM_PATTERN':
        in_dim = 3
        num_classes = 2
    else:
        in_dim = 7
        num_classes = 6
    print(f"=> input dimension: {in_dim}, number classes: {num_classes}")

    criterion = MyCriterion(num_classes)
    criterion = criterion.cuda()
    
    if args.data_name == 'SBM_PATTERN':
        genotype = PATTERN_Net
    elif args.data_name == 'SBM_CLUSTER':
        genotype = CLUSTER_Net
    else:
        print("Unknown dataset.")
        exit()

    print('=> loading from genotype: \n', genotype)
    model = Network(args, genotype, num_classes, in_dim, criterion)
    model = model.cuda()
    logging.info("param size = %fMB", count_parameters_in_MB(model))

    train_data, val_data, test_data = dataset.train, dataset.val, dataset.test

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size = args.batch_size,
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate,
        shuffle = True)

    valid_queue = torch.utils.data.DataLoader(
        val_data, batch_size = args.batch_size,
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate,
        shuffle = False)

    test_queue = torch.utils.data.DataLoader(
        test_data, batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
        collate_fn=dataset.collate,
        shuffle = False)
    
    
    
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum,
                                weight_decay=args.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    elif args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                     factor=0.5,
                                                     patience=5,
                                                     verbose=True)
    
    for epoch in range(args.epochs):
        logging.info('[EPOCH]\t%d', epoch)
        if args.optimizer == 'SGD':
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('[LR]\t%f', lr)

        macro_acc, micro_acc, train_obj = train(train_queue, model, criterion, optimizer)
        # validation
        macro_acc, micro_acc, valid_obj = infer(valid_queue, model, criterion, stage = 'validating')
        # testing
        macro_acc, micro_acc, test_obj = infer(test_queue, model, criterion, stage = ' testing   ')

        if args.optimizer == 'ADAM':
            scheduler.step(valid_obj)
            if optimizer.param_groups[0]['lr'] < 1e-5:
                print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
                break
Example #14
0
def main(config):
    """
        USER CONTROLS
    """

    # parameters
    params = config['params']

    # dataset
    DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)

    # device
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    out_dir = config['out_dir']

    # GNN model
    MODEL_NAME = config['model']

    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']

    # Superpixels
    net_params['in_dim'] = dataset.train[0][0].ndata['feat'][0].size(0)
    net_params['in_dim_edge'] = dataset.train[0][0].edata['feat'][0].size(0)
    num_classes = len(np.unique(np.array(dataset.train[:][1])))
    net_params['n_classes'] = num_classes

    if MODEL_NAME == 'DiffPool':
        # calculate assignment dimension: pool_ratio * largest graph's maximum
        # number of nodes  in the dataset
        max_num_nodes_train = max([
            dataset.train[i][0].number_of_nodes()
            for i in range(len(dataset.train))
        ])
        max_num_nodes_test = max([
            dataset.test[i][0].number_of_nodes()
            for i in range(len(dataset.test))
        ])
        max_num_node = max(max_num_nodes_train, max_num_nodes_test)
        net_params['assign_dim'] = int(
            max_num_node * net_params['pool_ratio']) * net_params['batch_size']

    if MODEL_NAME == 'RingGNN':
        num_nodes_train = [
            dataset.train[i][0].number_of_nodes()
            for i in range(len(dataset.train))
        ]
        num_nodes_test = [
            dataset.test[i][0].number_of_nodes()
            for i in range(len(dataset.test))
        ]
        num_nodes = num_nodes_train + num_nodes_test
        net_params['avg_node_num'] = int(np.ceil(np.mean(num_nodes)))

    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
def main():    
    
    config_path = ['configs/superpixels_graph_classification_GCN_MNIST.json',
                   'configs/superpixels_graph_classification_GIN_MNIST.json',
                   'configs/superpixels_graph_classification_GAT_MNIST.json',
                   
                   'configs/superpixels_graph_classification_GCN_CIFAR10.json',
                   'configs/superpixels_graph_classification_GIN_CIFAR10.json',
                   'configs/superpixels_graph_classification_GAT_CIFAR10.json']
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=int, default=0,
        help="Please give a config.json file with training/model/data/param details")

    parser.add_argument('--debug', action='store_true', default=False,
        help="Please give a value for gpu id")

    parser.add_argument('--resume', action='store_true', default=False,
        help="Please give a value for gpu id")

    parser.add_argument('--head', action='store_true', default=False,
        help="use head or not")

    parser.add_argument('--aug', type=str, default='nn',
        help="Please give a value for gpu id")

    parser.add_argument('--temp', type=float, default=0.5,
        help="Please give a value for gpu id")

    parser.add_argument('--drop_percent', type=float, default=0.2,
        help="Please give a value for gpu id")  
        
    parser.add_argument('--seed', default=41,
        help="Please give a value for seed")

    parser.add_argument('--gpu_id', default=0,
        help="Please give a value for gpu id")

    parser.add_argument('--epochs', type=int,  default=80, help="Please give a value for epochs")
    parser.add_argument('--decreasing_lr', default='50, 60', help='decreasing strategy')
    parser.add_argument('--init_lr',  help="Please give a value for init_lr")
    parser.add_argument('--dataset', help="Please give a value for dataset name")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--batch_size', help="Please give a value for batch_size")
    parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay',  type=float, default=1e-6,  help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval")    
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat', help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--kernel', help="Please give a value for kernel")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--graph_norm', help="Please give a value for graph_norm")
    parser.add_argument('--batch_norm', help="Please give a value for batch_norm")
    parser.add_argument('--sage_aggregator', help="Please give a value for sage_aggregator")
    parser.add_argument('--data_mode', help="Please give a value for data_mode")
    parser.add_argument('--num_pool', help="Please give a value for num_pool")
    parser.add_argument('--gnn_per_block', help="Please give a value for gnn_per_block")
    parser.add_argument('--embedding_dim', help="Please give a value for embedding_dim")
    parser.add_argument('--pool_ratio', help="Please give a value for pool_ratio")
    parser.add_argument('--linkpred', help="Please give a value for linkpred")
    parser.add_argument('--cat', help="Please give a value for cat")
    parser.add_argument('--self_loop', help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    args = parser.parse_args()
    with open(config_path[args.config]) as f:
        config = json.load(f)
        
    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    # parameters
    params = config['params']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)   
    if args.residual is not None:
        net_params['residual'] = True if args.residual=='True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat=='True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.kernel is not None:
        net_params['kernel'] = int(args.kernel)
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.gated is not None:
        net_params['gated'] = True if args.gated=='True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.graph_norm is not None:
        net_params['graph_norm'] = True if args.graph_norm=='True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm=='True' else False
    if args.sage_aggregator is not None:
        net_params['sage_aggregator'] = args.sage_aggregator
    if args.data_mode is not None:
        net_params['data_mode'] = args.data_mode
    if args.num_pool is not None:
        net_params['num_pool'] = int(args.num_pool)
    if args.gnn_per_block is not None:
        net_params['gnn_per_block'] = int(args.gnn_per_block)
    if args.embedding_dim is not None:
        net_params['embedding_dim'] = int(args.embedding_dim)
    if args.pool_ratio is not None:
        net_params['pool_ratio'] = float(args.pool_ratio)
    if args.linkpred is not None:
        net_params['linkpred'] = True if args.linkpred=='True' else False
    if args.cat is not None:
        net_params['cat'] = True if args.cat=='True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop=='True' else False
        
    # Superpixels
    net_params['in_dim'] = dataset.train[0][0].ndata['feat'][0].size(0)
    net_params['in_dim_edge'] = dataset.train[0][0].edata['feat'][0].size(0)
    num_classes = len(np.unique(np.array(dataset.train[:][1])))
    net_params['n_classes'] = num_classes

    if MODEL_NAME == 'DiffPool':
        # calculate assignment dimension: pool_ratio * largest graph's maximum
        # number of nodes  in the dataset
        max_num_nodes_train = max([dataset.train[i][0].number_of_nodes() for i in range(len(dataset.train))])
        max_num_nodes_test = max([dataset.test[i][0].number_of_nodes() for i in range(len(dataset.test))])
        max_num_node = max(max_num_nodes_train, max_num_nodes_test)
        net_params['assign_dim'] = int(max_num_node * net_params['pool_ratio']) * net_params['batch_size']
    
    
    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, args)
def main():
    """
        USER CONTROLS
    """

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        help=
        "Please give a config.json file with training/model/data/param details"
    )
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--dataset',
                        help="Please give a value for dataset name")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed', help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--batch_size',
                        help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor',
                        help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience',
                        help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay',
                        help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval',
                        help="Please give a value for print_epoch_interval")
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--hidden_dim',
                        help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat',
                        help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--kernel', help="Please give a value for kernel")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout',
                        help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--graph_norm',
                        help="Please give a value for graph_norm")
    parser.add_argument('--batch_norm',
                        help="Please give a value for batch_norm")
    parser.add_argument('--sage_aggregator',
                        help="Please give a value for sage_aggregator")
    parser.add_argument('--data_mode',
                        help="Please give a value for data_mode")
    parser.add_argument('--num_pool', help="Please give a value for num_pool")
    parser.add_argument('--gnn_per_block',
                        help="Please give a value for gnn_per_block")
    parser.add_argument('--embedding_dim',
                        help="Please give a value for embedding_dim")
    parser.add_argument('--pool_ratio',
                        help="Please give a value for pool_ratio")
    parser.add_argument('--linkpred', help="Please give a value for linkpred")
    parser.add_argument('--cat', help="Please give a value for cat")
    parser.add_argument('--self_loop',
                        help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    args = parser.parse_args()
    with open(args.config) as f:
        config = json.load(f)

    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    # parameters
    params = config['params']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)
    if args.residual is not None:
        net_params['residual'] = True if args.residual == 'True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat == 'True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.kernel is not None:
        net_params['kernel'] = int(args.kernel)
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.gated is not None:
        net_params['gated'] = True if args.gated == 'True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.graph_norm is not None:
        net_params['graph_norm'] = True if args.graph_norm == 'True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm == 'True' else False
    if args.sage_aggregator is not None:
        net_params['sage_aggregator'] = args.sage_aggregator
    if args.data_mode is not None:
        net_params['data_mode'] = args.data_mode
    if args.num_pool is not None:
        net_params['num_pool'] = int(args.num_pool)
    if args.gnn_per_block is not None:
        net_params['gnn_per_block'] = int(args.gnn_per_block)
    if args.embedding_dim is not None:
        net_params['embedding_dim'] = int(args.embedding_dim)
    if args.pool_ratio is not None:
        net_params['pool_ratio'] = float(args.pool_ratio)
    if args.linkpred is not None:
        net_params['linkpred'] = True if args.linkpred == 'True' else False
    if args.cat is not None:
        net_params['cat'] = True if args.cat == 'True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop == 'True' else False

    # Superpixels
    net_params['in_dim'] = dataset.train[0][0].ndata['feat'][0].size(0)
    net_params['in_dim_edge'] = dataset.train[0][0].edata['feat'][0].size(0)
    num_classes = len(np.unique(np.array(dataset.train[:][1])))
    net_params['n_classes'] = num_classes

    if MODEL_NAME == 'DiffPool':
        # calculate assignment dimension: pool_ratio * largest graph's maximum
        # number of nodes  in the dataset
        max_num_nodes_train = max([
            dataset.train[i][0].number_of_nodes()
            for i in range(len(dataset.train))
        ])
        max_num_nodes_test = max([
            dataset.test[i][0].number_of_nodes()
            for i in range(len(dataset.test))
        ])
        max_num_node = max(max_num_nodes_train, max_num_nodes_test)
        net_params['assign_dim'] = int(
            max_num_node * net_params['pool_ratio']) * net_params['batch_size']

    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
Example #17
0
def start(args):
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    logging.info("args = %s", args)

    dataset = LoadData(args.data_name)
    in_dim = dataset.num_atom_type
    num_classes = 1
    criterion = nn.L1Loss()
    criterion = criterion.cuda()
    print(f"=> input dimension: {in_dim}, number classes: {num_classes}")

    genotype = ZINC_Net
    print('=> loading from genotype: \n', genotype)
    # model = Network(genotype, args.layers, in_dim, args.feature_dim, num_classes, criterion, args.data_type, args.readout, args.dropout)
    model = Network(args, genotype, num_classes, in_dim, criterion)
    model = model.cuda()
    logging.info("=> param size = %f", count_parameters_in_MB(model) * 1e6)

    train_data, val_data, test_data = dataset.train, dataset.val, dataset.test

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              pin_memory=True,
                                              num_workers=args.workers,
                                              collate_fn=dataset.collate,
                                              shuffle=True)

    valid_queue = torch.utils.data.DataLoader(val_data,
                                              batch_size=args.batch_size,
                                              pin_memory=True,
                                              num_workers=args.workers,
                                              collate_fn=dataset.collate,
                                              shuffle=False)

    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             pin_memory=True,
                                             num_workers=args.workers,
                                             collate_fn=dataset.collate,
                                             shuffle=False)

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    elif args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=0.001,
                                     weight_decay=0.0)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='min',
                                                               factor=0.5,
                                                               patience=10,
                                                               verbose=True)

    for epoch in range(args.epochs):
        logging.info('[EPOCH]\t%d', epoch)
        if args.optimizer == 'SGD':
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('[LR]\t%f', lr)

        # training
        train_mae, train_obj = train(train_queue, model, criterion, optimizer)
        # validation
        valid_mae, valid_obj = infer(valid_queue,
                                     model,
                                     criterion,
                                     stage='validating')
        # testing
        test_mae, test_obj = infer(test_queue,
                                   model,
                                   criterion,
                                   stage='testing   ')
        desc = '[train] mae: {:.3f}, loss: {:.3f}\t[validate] mae:{:.3f}, loss: {:.3f}\t[test] mae: {:.3f}, loss: {:.3f}'.format(
            train_mae, train_obj, valid_mae, valid_obj, test_mae, test_obj)
        logging.info(desc)

        if args.optimizer == 'ADAM':
            scheduler.step(valid_obj)
            if optimizer.param_groups[0]['lr'] < 1e-5:
                print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
                break
Example #18
0
def main():

    start_time_str = time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')

    """
        USER CONTROLS
    """
    
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details")
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--dataset', help="Please give a value for dataset name")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed', help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--batch_size', help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay', help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval")    
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--batch_norm', help="Please give a value for batch_norm")
    parser.add_argument('--self_loop', help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    parser.add_argument('--verbose', help="Please give a value for verbose")
    parser.add_argument('--only_view_params', help="Please give a value for only_view_params")
    parser.add_argument('--undirected', help="Please give a value for undirected")
    parser.add_argument('--max_order', help="Please give a value for max_order")
    parser.add_argument('--gru', help="Please give a value for gru")
    parser.add_argument('--activation', help="Please give a value for activation")
    parser.add_argument('--test', help="Please give a value for test")
    args = parser.parse_args()

    if args.config is not None:
        with open(args.config) as f:
            config = json.load(f)
    else:
        config = {'gpu': {'use': False, 'id': 0}, 'params': {}, 'net_params': {}}


    only_view_params = False
    if args.only_view_params is not None:
        only_view_params = True if args.only_view_params=='True' else False

    test_mode = False
    if args.test is not None:
        test_mode = True if args.test=='True' else False

    verbose_mode = False
    if args.verbose is not None:
        verbose_mode = True if args.verbose=='True' else False
        
    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # Model name
    if args.model is not None:
        MODEL_NAME = args.model
    elif 'model' in config:
        MODEL_NAME = config['model']
    else:
        raise Exception('No specified model (--model)')
    # Dataset name
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    elif 'dataset' in config:
        DATASET_NAME = config['dataset']
    else:
        raise Exception('No specified dataset (--dataset)')
    # Out directory
    if args.out_dir is not None:
        out_dir = args.out_dir
    elif 'out_dir' in config:
        out_dir = config['out_dir']
    else:
        raise Exception('No specified out directory (--out_dir)')


    '''
    Load dataset
    '''
    # Superpixels dataset
    dataset = LoadData(DATASET_NAME)


    '''
        TEST model pipeline
    '''
    if test_mode:
        print ('=' * 10 + ' TEST mode ' + '=' * 10)
        test_pipeline(MODEL_NAME, dataset, device, verbose_mode, out_dir)

        return


    '''
        TRAIN model pipeline
    '''
    # parameters
    params = config['params']
    if not 'verbose' in params:
        params['verbose'] = False
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    if args.verbose is not None:
        params['verbose'] = True if args.verbose=='True' else False
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if not 'max_order' in net_params:
        net_params['max_order'] = 2
    if not 'gru' in net_params:
        net_params['gru'] = False
    if not 'undirected' in net_params:
        net_params['undirected'] = False
    if not 'activation' in net_params:
        net_params['activation'] = 'relu'
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)   
    if args.residual is not None:
        net_params['residual'] = True if args.residual=='True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.gated is not None:
        net_params['gated'] = True if args.gated=='True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm=='True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop=='True' else False
    if args.undirected is not None:
        net_params['undirected'] = True if args.undirected=='True' else False
    if args.max_order is not None:
        net_params['max_order'] = int(args.max_order)
    if args.gru is not None:
        net_params['gru'] = True if args.gru=='True' else False
    if args.activation is not None:
        net_params['activation'] = args.activation

    net_params['in_dim'] = dataset.train[0][0].ndata['feat'][0].size(0)
    net_params['in_dim_edge'] = dataset.train[0][0].edata['feat'][0].size(0)
    num_classes = len(np.unique(np.array(dataset.train[:][1])))
    net_params['n_classes'] = num_classes

    # Set random seed
    set_random_seed(params['seed'], device)

    # View parameters
    net_params['total_param'] = view_model_param(MODEL_NAME, net_params, params['verbose'])
    if only_view_params:
        print('== View Parameters only ==')
        return
    
    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + start_time_str
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + start_time_str
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + start_time_str
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + start_time_str
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')
        
    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    print('\nResult output:', write_file_name)

    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
def train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, dirs):
    avg_test_acc = []
    avg_train_acc = []
    avg_epochs = []

    t0 = time.time()
    per_epoch_time = []

    dataset = LoadData(DATASET_NAME)

    if MODEL_NAME in ['GCN', 'GAT']:
        if net_params['self_loop']:
            print(
                "[!] Adding graph self-loops for GCN/GAT models (central node trick)."
            )
            dataset._add_self_loops()

    if net_params['pos_enc']:
        print("[!] Adding graph positional encoding.")
        dataset._add_positional_encodings(net_params['pos_enc_dim'])  #TODO

    trainset, valset, testset = dataset.train, dataset.val, dataset.test

    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params['device']

    # Write the network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""" \
                .format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param']))

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for split_number in range(5):

            t0_split = time.time()
            log_dir = os.path.join(root_log_dir, "RUN_" + str(split_number))
            writer = SummaryWriter(log_dir=log_dir)

            # setting seeds
            random.seed(params['seed'])
            np.random.seed(params['seed'])
            torch.manual_seed(params['seed'])
            if device.type == 'cuda':
                torch.cuda.manual_seed(params['seed'])

            print("RUN NUMBER: ", split_number)
            trainset, valset, testset = dataset.train[
                split_number], dataset.val[split_number], dataset.test[
                    split_number]
            print("Training Graphs: ", len(trainset))
            print("Validation Graphs: ", len(valset))
            print("Test Graphs: ", len(testset))
            print("Number of Classes: ", net_params['n_classes'])

            model = gnn_model(MODEL_NAME, net_params)
            model = model.to(device)
            optimizer = optim.Adam(model.parameters(),
                                   lr=params['init_lr'],
                                   weight_decay=params['weight_decay'])
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='min',
                factor=params['lr_reduce_factor'],
                patience=params['lr_schedule_patience'],
                verbose=True)

            epoch_train_losses, epoch_val_losses = [], []
            epoch_train_accs, epoch_val_accs = [], []

            # batching exception for Diffpool
            drop_last = True if MODEL_NAME == 'DiffPool' else False
            # drop_last = False

            if MODEL_NAME in ['RingGNN', '3WLGNN']:
                # import train functions specific for WL-GNNs
                from train.train_CSL_graph_classification import train_epoch_dense as train_epoch, evaluate_network_dense as evaluate_network
                from functools import partial  # util function to pass pos_enc flag to collate function

                train_loader = DataLoader(trainset,
                                          shuffle=True,
                                          collate_fn=partial(
                                              dataset.collate_dense_gnn,
                                              pos_enc=net_params['pos_enc']))
                val_loader = DataLoader(valset,
                                        shuffle=False,
                                        collate_fn=partial(
                                            dataset.collate_dense_gnn,
                                            pos_enc=net_params['pos_enc']))
                test_loader = DataLoader(testset,
                                         shuffle=False,
                                         collate_fn=partial(
                                             dataset.collate_dense_gnn,
                                             pos_enc=net_params['pos_enc']))

            else:
                # import train functions for all other GCNs
                from train.train_CSL_graph_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network

                train_loader = DataLoader(trainset,
                                          batch_size=params['batch_size'],
                                          shuffle=True,
                                          drop_last=drop_last,
                                          collate_fn=dataset.collate)
                val_loader = DataLoader(valset,
                                        batch_size=params['batch_size'],
                                        shuffle=False,
                                        drop_last=drop_last,
                                        collate_fn=dataset.collate)
                test_loader = DataLoader(testset,
                                         batch_size=params['batch_size'],
                                         shuffle=False,
                                         drop_last=drop_last,
                                         collate_fn=dataset.collate)

            with tqdm(range(params['epochs'])) as t:
                for epoch in t:

                    t.set_description('Epoch %d' % epoch)

                    start = time.time()

                    if MODEL_NAME in [
                            'RingGNN', '3WLGNN'
                    ]:  # since different batch training function for dense GNNs
                        epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
                            model, optimizer, device, train_loader, epoch,
                            params['batch_size'])
                    else:  # for all other models common train function
                        epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
                            model, optimizer, device, train_loader, epoch)

                    # epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch)
                    epoch_val_loss, epoch_val_acc = evaluate_network(
                        model, device, val_loader, epoch)

                    _, epoch_test_acc = evaluate_network(
                        model, device, test_loader, epoch)

                    epoch_train_losses.append(epoch_train_loss)
                    epoch_val_losses.append(epoch_val_loss)
                    epoch_train_accs.append(epoch_train_acc)
                    epoch_val_accs.append(epoch_val_acc)

                    writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                    writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                    writer.add_scalar('train/_acc', epoch_train_acc, epoch)
                    writer.add_scalar('val/_acc', epoch_val_acc, epoch)
                    writer.add_scalar('test/_acc', epoch_test_acc, epoch)
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'], epoch)

                    epoch_train_acc = 100. * epoch_train_acc
                    epoch_test_acc = 100. * epoch_test_acc

                    t.set_postfix(time=time.time() - start,
                                  lr=optimizer.param_groups[0]['lr'],
                                  train_loss=epoch_train_loss,
                                  val_loss=epoch_val_loss,
                                  train_acc=epoch_train_acc,
                                  val_acc=epoch_val_acc,
                                  test_acc=epoch_test_acc)

                    per_epoch_time.append(time.time() - start)

                    # Saving checkpoint
                    ckpt_dir = os.path.join(root_ckpt_dir,
                                            "RUN_" + str(split_number))
                    if not os.path.exists(ckpt_dir):
                        os.makedirs(ckpt_dir)
                    torch.save(
                        model.state_dict(),
                        '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))

                    files = glob.glob(ckpt_dir + '/*.pkl')
                    for file in files:
                        epoch_nb = file.split('_')[-1]
                        epoch_nb = int(epoch_nb.split('.')[0])
                        if epoch_nb < epoch - 1:
                            os.remove(file)

                    scheduler.step(epoch_val_loss)

                    if optimizer.param_groups[0]['lr'] < params['min_lr']:
                        print("\n!! LR EQUAL TO MIN LR SET.")
                        break

                    # Stop training after params['max_time'] hours
                    if time.time() - t0_split > params[
                            'max_time'] * 3600 / 10:  # Dividing max_time by 10, since there are 10 runs in TUs
                        print('-' * 89)
                        print(
                            "Max_time for one train-val-test split experiment elapsed {:.3f} hours, so stopping"
                            .format(params['max_time'] / 10))
                        break

            _, test_acc = evaluate_network(model, device, test_loader, epoch)
            _, train_acc = evaluate_network(model, device, train_loader, epoch)
            avg_test_acc.append(test_acc)
            avg_train_acc.append(train_acc)
            avg_epochs.append(epoch)

            print("Test Accuracy [LAST EPOCH]: {:.4f}".format(test_acc))
            print("Train Accuracy [LAST EPOCH]: {:.4f}".format(train_acc))

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')

    print("TOTAL TIME TAKEN: {:.4f}hrs".format((time.time() - t0) / 3600))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))

    # Final test accuracy value averaged over 5-fold
    print("""\n\n\nFINAL RESULTS\n\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}""" \
          .format(np.mean(np.array(avg_test_acc)) * 100, np.std(avg_test_acc) * 100))
    print("\nAll splits Test Accuracies:\n", avg_test_acc)
    print("""\n\n\nFINAL RESULTS\n\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}""" \
          .format(np.mean(np.array(avg_train_acc)) * 100, np.std(avg_train_acc) * 100))
    print("\nAll splits Train Accuracies:\n", avg_train_acc)

    writer.close()
    """
        Write the results in out/results folder
    """
    with open(write_file_name + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST ACCURACY averaged: {:.3f}\n with test acc s.d. {:.3f}\nTRAIN ACCURACY averaged: {:.3f}\n with train s.d. {:.3f}\n\n
    Convergence Time (Epochs): {:.3f}\nTotal Time Taken: {:.3f} hrs\nAverage Time Per Epoch: {:.3f} s\n\n\nAll Splits Test Accuracies: {}\n\nAll Splits Train Accuracies: {}""" \
                .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                        np.mean(np.array(avg_test_acc)) * 100, np.std(avg_test_acc) * 100,
                        np.mean(np.array(avg_train_acc)) * 100, np.std(avg_train_acc) * 100, np.mean(np.array(avg_epochs)),
                        (time.time() - t0) / 3600, np.mean(per_epoch_time), avg_test_acc, avg_train_acc))
def train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, args):

    # setting seeds
    random.seed(params['seed'])
    np.random.seed(params['seed'])
    torch.manual_seed(params['seed'])
    device = net_params['device']
    if device == 'cuda':
        torch.cuda.manual_seed(params['seed'])

    dataset = LoadData(DATASET_NAME)
    trainset, valset, testset = dataset.train, dataset.val, dataset.test

    net_params['in_dim'] = torch.unique(dataset.train[0][0].ndata['feat'],
                                        dim=0).size(
                                            0)  # node_dim (feat is an integer)
    net_params['n_classes'] = torch.unique(dataset.train[0][1], dim=0).size(0)
    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)

    load_model = args.load_model
    aug_type_list = [
        'drop_nodes', 'drop_add_edges', 'noise', 'mask', 'subgraph'
    ]

    if MODEL_NAME in ['GCN', 'GAT']:
        if net_params['self_loop']:
            print(
                "[!] Adding graph self-loops for GCN/GAT models (central node trick)."
            )
            dataset._add_self_loops()

    print('-' * 40 + "Finetune Option" + '-' * 40)
    print('SEED:           [{}]'.format(params['seed']))
    print("Data  Name:     [{}]".format(DATASET_NAME))
    print("Model Name:     [{}]".format(MODEL_NAME))
    print("Training Graphs:[{}]".format(len(trainset)))
    print("Valid Graphs:   [{}]".format(len(valset)))
    print("Test Graphs:    [{}]".format(len(testset)))
    print("Number Classes: [{}]".format(net_params['n_classes']))
    print("Learning rate:  [{}]".format(params['init_lr']))
    print('-' * 40 + "Contrastive Option" + '-' * 40)
    print("Load model:     [{}]".format(load_model))
    print("Aug Type:       [{}]".format(aug_type_list[args.aug]))
    print("Projection head:[{}]".format(args.head))
    print('-' * 100)

    model = gnn_model(MODEL_NAME, net_params)

    if load_model:
        output_path = './001_contrastive_models'
        save_model_dir0 = os.path.join(output_path, DATASET_NAME)
        save_model_dir1 = os.path.join(save_model_dir0,
                                       aug_type_list[args.aug])

        if args.head:
            save_model_dir1 += "_head"
        else:
            save_model_dir1 += "_no_head"
        save_model_dir2 = os.path.join(save_model_dir1, MODEL_NAME)
        load_file_name = glob.glob(save_model_dir2 + '/*.pkl')
        checkpoint = torch.load(load_file_name[-1])
        model_dict = model.state_dict()

        state_dict = {
            k: v
            for k, v in checkpoint.items() if k in model_dict.keys()
        }
        model.load_state_dict(state_dict)
        print('Success load pre-trained model!: [{}]'.format(
            load_file_name[-1]))
    else:
        print('No model load!: Test baseline! ')

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=params['init_lr'],
                           weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=params['lr_reduce_factor'],
        patience=params['lr_schedule_patience'],
        verbose=True)

    train_loader = DataLoader(trainset,
                              batch_size=params['batch_size'],
                              shuffle=True,
                              collate_fn=dataset.collate)
    val_loader = DataLoader(valset,
                            batch_size=params['batch_size'],
                            shuffle=False,
                            collate_fn=dataset.collate)
    test_loader = DataLoader(testset,
                             batch_size=params['batch_size'],
                             shuffle=False,
                             collate_fn=dataset.collate)

    for epoch in range(params['epochs']):

        epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
            model, optimizer, device, train_loader, epoch)
        epoch_val_loss, epoch_val_acc = evaluate_network(
            model, device, val_loader, epoch)
        epoch_test_loss, epoch_test_acc = evaluate_network(
            model, device, test_loader, epoch)
        _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch)

        print('-' * 80)
        print(
            time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' +
            "Epoch [{:>2d}]  Test Acc: [{:.4f}]".format(
                epoch + 1, epoch_test_acc))
        print('-' * 80)

        scheduler.step(epoch_val_loss)

        if optimizer.param_groups[0]['lr'] < params['min_lr']:
            print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
            break

    _, test_acc = evaluate_network(model, device, test_loader, epoch)
    _, train_acc = evaluate_network(model, device, train_loader, epoch)
    print("Test Accuracy: {:.4f}".format(test_acc))
    print("Train Accuracy: {:.4f}".format(train_acc))
    return train_acc, test_acc
Example #21
0
def p():
    return Dataset(LoadData('SBM_PATTERN'))
Example #22
0
def c():
    return Dataset(LoadData('SBM_CLUSTER'))
Example #23
0
def main():
    """
        USER CONTROLS
    """

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        help=
        "Please give a config.json file with training/model/data/param details"
    )
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model', help="Please give a value for model name")
    parser.add_argument('--dataset',
                        help="Please give a value for dataset name")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed', help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--batch_size',
                        help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor',
                        help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience',
                        help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay',
                        help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval',
                        help="Please give a value for print_epoch_interval")
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--hidden_dim',
                        help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat',
                        help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--in_feat_dropout',
                        help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--layer_norm',
                        help="Please give a value for layer_norm")
    parser.add_argument('--batch_norm',
                        help="Please give a value for batch_norm")
    parser.add_argument('--self_loop',
                        help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    parser.add_argument('--pos_enc_dim',
                        help="Please give a value for pos_enc_dim")
    parser.add_argument('--lap_pos_enc',
                        help="Please give a value for lap_pos_enc")
    parser.add_argument('--wl_pos_enc',
                        help="Please give a value for wl_pos_enc")
    args = parser.parse_args()
    with open(args.config) as f:
        config = json.load(f)

    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME)
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    # parameters
    params = config['params']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)
    if args.residual is not None:
        net_params['residual'] = True if args.residual == 'True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat == 'True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.layer_norm is not None:
        net_params['layer_norm'] = True if args.layer_norm == 'True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm == 'True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop == 'True' else False
    if args.lap_pos_enc is not None:
        net_params['lap_pos_enc'] = True if args.pos_enc == 'True' else False
    if args.pos_enc_dim is not None:
        net_params['pos_enc_dim'] = int(args.pos_enc_dim)
    if args.wl_pos_enc is not None:
        net_params['wl_pos_enc'] = True if args.pos_enc == 'True' else False

    # SBM
    net_params['in_dim'] = torch.unique(dataset.train[0][0].ndata['feat'],
                                        dim=0).size(
                                            0)  # node_dim (feat is an integer)
    net_params['n_classes'] = torch.unique(dataset.train[0][1], dim=0).size(0)

    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(
        config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs)
Example #24
0
def main():
    """
        USER CONTROLS
    """
    # --dataset ALKBH5_Baltz2012 --config configs/RNAgraph_graph_classification_GCN_in_vivo_100k.json --model GCN --debias False --seed 66
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--config',
        default='configs/RNAgraph_graph_classification_GCN_in_vivo_100k.json',
        help=
        "Please give a config.json file with training/model/data/param details"
    )
    parser.add_argument('--gpu_id', help="Please give a value for gpu id")
    parser.add_argument('--model',
                        default='GCN',
                        help="Please give a value for model name")
    parser.add_argument('--dataset',
                        default='ALKBH5_Baltz2012',
                        help="Please give a value for dataset name")
    parser.add_argument('--out_dir', help="Please give a value for out_dir")
    parser.add_argument('--seed',
                        default=39,
                        help="Please give a value for seed")
    parser.add_argument('--epochs', help="Please give a value for epochs")
    parser.add_argument('--batch_size',
                        help="Please give a value for batch_size")
    parser.add_argument('--init_lr', help="Please give a value for init_lr")
    parser.add_argument('--lr_reduce_factor',
                        help="Please give a value for lr_reduce_factor")
    parser.add_argument('--lr_schedule_patience',
                        help="Please give a value for lr_schedule_patience")
    parser.add_argument('--min_lr', help="Please give a value for min_lr")
    parser.add_argument('--weight_decay',
                        help="Please give a value for weight_decay")
    parser.add_argument('--print_epoch_interval',
                        help="Please give a value for print_epoch_interval")
    parser.add_argument('--L', help="Please give a value for L")
    parser.add_argument('--hidden_dim',
                        help="Please give a value for hidden_dim")
    parser.add_argument('--out_dim', help="Please give a value for out_dim")
    parser.add_argument('--residual', help="Please give a value for residual")
    parser.add_argument('--edge_feat',
                        help="Please give a value for edge_feat")
    parser.add_argument('--readout', help="Please give a value for readout")
    parser.add_argument('--kernel', help="Please give a value for kernel")
    parser.add_argument('--n_heads', help="Please give a value for n_heads")
    parser.add_argument('--gated', help="Please give a value for gated")
    parser.add_argument('--in_feat_dropout',
                        help="Please give a value for in_feat_dropout")
    parser.add_argument('--dropout', help="Please give a value for dropout")
    parser.add_argument('--layer_norm',
                        help="Please give a value for layer_norm")
    parser.add_argument('--batch_norm',
                        help="Please give a value for batch_norm")
    parser.add_argument('--sage_aggregator',
                        help="Please give a value for sage_aggregator")
    parser.add_argument('--data_mode',
                        help="Please give a value for data_mode")
    parser.add_argument('--num_pool', help="Please give a value for num_pool")
    parser.add_argument('--gnn_per_block',
                        help="Please give a value for gnn_per_block")
    parser.add_argument('--embedding_dim',
                        help="Please give a value for embedding_dim")
    parser.add_argument('--pool_ratio',
                        help="Please give a value for pool_ratio")
    parser.add_argument('--linkpred', help="Please give a value for linkpred")
    parser.add_argument('--cat', help="Please give a value for cat")
    parser.add_argument('--self_loop',
                        help="Please give a value for self_loop")
    parser.add_argument('--max_time', help="Please give a value for max_time")
    parser.add_argument('--debias',
                        default='False',
                        help="Debias the data or not")
    parser.add_argument('--motif', help="get motifs or not")
    parser.add_argument('--best_epoch', help="best epoch for the result")
    args = parser.parse_args()
    with open(args.config) as f:
        config = json.load(f)

    # device
    if args.gpu_id is not None:
        config['gpu']['id'] = int(args.gpu_id)
        config['gpu']['use'] = True
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
    # model, dataset, out_dir
    if args.model is not None:
        MODEL_NAME = args.model
    else:
        MODEL_NAME = config['model']
    if args.debias is not None:
        config['debias'] = args.debias
    if args.motif is not None:
        config['motif'] = args.motif
    if args.best_epoch is not None:
        config['best_epoch'] = args.best_epoch
    if args.dataset is not None:
        DATASET_NAME = args.dataset
    else:
        DATASET_NAME = config['dataset']
    dataset = LoadData(DATASET_NAME, config)
    if args.out_dir is not None:
        out_dir = args.out_dir
    else:
        out_dir = config['out_dir']
    # parameters
    params = config['params']
    if args.seed is not None:
        params['seed'] = int(args.seed)
    if args.epochs is not None:
        params['epochs'] = int(args.epochs)
    if args.batch_size is not None:
        params['batch_size'] = int(args.batch_size)
    if args.init_lr is not None:
        params['init_lr'] = float(args.init_lr)
    if args.lr_reduce_factor is not None:
        params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    if args.lr_schedule_patience is not None:
        params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    if args.min_lr is not None:
        params['min_lr'] = float(args.min_lr)
    if args.weight_decay is not None:
        params['weight_decay'] = float(args.weight_decay)
    if args.print_epoch_interval is not None:
        params['print_epoch_interval'] = int(args.print_epoch_interval)
    if args.max_time is not None:
        params['max_time'] = float(args.max_time)
    # network parameters
    net_params = config['net_params']
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    if args.L is not None:
        net_params['L'] = int(args.L)
    if args.hidden_dim is not None:
        net_params['hidden_dim'] = int(args.hidden_dim)
    if args.out_dim is not None:
        net_params['out_dim'] = int(args.out_dim)
    if args.residual is not None:
        net_params['residual'] = True if args.residual == 'True' else False
    if args.edge_feat is not None:
        net_params['edge_feat'] = True if args.edge_feat == 'True' else False
    if args.readout is not None:
        net_params['readout'] = args.readout
    if args.kernel is not None:
        net_params['kernel'] = int(args.kernel)
    if args.n_heads is not None:
        net_params['n_heads'] = int(args.n_heads)
    if args.gated is not None:
        net_params['gated'] = True if args.gated == 'True' else False
    if args.in_feat_dropout is not None:
        net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    if args.dropout is not None:
        net_params['dropout'] = float(args.dropout)
    if args.layer_norm is not None:
        net_params['layer_norm'] = True if args.layer_norm == 'True' else False
    if args.batch_norm is not None:
        net_params['batch_norm'] = True if args.batch_norm == 'True' else False
    if args.sage_aggregator is not None:
        net_params['sage_aggregator'] = args.sage_aggregator
    if args.data_mode is not None:
        net_params['data_mode'] = args.data_mode
    if args.num_pool is not None:
        net_params['num_pool'] = int(args.num_pool)
    if args.gnn_per_block is not None:
        net_params['gnn_per_block'] = int(args.gnn_per_block)
    if args.embedding_dim is not None:
        net_params['embedding_dim'] = int(args.embedding_dim)
    if args.pool_ratio is not None:
        net_params['pool_ratio'] = float(args.pool_ratio)
    if args.linkpred is not None:
        net_params['linkpred'] = True if args.linkpred == 'True' else False
    if args.cat is not None:
        net_params['cat'] = True if args.cat == 'True' else False
    if args.self_loop is not None:
        net_params['self_loop'] = True if args.self_loop == 'True' else False

    net_params['in_dim'] = dataset.train[0][0].ndata['feat'][0].size(0)
    # net_params['in_dim_edge'] = dataset.train_utils[0][0].edata['feat'][0].size(0)
    net_params['in_dim_edge'] = 1
    num_classes = len(np.unique(np.array(dataset.train[:][1])))
    net_params['n_classes'] = num_classes

    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str \
        (config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str \
        (config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str \
        (config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str \
        (config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)

    train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs, config)