コード例 #1
0
ファイル: train.py プロジェクト: drorlab/atom3d
def main():
    # Initialize arguments
    args = init_cormorant_argparse('lba')
    # Initialize file paths
    args = init_cormorant_file_paths(args)
    # Initialize logger
    init_logger(args)
    # Initialize device and data type
    device, dtype = init_cuda(args)
    # Initialize dataloader
    args, datasets, num_species, charge_scale = initialize_lba_data(
        args, args.datadir)
    # Further differences for Siamese networks
    if args.siamese:
        collate_fn = collate_lba_siamese
    else:
        collate_fn = collate_lba
    # Construct PyTorch dataloaders from datasets
    dataloaders = {
        split: DataLoader(dataset,
                          batch_size=args.batch_size,
                          shuffle=args.shuffle if
                          (split == 'train') else False,
                          num_workers=args.num_workers,
                          collate_fn=collate_fn)
        for split, dataset in datasets.items()
    }
    # Initialize model
    if args.siamese:
        model = ENN_LBA_Siamese(args.maxl,
                                args.max_sh,
                                args.num_cg_levels,
                                args.num_channels,
                                num_species,
                                args.cutoff_type,
                                args.hard_cut_rad,
                                args.soft_cut_rad,
                                args.soft_cut_width,
                                args.weight_init,
                                args.level_gain,
                                args.charge_power,
                                args.basis_set,
                                charge_scale,
                                args.gaussian_mask,
                                cgprod_bounded=args.cgprod_bounded,
                                cg_pow_normalization=args.cg_pow_normalization,
                                cg_agg_normalization=args.cg_agg_normalization,
                                device=device,
                                dtype=dtype)
    else:
        model = ENN_LBA(args.maxl,
                        args.max_sh,
                        args.num_cg_levels,
                        args.num_channels,
                        num_species,
                        args.cutoff_type,
                        args.hard_cut_rad,
                        args.soft_cut_rad,
                        args.soft_cut_width,
                        args.weight_init,
                        args.level_gain,
                        args.charge_power,
                        args.basis_set,
                        charge_scale,
                        args.gaussian_mask,
                        cgprod_bounded=args.cgprod_bounded,
                        cg_pow_normalization=args.cg_pow_normalization,
                        cg_agg_normalization=args.cg_agg_normalization,
                        device=device,
                        dtype=dtype)
    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)
    # Define a loss function. Just use L2 loss for now.
    loss_fn = torch.nn.functional.mse_loss
    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model,
                    dataloaders['train'],
                    args,
                    charge_scale=charge_scale,
                    siamese=args.siamese)
    # Instantiate the training class
    trainer = Engine(args,
                     dataloaders,
                     model,
                     loss_fn,
                     optimizer,
                     scheduler,
                     restart_epochs,
                     device,
                     dtype,
                     clip_value=None)
    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()
    # Train model.
    trainer.train()
    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()
コード例 #2
0
ファイル: train_msp.py プロジェクト: maschka/atom3d
def main():
    # Initialize arguments
    args = init_cormorant_argparse('msp')
    # Initialize file paths
    args = init_cormorant_file_paths(args)
    # Initialize logger
    init_logger(args)
    # Initialize device and data type
    device, dtype = init_cuda(args)
    # Initialize dataloader # Use initialize_msp_data to load LMDB directly (needs much more memory)
    if args.format.lower().startswith('lmdb'):
        init = initialize_msp_data(args, args.datadir)
    else:
        init = initialize_datasets(args, args.datadir, 'msp', args.ddir_suffix)
    args, datasets, num_species, charge_scale = init
    # Construct PyTorch dataloaders from datasets
    dataloaders = {
        split: DataLoader(dataset,
                          batch_size=args.batch_size,
                          shuffle=args.shuffle if
                          (split == 'train') else False,
                          num_workers=args.num_workers,
                          collate_fn=collate_msp)
        for split, dataset in datasets.items()
    }
    # Initialize model
    model = ENN_MSP(args.maxl,
                    args.max_sh,
                    args.num_cg_levels,
                    args.num_channels,
                    num_species,
                    args.cutoff_type,
                    args.hard_cut_rad,
                    args.soft_cut_rad,
                    args.soft_cut_width,
                    args.weight_init,
                    args.level_gain,
                    args.charge_power,
                    args.basis_set,
                    charge_scale,
                    args.gaussian_mask,
                    num_classes=args.num_classes,
                    device=device,
                    dtype=dtype)
    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)
    # Define cross-entropy as the loss function.
    loss_fn = torch.nn.functional.cross_entropy
    gc.collect()
    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model,
                    dataloaders['train'],
                    args,
                    charge_scale=charge_scale,
                    siamese=True)
    # Instantiate the training class
    trainer = Engine(args,
                     dataloaders,
                     model,
                     loss_fn,
                     optimizer,
                     scheduler,
                     restart_epochs,
                     device,
                     dtype,
                     task='classification',
                     clip_value=None)
    print('Initialized a', trainer.task, 'trainer.')
    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()
    # Train model.
    trainer.train()
    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()
コード例 #3
0
ファイル: train_resdel.py プロジェクト: sailfish009/atom3d
def main():

    # Initialize arguments -- Just
    args = init_argparse('resdel')

    # Initialize file paths
    args = init_file_paths(args)

    # Initialize logger
    init_logger(args)

    # Initialize device and data type
    device, dtype = init_cuda(args)

    # Initialize dataloader
    args, datasets, num_species, charge_scale = initialize_datasets(args, args.datadir, 'resdel', 
                                                                    force_download=args.force_download,
                                                                    ignore_check=args.ignore_check
                                                                    )

    # Construct PyTorch dataloaders from datasets
    dataloaders = {split: DataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=args.shuffle if (split == 'train') else False,
                                     num_workers=args.num_workers,
                                     collate_fn=collate_fn)
                         for split, dataset in datasets.items()}

    # Initialize model
    model = CormorantResDel(args.maxl, args.max_sh, args.num_cg_levels, args.num_channels, num_species,
                             args.cutoff_type, args.hard_cut_rad, args.soft_cut_rad, args.soft_cut_width,
                             args.weight_init, args.level_gain, args.charge_power, args.basis_set,
                             charge_scale, args.gaussian_mask,
                             args.top, args.input, args.num_mpnn_levels,
                             num_classes=20,
                             cgprod_bounded = False,
                             cg_agg_normalization = 'relu',
                             cg_pow_normalization = 'relu',
                             device=device, dtype=dtype)

    # Initialize the scheduler and optimizer
    optimizer = init_optimizer(args, model)
    scheduler, restart_epochs = init_scheduler(args, optimizer)

    # Define cross-entropy as the loss function.
    loss_fn = torch.nn.functional.cross_entropy

    # Apply the covariance and permutation invariance tests.
    cormorant_tests(model, dataloaders['train'], args, charge_scale=charge_scale)

    # Instantiate the training class
    trainer = Engine(args, dataloaders, model, loss_fn, optimizer, scheduler, restart_epochs, device, dtype, task='classification', clip_value=None)
    print('Initialized a',trainer.task,'trainer.')

    # Load from checkpoint file. If no checkpoint file exists, automatically does nothing.
    trainer.load_checkpoint()

    # Train model.
    trainer.train()

    # Test predictions on best model and also last checkpointed model.
    trainer.evaluate()