Ejemplo n.º 1
0
def main():

    # Initialize arguments -- Just
    args = init_argparse('pdbbind')

    # Initialize file paths
    args = init_file_paths(args)

    # Initialize logger
    init_logger(args)

    # Initialize device and data type
    device, dtype = init_cuda(args)

    # Initialize dataloader
    args, datasets, num_species, charge_scale = initialize_datasets(args, args.datadir, 'pdbbind', 
                                                                    force_download=args.force_download,
                                                                    ignore_check=args.ignore_check
                                                                    )

    # Construct PyTorch dataloaders from datasets
    dataloaders = {split: DataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=args.shuffle if (split == 'train') else False,
                                     num_workers=args.num_workers,
                                     collate_fn=collate_fn)
                         for split, dataset in datasets.items()}

    # Initialize model
    model = CormorantPDBBind(args.maxl, args.max_sh, args.num_cg_levels, args.num_channels, num_species,
                             args.cutoff_type, args.hard_cut_rad, args.soft_cut_rad, args.soft_cut_width,
                             args.weight_init, args.level_gain, args.charge_power, args.basis_set,
                             charge_scale, args.gaussian_mask, args.top, args.input,
                             cgprod_bounded = True,
                             device=device, dtype=dtype)

    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)

    # Define a loss function. Just use L2 loss for now.
    loss_fn = torch.nn.functional.mse_loss

    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model, dataloaders['train'], args, charge_scale=charge_scale)

    # Instantiate the training class
    trainer = Engine(args, dataloaders, model, loss_fn, optimizer, scheduler, restart_epochs, device, dtype, clip_value=None)
    print('Initialized the trainer with clip value:',trainer.clip_value)

    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()

    # Train model.
    trainer.train()

    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()
Ejemplo n.º 2
0
def main():
    # Initialize arguments
    args = init_cormorant_argparse('lba')
    # Initialize file paths
    args = init_cormorant_file_paths(args)
    # Initialize logger
    init_logger(args)
    # Initialize device and data type
    device, dtype = init_cuda(args)
    # Initialize dataloader
    if args.format == 'npz':
        args, datasets, num_species, charge_scale = initialize_datasets(args, args.datadir, 'lba', args.ddir_suffix) 
    else:
        args, datasets, num_species, charge_scale = initialize_lba_data(args, args.datadir)        
    # Further differences for Siamese networks
    if args.siamese:
        collate_fn = collate_lba_siamese
    else:
        collate_fn = collate_lba
    # Construct PyTorch dataloaders from datasets
    dataloaders = {split: DataLoader(dataset, batch_size=args.batch_size,
                                     shuffle=args.shuffle if (split == 'train') else False,
                                     num_workers=args.num_workers, collate_fn=collate_fn)
                   for split, dataset in datasets.items()}
    # Initialize model
    if args.siamese:
        model = ENN_LBA_Siamese(args.maxl, args.max_sh, args.num_cg_levels, args.num_channels, num_species,
                        args.cutoff_type, args.hard_cut_rad, args.soft_cut_rad, args.soft_cut_width,
                        args.weight_init, args.level_gain, args.charge_power, args.basis_set,
                        charge_scale, args.gaussian_mask,
                        device=device, dtype=dtype)
    else:
        model = ENN_LBA(args.maxl, args.max_sh, args.num_cg_levels, args.num_channels, num_species,
                        args.cutoff_type, args.hard_cut_rad, args.soft_cut_rad, args.soft_cut_width,
                        args.weight_init, args.level_gain, args.charge_power, args.basis_set,
                        charge_scale, args.gaussian_mask, 
                        device=device, dtype=dtype)
    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)
    # Define a loss function. Just use L2 loss for now.
    loss_fn = torch.nn.functional.mse_loss
    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model, dataloaders['train'], args, charge_scale=charge_scale, siamese = args.siamese)
    # Instantiate the training class
    trainer = Engine(args, dataloaders, model, loss_fn, optimizer, scheduler, restart_epochs, device, dtype, clip_value=None)
    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()
    # Train model.
    trainer.train()
    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()
Ejemplo n.º 3
0
def main():

    # Initialize arguments -- Just
    args = init_argparse('qm9')

    # Initialize file paths
    args = init_file_paths(args)

    # Initialize logger
    init_logger(args)

    # Initialize device and data type
    device, dtype = init_cuda(args)

    # Initialize dataloader
    args, datasets, num_species, charge_scale = initialize_datasets(
        args,
        args.datadir,
        'qm9',
        subtract_thermo=args.subtract_thermo,
        force_download=args.force_download)

    # unit conversion (U0, U, G, H are not converted when loaded from our or the MoleculeNet data because the identifiers differ)
    qm9_to_eV = {
        'U0': 27.2114,
        'U': 27.2114,
        'G': 27.2114,
        'H': 27.2114,
        'zpve': 27211.4,
        'gap': 27.2114,
        'h**o': 27.2114,
        'lumo': 27.2114
    }
    for dataset in datasets.values():
        dataset.convert_units(qm9_to_eV)

    # Construct PyTorch dataloaders from datasets
    dataloaders = {
        split: DataLoader(dataset,
                          batch_size=args.batch_size,
                          shuffle=args.shuffle if
                          (split == 'train') else False,
                          num_workers=args.num_workers,
                          collate_fn=collate_fn)
        for split, dataset in datasets.items()
    }

    # Initialize model
    model = CormorantQM9(args.maxl,
                         args.max_sh,
                         args.num_cg_levels,
                         args.num_channels,
                         num_species,
                         args.cutoff_type,
                         args.hard_cut_rad,
                         args.soft_cut_rad,
                         args.soft_cut_width,
                         args.weight_init,
                         args.level_gain,
                         args.charge_power,
                         args.basis_set,
                         charge_scale,
                         args.gaussian_mask,
                         args.top,
                         args.input,
                         args.num_mpnn_levels,
                         device=device,
                         dtype=dtype)

    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)

    # Define a loss function. Just use L2 loss for now.
    loss_fn = torch.nn.functional.mse_loss

    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model,
                    dataloaders['train'],
                    args,
                    charge_scale=charge_scale)

    # Instantiate the training class
    trainer = Engine(args, dataloaders, model, loss_fn, optimizer, scheduler,
                     restart_epochs, device, dtype)

    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()

    # Train model.
    trainer.train()

    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()
