예제 #1
0
파일: train.py 프로젝트: yinxx/ProteinGCN
def main():
    global args, best_error_global, best_error_local, savepath, dataset

    parser = buildParser()
    args = parser.parse_args()

    print('Torch Device being used: ', cfg.device)

    # create the savepath
    savepath = args.save_dir + str(args.name) + '/'
    if not os.path.exists(savepath):
        os.makedirs(savepath)

    # Writes to file and also to terminal
    sys.stdout = Logger(savepath)
    print(vars(args))

    best_error_global, best_error_local = 1e10, 1e10

    randomSeed(args.seed)

    # create train/val/test dataset separately
    assert os.path.exists(args.protein_dir), '{} does not exist!'.format(
        args.protein_dir)
    all_dirs = [
        d for d in os.listdir(args.protein_dir)
        if not d.startswith('.DS_Store')
    ]
    dir_len = len(all_dirs)
    indices = list(range(dir_len))
    random.shuffle(indices)

    train_size = math.floor(args.train * dir_len)
    val_size = math.floor(args.val * dir_len)
    test_size = math.floor(args.test * dir_len)
    test_dirs = all_dirs[:test_size]
    train_dirs = all_dirs[test_size:test_size + train_size]
    val_dirs = all_dirs[test_size + train_size:test_size + train_size +
                        val_size]
    print('Testing on {} protein directories:'.format(len(test_dirs)))

    dataset = ProteinDataset(args.pkl_dir,
                             args.id_prop,
                             args.atom_init,
                             random_seed=args.seed)

    print('Dataset length: ', len(dataset))

    # load all model args from pretrained model
    if args.pretrained is not None and os.path.isfile(args.pretrained):
        print("=> loading model params '{}'".format(args.pretrained))
        model_checkpoint = torch.load(
            args.pretrained, map_location=lambda storage, loc: storage)
        model_args = argparse.Namespace(**model_checkpoint['args'])
        # override all args value with model_args
        args.h_a = model_args.h_a
        args.h_g = model_args.h_g
        args.n_conv = model_args.n_conv
        args.random_seed = model_args.seed
        args.lr = model_args.lr

        print("=> loaded model params '{}'".format(args.pretrained))
    else:
        print("=> no model params found at '{}'".format(args.pretrained))

    # build model
    kwargs = {
        'pkl_dir': args.pkl_dir,  # Root directory for data
        'atom_init': args.atom_init,  # Atom Init filename
        'h_a': args.h_a,  # Dim of the hidden atom embedding learnt
        'h_g': args.h_g,  # Dim of the hidden graph embedding after pooling
        'n_conv': args.n_conv,  # Number of GCN layers
        'random_seed': args.seed,  # Seed to fix the simulation
        'lr': args.lr,  # Learning rate for optimizer
    }

    structures, _, _ = dataset[0]
    h_b = structures[1].shape[-1]
    kwargs['h_b'] = h_b  # Dim of the bond embedding initialization

    # Use DataParallel for faster training
    print("Let's use", torch.cuda.device_count(),
          "GPUs and Data Parallel Model.")
    model = ProteinGCN(**kwargs)
    model = torch.nn.DataParallel(model)
    model.cuda()

    print('Trainable Model Parameters: ', count_parameters(model))

    # Create dataloader to iterate through the dataset in batches
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset,
        train_dirs,
        val_dirs,
        test_dirs,
        collate_fn=collate_pool,
        num_workers=args.workers,
        batch_size=args.batch_size,
        pin_memory=False)

    try:
        print('Training data    : ', len(train_loader.sampler))
        print('Validation data  : ', len(val_loader.sampler))
        print('Testing data     : ', len(test_loader.sampler))
    except Exception as e:
        # sometimes test may not be defined
        print('\nException Cause: {}'.format(e.args[0]))

    # obtain target value normalizer
    if len(dataset) < args.avg_sample:
        sample_data_list = [dataset[i] for i in tqdm(range(len(dataset)))]
    else:
        sample_data_list = [
            dataset[i]
            for i in tqdm(random.sample(range(len(dataset)), args.avg_sample))
        ]

    _, _, sample_target = collate_pool(sample_data_list)
    normalizer_global = Normalizer(sample_target[0])
    normalizer_local = Normalizer(torch.tensor([0.0]))
    normalizer_local = Normalizer(sample_target[1])

    # load the model state dict from given pretrained model
    if args.pretrained is not None and os.path.isfile(args.pretrained):
        print("=> loading model '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained,
                                map_location=lambda storage, loc: storage)

        print('Best error global: ', checkpoint['best_error_global'])
        print('Best error local: ', checkpoint['best_error_local'])

        best_error_global = checkpoint['best_error_global']
        best_error_local = checkpoint['best_error_local']

        model.module.load_state_dict(checkpoint['state_dict'])
        model.module.optimizer.load_state_dict(checkpoint['optimizer'])
        normalizer_local.load_state_dict(checkpoint['normalizer_local'])
        normalizer_global.load_state_dict(checkpoint['normalizer_global'])
    else:
        print("=> no model found at '{}'".format(args.pretrained))

    # Main training loop
    for epoch in range(args.epochs):
        # Training
        [train_error_global, train_error_local,
         train_loss] = trainModel(train_loader,
                                  model,
                                  normalizer_global,
                                  normalizer_local,
                                  epoch=epoch)
        # Validation
        [val_error_global, val_error_local,
         val_loss] = trainModel(val_loader,
                                model,
                                normalizer_global,
                                normalizer_local,
                                epoch=epoch,
                                evaluation=True)

        # check for error overflow
        if (val_error_global != val_error_global) or (val_error_local !=
                                                      val_error_local):
            print('Exit due to NaN')
            sys.exit(1)

        # remember the best error and possibly save checkpoint
        is_best = val_error_global < best_error_global
        best_error_global = min(val_error_global, best_error_global)
        best_error_local = val_error_local

        # save best model
        if args.save_checkpoints:
            model.module.save(
                {
                    'epoch': epoch,
                    'state_dict': model.module.state_dict(),
                    'best_error_global': best_error_global,
                    'best_error_local': best_error_local,
                    'optimizer': model.module.optimizer.state_dict(),
                    'normalizer_global': normalizer_global.state_dict(),
                    'normalizer_local': normalizer_local.state_dict(),
                    'args': vars(args)
                }, is_best, savepath)

    # test best model using saved checkpoints
    if args.save_checkpoints and len(test_loader):
        print('---------Evaluate Model on Test Set---------------')
        # this try/except allows the code to test on the go or by defining a pretrained path separately
        try:
            best_checkpoint = torch.load(savepath + 'model_best.pth.tar')
        except Exception as e:
            best_checkpoint = torch.load(args.pretrained)

        model.module.load_state_dict(best_checkpoint['state_dict'])
        [test_error_global, test_error_local,
         test_loss] = trainModel(test_loader,
                                 model,
                                 normalizer_global,
                                 normalizer_local,
                                 testing=True)
