Example #1
0
def train_qm9(args, device, metrics_dict):
    all_data = QM9Dataset(return_types=args.required_data,
                          target_tasks=args.targets,
                          dist_embedding=args.dist_embedding,
                          num_radial=args.num_radial,
                          prefetch_graphs=args.prefetch_graphs)

    all_idx = get_random_indices(len(all_data), args.seed_data)
    model_idx = all_idx[:100000]
    test_idx = all_idx[len(model_idx):len(model_idx) +
                       int(0.1 * len(all_data))]
    val_idx = all_idx[len(model_idx) + len(test_idx):]
    train_idx = model_idx[:args.num_train]
    # for debugging purposes:
    # test_idx = all_idx[len(model_idx): len(model_idx) + 200]
    # val_idx = all_idx[len(model_idx) + len(test_idx): len(model_idx) + len(test_idx) + 3000]
    try:
        edge_dim = all_data[0][0].edata['feat'].shape[
            1] if args.use_e_features else 0
    except:
        edge_dim = all_data[0][0].edges['bond'].data['feat'].shape[
            1] if args.use_e_features else 0
    model = globals()[args.model_type](
        node_dim=all_data[0][0].ndata['feat'].shape[1],
        edge_dim=edge_dim,
        avg_d=all_data.avg_degree,
        **args.model_parameters)

    if args.pretrain_checkpoint:
        # get arguments used during pretraining
        with open(
                os.path.join(os.path.dirname(args.pretrain_checkpoint),
                             'train_arguments.yaml'), 'r') as arg_file:
            pretrain_dict = yaml.load(arg_file, Loader=yaml.FullLoader)
        pretrain_args = argparse.Namespace()
        pretrain_args.__dict__.update(pretrain_dict)
        train_idx = model_idx[pretrain_args.num_train:pretrain_args.num_train +
                              args.num_train]

        checkpoint = torch.load(args.pretrain_checkpoint, map_location=device)
        # get all the weights that have something from 'args.transfer_layers' in their keys name
        # but only if they do not contain 'teacher' and remove 'student.' which we need for loading from BYOLWrapper
        pretrained_gnn_dict = {
            k.replace('student.', ''): v
            for k, v in checkpoint['model_state_dict'].items()
            if any(transfer_layer in k
                   for transfer_layer in args.transfer_layers) and 'teacher'
            not in k and not any(to_exclude in k
                                 for to_exclude in args.exclude_from_transfer)
        }
        model_state_dict = model.state_dict()
        model_state_dict.update(
            pretrained_gnn_dict
        )  # update the gnn layers with the pretrained weights
        model.load_state_dict(model_state_dict)
    print('model trainable params: ',
          sum(p.numel() for p in model.parameters() if p.requires_grad))

    print(f'Training on {len(train_idx)} samples from the model sequences')
    collate_function = globals()[
        args.collate_function] if args.collate_params == {} else globals()[
            args.collate_function](**args.collate_params)

    if args.train_sampler != None:
        sampler = globals()[args.train_sampler](data_source=all_data,
                                                batch_size=args.batch_size,
                                                indices=train_idx)
        train_loader = DataLoader(Subset(all_data, train_idx),
                                  batch_sampler=sampler,
                                  collate_fn=collate_function)
    else:
        train_loader = DataLoader(Subset(all_data, train_idx),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_function)
    val_loader = DataLoader(Subset(all_data, val_idx),
                            batch_size=args.batch_size,
                            collate_fn=collate_function)
    test_loader = DataLoader(Subset(all_data, test_idx),
                             batch_size=args.batch_size,
                             collate_fn=collate_function)

    metrics_dict.update({
        'mae_denormalized': QM9DenormalizedL1(dataset=all_data),
        'mse_denormalized': QM9DenormalizedL2(dataset=all_data)
    })
    metrics = {
        metric: metrics_dict[metric]
        for metric in args.metrics if metric != 'qm9_properties'
    }
    tensorboard_functions = {
        function: TENSORBOARD_FUNCTIONS[function]
        for function in args.tensorboard_functions
    }
    if 'qm9_properties' in args.metrics:
        metrics.update({
            task: QM9SingleTargetDenormalizedL1(dataset=all_data, task=task)
            for task in all_data.target_tasks
        })

    # Needs "from torch.optim import *" and "from models import *" to work
    if args.model3d_type:
        model3d = globals()[args.model3d_type](
            node_dim=all_data[0][1].ndata['feat'].shape[1] if isinstance(
                all_data[0][1], dgl.DGLGraph) else all_data[0][1].shape[-1],
            edge_dim=all_data[0][1].edata['d'].shape[1] if args.use_e_features
            and isinstance(all_data[0][1], dgl.DGLGraph) else 0,
            avg_d=all_data.avg_degree,
            **args.model3d_parameters)
        print('3D model trainable params: ',
              sum(p.numel() for p in model3d.parameters() if p.requires_grad))

        critic = None
        if args.ssl_mode == 'byol':
            ssl_trainer = BYOLTrainer
        elif args.ssl_mode == 'alternating':
            ssl_trainer = SelfSupervisedAlternatingTrainer
        elif args.ssl_mode == 'contrastive':
            ssl_trainer = SelfSupervisedTrainer
        elif args.ssl_mode == 'philosophy':
            ssl_trainer = PhilosophyTrainer
            critic = globals()[args.critic_type](**args.critic_parameters)
        trainer = ssl_trainer(
            model=model,
            model3d=model3d,
            critic=critic,
            args=args,
            metrics=metrics,
            main_metric=args.main_metric,
            main_metric_goal=args.main_metric_goal,
            optim=globals()[args.optimizer],
            loss_func=globals()[args.loss_func](**args.loss_params),
            critic_loss=globals()[args.critic_loss](**args.critic_loss_params),
            device=device,
            tensorboard_functions=tensorboard_functions,
            scheduler_step_per_batch=args.scheduler_step_per_batch)
    else:
        trainer = Trainer(
            model=model,
            args=args,
            metrics=metrics,
            main_metric=args.main_metric,
            main_metric_goal=args.main_metric_goal,
            optim=globals()[args.optimizer],
            loss_func=globals()[args.loss_func](**args.loss_params),
            device=device,
            tensorboard_functions=tensorboard_functions,
            scheduler_step_per_batch=args.scheduler_step_per_batch)
    trainer.train(train_loader, val_loader)

    if args.eval_on_test:
        trainer.evaluation(test_loader, data_split='test')