Ejemplo n.º 4
0
def main():
    # Initialize arguments
    args = init_cormorant_argparse('res')
    # Initialize file paths
    args = init_cormorant_file_paths(args)
    # Initialize logger
    init_logger(args)
    # Initialize device and data type
    device, dtype = init_cuda(args)
    # Initialize dataloader
    args, datasets, num_species, charge_scale = initialize_datasets(
        args, args.datadir, 'res', args.ddir_suffix)
    # Construct PyTorch dataloaders from datasets
    dataloaders = {
        split: DataLoader(dataset,
                          batch_size=args.batch_size,
                          shuffle=args.shuffle if
                          (split == 'train') else False,
                          num_workers=args.num_workers,
                          collate_fn=collate_fn)
        for split, dataset in datasets.items()
    }
    # Initialize model
    model = ENN_RES(args.maxl,
                    args.max_sh,
                    args.num_cg_levels,
                    args.num_channels,
                    num_species,
                    args.cutoff_type,
                    args.hard_cut_rad,
                    args.soft_cut_rad,
                    args.soft_cut_width,
                    args.weight_init,
                    args.level_gain,
                    args.charge_power,
                    args.basis_set,
                    charge_scale,
                    args.gaussian_mask,
                    args.top,
                    args.input,
                    args.num_mpnn_levels,
                    num_classes=20,
                    cgprod_bounded=False,
                    cg_agg_normalization='relu',
                    cg_pow_normalization='relu',
                    device=device,
                    dtype=dtype)
    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)
    # Define cross-entropy as the loss function.
    loss_fn = torch.nn.functional.cross_entropy
    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model,
                    dataloaders['train'],
                    args,
                    charge_scale=charge_scale)
    # Instantiate the training class
    trainer = Engine(args,
                     dataloaders,
                     model,
                     loss_fn,
                     optimizer,
                     scheduler,
                     restart_epochs,
                     device,
                     dtype,
                     task='classification',
                     clip_value=None)
    print('Initialized a', trainer.task, 'trainer.')
    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()
    # Train model.
    trainer.train()
    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()
Ejemplo n.º 5
0
def main():
    # Initialize arguments
    args = init_cormorant_argparse('msp')
    # Initialize file paths
    args = init_cormorant_file_paths(args)
    # Initialize logger
    init_logger(args)
    # Initialize device and data type
    device, dtype = init_cuda(args)
    # Initialize dataloader # Use initialize_msp_data to load LMDB directly (needs much more memory)
    if args.format.lower().startswith('lmdb'):
        init = initialize_msp_data(args, args.datadir)
    else:
        init = initialize_datasets(args, args.datadir, 'msp', args.ddir_suffix)
    args, datasets, num_species, charge_scale = init
    # Construct PyTorch dataloaders from datasets
    dataloaders = {
        split: DataLoader(dataset,
                          batch_size=args.batch_size,
                          shuffle=args.shuffle if
                          (split == 'train') else False,
                          num_workers=args.num_workers,
                          collate_fn=collate_msp)
        for split, dataset in datasets.items()
    }
    # Initialize model
    model = ENN_MSP(args.maxl,
                    args.max_sh,
                    args.num_cg_levels,
                    args.num_channels,
                    num_species,
                    args.cutoff_type,
                    args.hard_cut_rad,
                    args.soft_cut_rad,
                    args.soft_cut_width,
                    args.weight_init,
                    args.level_gain,
                    args.charge_power,
                    args.basis_set,
                    charge_scale,
                    args.gaussian_mask,
                    num_classes=args.num_classes,
                    device=device,
                    dtype=dtype)
    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)
    # Define cross-entropy as the loss function.
    loss_fn = torch.nn.functional.cross_entropy
    gc.collect()
    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model,
                    dataloaders['train'],
                    args,
                    charge_scale=charge_scale,
                    siamese=True)
    # Instantiate the training class
    trainer = Engine(args,
                     dataloaders,
                     model,
                     loss_fn,
                     optimizer,
                     scheduler,
                     restart_epochs,
                     device,
                     dtype,
                     task='classification',
                     clip_value=None)
    print('Initialized a', trainer.task, 'trainer.')
    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()
    # Train model.
    trainer.train()
    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()
Ejemplo n.º 6
0
from cormorant.data.utils import initialize_datasets
from cormorant.data.collate import collate_fn

import logging

logging.basicConfig(level=logging.INFO)

datasets, num_species, max_charge = initialize_datasets('/tmp/test', 'qm9')

train = datasets['train']

batch = [train[i] for i in [5, 123, 5436, 43132]]

batch_coll = collate_fn(batch)

breakpoint()