예제 #2
0
파일: main.py 프로젝트: lukemelas/mt-cgcnn
def main():
    global args, best_mae_error

    # Dataset from CIF files
    dataset = CIFData(*args.data_options)
    print(f'Dataset size: {len(dataset)}')

    # Dataloader from dataset
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_pool,
        batch_size=args.batch_size,
        train_size=args.train_size,
        num_workers=args.workers,
        val_size=args.val_size,
        test_size=args.test_size,
        pin_memory=args.cuda,
        return_test=True)

    # Initialize data normalizer with sample of 500 points
    if args.task == 'classification':
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({'mean': 0., 'std': 1.})
    elif args.task == 'regression':
        if len(dataset) < 500:
            warnings.warn('Dataset has less than 500 data points. '
                          'Lower accuracy is expected. ')
            sample_data_list = [dataset[i] for i in range(len(dataset))]
        else:
            sample_data_list = [
                dataset[i] for i in sample(range(len(dataset)), 500)
            ]
        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)
    else:
        raise NameError('task argument must be regression or classification')

    # Build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len,
                                nbr_fea_len,
                                atom_fea_len=args.atom_fea_len,
                                n_conv=args.n_conv,
                                h_fea_len=args.h_fea_len,
                                n_h=args.n_h,
                                classification=(args.task == 'classification'))

    # GPU
    if args.cuda:
        model.cuda()

    # Loss function
    criterion = nn.NLLLoss() if args.task == 'classification' else nn.MSELoss()

    # Optimizer
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise NameError('optim argument must be SGD or Adam')

    # Scheduler
    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_milestones,
                            gamma=0.1)

    # Resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Train
    for epoch in range(args.start_epoch, args.epochs):

        # Train (one epoch)
        train(train_loader, model, criterion, optimizer, epoch, normalizer)

        # Validate
        mae_error = validate(val_loader, model, criterion, normalizer)
        assert mae_error == mae_error, 'NaN :('

        # Step learning rate scheduler
        scheduler.step(mae_error)

        # Save checkpoint
        if args.task == 'regression':
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_mae_error': best_mae_error,
                'optimizer': optimizer.state_dict(),
                'normalizer': normalizer.state_dict(),
                'args': vars(args)
            }, is_best)

    # Evaluate best model on test set
    print('--------- Evaluate model on test set ---------------')
    best_checkpoint = torch.load('model_best.pth.tar')
    model.load_state_dict(best_checkpoint['state_dict'])
    validate(test_loader, model, criterion, normalizer, test=True)