Exemplo n.º 1
0
def main():

    # Initialize arguments -- Just
    args = init_argparse('pdbbind')

    # Initialize file paths
    args = init_file_paths(args)

    # Initialize logger
    init_logger(args)

    # Initialize device and data type
    device, dtype = init_cuda(args)

    # Initialize dataloader
    args, datasets, num_species, charge_scale = initialize_datasets(args, args.datadir, 'pdbbind', 
                                                                    force_download=args.force_download,
                                                                    ignore_check=args.ignore_check
                                                                    )

    # Construct PyTorch dataloaders from datasets
    dataloaders = {split: DataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=args.shuffle if (split == 'train') else False,
                                     num_workers=args.num_workers,
                                     collate_fn=collate_fn)
                         for split, dataset in datasets.items()}

    # Initialize model
    model = CormorantPDBBind(args.maxl, args.max_sh, args.num_cg_levels, args.num_channels, num_species,
                             args.cutoff_type, args.hard_cut_rad, args.soft_cut_rad, args.soft_cut_width,
                             args.weight_init, args.level_gain, args.charge_power, args.basis_set,
                             charge_scale, args.gaussian_mask, args.top, args.input,
                             cgprod_bounded = True,
                             device=device, dtype=dtype)

    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)

    # Define a loss function. Just use L2 loss for now.
    loss_fn = torch.nn.functional.mse_loss

    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model, dataloaders['train'], args, charge_scale=charge_scale)

    # Instantiate the training class
    trainer = Engine(args, dataloaders, model, loss_fn, optimizer, scheduler, restart_epochs, device, dtype, clip_value=None)
    print('Initialized the trainer with clip value:',trainer.clip_value)

    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()

    # Train model.
    trainer.train()

    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()
Exemplo n.º 2
0
def main():

    # Initialize arguments -- Just
    args = init_argparse('qm9')

    # Initialize file paths
    args = init_file_paths(args)

    # Initialize logger
    init_logger(args)

    # Initialize device and data type
    device, dtype = init_cuda(args)

    # Initialize dataloader
    args, datasets, num_species, charge_scale = initialize_datasets(
        args,
        args.datadir,
        'qm9',
        subtract_thermo=args.subtract_thermo,
        force_download=args.force_download)

    # unit conversion (U0, U, G, H are not converted when loaded from our or the MoleculeNet data because the identifiers differ)
    qm9_to_eV = {
        'U0': 27.2114,
        'U': 27.2114,
        'G': 27.2114,
        'H': 27.2114,
        'zpve': 27211.4,
        'gap': 27.2114,
        'h**o': 27.2114,
        'lumo': 27.2114
    }
    for dataset in datasets.values():
        dataset.convert_units(qm9_to_eV)

    # Construct PyTorch dataloaders from datasets
    dataloaders = {
        split: DataLoader(dataset,
                          batch_size=args.batch_size,
                          shuffle=args.shuffle if
                          (split == 'train') else False,
                          num_workers=args.num_workers,
                          collate_fn=collate_fn)
        for split, dataset in datasets.items()
    }

    # Initialize model
    model = CormorantQM9(args.maxl,
                         args.max_sh,
                         args.num_cg_levels,
                         args.num_channels,
                         num_species,
                         args.cutoff_type,
                         args.hard_cut_rad,
                         args.soft_cut_rad,
                         args.soft_cut_width,
                         args.weight_init,
                         args.level_gain,
                         args.charge_power,
                         args.basis_set,
                         charge_scale,
                         args.gaussian_mask,
                         args.top,
                         args.input,
                         args.num_mpnn_levels,
                         device=device,
                         dtype=dtype)

    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)

    # Define a loss function. Just use L2 loss for now.
    loss_fn = torch.nn.functional.mse_loss

    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model,
                    dataloaders['train'],
                    args,
                    charge_scale=charge_scale)

    # Instantiate the training class
    trainer = Engine(args, dataloaders, model, loss_fn, optimizer, scheduler,
                     restart_epochs, device, dtype)

    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()

    # Train model.
    trainer.train()

    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()