Example #2
0
def train_molhiv(args, device, metrics_dict):
    dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
    split_idx = dataset.get_idx_split()
    train_loader = DataLoader(dataset[split_idx["train"]],
                              batch_size=32,
                              shuffle=True,
                              collate_fn=collate_dgl)
    val_loader = DataLoader(dataset[split_idx["valid"]],
                            batch_size=32,
                            shuffle=False,
                            collate_fn=collate_dgl)
    test_loader = DataLoader(dataset[split_idx["test"]],
                             batch_size=32,
                             shuffle=False,
                             collate_fn=collate_dgl)

    model = globals()[args.model_type](
        node_dim=dataset[0][0].ndata['feat'].shape[1],
        edge_dim=dataset[0][0].edata['feat'].shape[1]
        if args.use_e_features else 0,
        **args.model_parameters)
    print('model trainable params: ',
          sum(p.numel() for p in model.parameters() if p.requires_grad))
    collate_function = globals()[
        args.collate_function] if args.collate_params == {} else globals()[
            args.collate_function](**args.collate_params)

    metrics = {metric: metrics_dict[metric] for metric in args.metrics}
    tensorboard_functions = {
        function: TENSORBOARD_FUNCTIONS[function]
        for function in args.tensorboard_functions
    }

    # Needs "from torch.optim import *" and "from models import *" to work
    transferred_params = [
        v for k, v in model.named_parameters()
        if any(transfer_name in k for transfer_name in args.transfer_layers)
    ]
    new_params = [
        v for k, v in model.named_parameters() if all(
            transfer_name not in k for transfer_name in args.transfer_layers)
    ]
    transfer_lr = args.optimizer_params[
        'lr'] if args.transferred_lr == None else args.transferred_lr
    optim = globals()[args.optimizer]([{
        'params': new_params
    }, {
        'params': transferred_params,
        'lr': transfer_lr
    }], **args.optimizer_params)
    trainer = Trainer(model=model,
                      args=args,
                      metrics=metrics,
                      main_metric=args.main_metric,
                      main_metric_goal=args.main_metric_goal,
                      optim=optim,
                      loss_func=globals()[args.loss_func](**args.loss_params),
                      device=device,
                      tensorboard_functions=tensorboard_functions,
                      scheduler_step_per_batch=args.scheduler_step_per_batch)
    trainer.train(train_loader, val_loader)

    if args.eval_on_test:
        trainer.evaluation(test_loader, data_split='test')
Example #3
0
def train_zinc(args, device, metrics_dict):
    train_data = ZINCDataset(split='train',
                             device=device,
                             prefetch_graphs=args.prefetch_graphs)
    val_data = ZINCDataset(split='val',
                           device=device,
                           prefetch_graphs=args.prefetch_graphs)
    test_data = ZINCDataset(split='test',
                            device=device,
                            prefetch_graphs=args.prefetch_graphs)

    model = globals()[args.model_type](
        node_dim=train_data[0][0].ndata['feat'].shape[1],
        edge_dim=train_data[0][0].edata['feat'].shape[1]
        if args.use_e_features else 0,
        **args.model_parameters)
    print('model trainable params: ',
          sum(p.numel() for p in model.parameters() if p.requires_grad))
    collate_function = globals()[
        args.collate_function] if args.collate_params == {} else globals()[
            args.collate_function](**args.collate_params)
    if args.train_sampler != None:
        sampler = globals()[args.train_sampler](data_source=train_data,
                                                batch_size=args.batch_size,
                                                indices=range(len(train_data)))
        train_loader = DataLoader(train_data,
                                  batch_sampler=sampler,
                                  collate_fn=collate_function)
    else:
        train_loader = DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_function)
    val_loader = DataLoader(val_data,
                            batch_size=args.batch_size,
                            collate_fn=collate_function)
    test_loader = DataLoader(test_data,
                             batch_size=args.batch_size,
                             collate_fn=collate_function)

    metrics = {metric: metrics_dict[metric] for metric in args.metrics}
    tensorboard_functions = {
        function: TENSORBOARD_FUNCTIONS[function]
        for function in args.tensorboard_functions
    }

    # Needs "from torch.optim import *" and "from models import *" to work
    transferred_params = [
        v for k, v in model.named_parameters()
        if any(transfer_name in k for transfer_name in args.transfer_layers)
    ]
    new_params = [
        v for k, v in model.named_parameters() if all(
            transfer_name not in k for transfer_name in args.transfer_layers)
    ]
    transfer_lr = args.optimizer_params[
        'lr'] if args.transferred_lr == None else args.transferred_lr
    optim = globals()[args.optimizer]([{
        'params': new_params
    }, {
        'params': transferred_params,
        'lr': transfer_lr
    }], **args.optimizer_params)
    trainer = Trainer(model=model,
                      args=args,
                      metrics=metrics,
                      main_metric=args.main_metric,
                      main_metric_goal=args.main_metric_goal,
                      optim=optim,
                      loss_func=globals()[args.loss_func](**args.loss_params),
                      device=device,
                      tensorboard_functions=tensorboard_functions,
                      scheduler_step_per_batch=args.scheduler_step_per_batch)
    trainer.train(train_loader, val_loader)

    if args.eval_on_test:
        trainer.evaluation(test_loader, data_split='test')