def train_qm9(args, device, metrics_dict): all_data = QM9Dataset(return_types=args.required_data, target_tasks=args.targets, dist_embedding=args.dist_embedding, num_radial=args.num_radial, prefetch_graphs=args.prefetch_graphs) all_idx = get_random_indices(len(all_data), args.seed_data) model_idx = all_idx[:100000] test_idx = all_idx[len(model_idx):len(model_idx) + int(0.1 * len(all_data))] val_idx = all_idx[len(model_idx) + len(test_idx):] train_idx = model_idx[:args.num_train] # for debugging purposes: # test_idx = all_idx[len(model_idx): len(model_idx) + 200] # val_idx = all_idx[len(model_idx) + len(test_idx): len(model_idx) + len(test_idx) + 3000] try: edge_dim = all_data[0][0].edata['feat'].shape[ 1] if args.use_e_features else 0 except: edge_dim = all_data[0][0].edges['bond'].data['feat'].shape[ 1] if args.use_e_features else 0 model = globals()[args.model_type]( node_dim=all_data[0][0].ndata['feat'].shape[1], edge_dim=edge_dim, avg_d=all_data.avg_degree, **args.model_parameters) if args.pretrain_checkpoint: # get arguments used during pretraining with open( os.path.join(os.path.dirname(args.pretrain_checkpoint), 'train_arguments.yaml'), 'r') as arg_file: pretrain_dict = yaml.load(arg_file, Loader=yaml.FullLoader) pretrain_args = argparse.Namespace() pretrain_args.__dict__.update(pretrain_dict) train_idx = model_idx[pretrain_args.num_train:pretrain_args.num_train + args.num_train] checkpoint = torch.load(args.pretrain_checkpoint, map_location=device) # get all the weights that have something from 'args.transfer_layers' in their keys name # but only if they do not contain 'teacher' and remove 'student.' which we need for loading from BYOLWrapper pretrained_gnn_dict = { k.replace('student.', ''): v for k, v in checkpoint['model_state_dict'].items() if any(transfer_layer in k for transfer_layer in args.transfer_layers) and 'teacher' not in k and not any(to_exclude in k for to_exclude in args.exclude_from_transfer) } model_state_dict = model.state_dict() model_state_dict.update( pretrained_gnn_dict ) # update the gnn layers with the pretrained weights model.load_state_dict(model_state_dict) print('model trainable params: ', sum(p.numel() for p in model.parameters() if p.requires_grad)) print(f'Training on {len(train_idx)} samples from the model sequences') collate_function = globals()[ args.collate_function] if args.collate_params == {} else globals()[ args.collate_function](**args.collate_params) if args.train_sampler != None: sampler = globals()[args.train_sampler](data_source=all_data, batch_size=args.batch_size, indices=train_idx) train_loader = DataLoader(Subset(all_data, train_idx), batch_sampler=sampler, collate_fn=collate_function) else: train_loader = DataLoader(Subset(all_data, train_idx), batch_size=args.batch_size, shuffle=True, collate_fn=collate_function) val_loader = DataLoader(Subset(all_data, val_idx), batch_size=args.batch_size, collate_fn=collate_function) test_loader = DataLoader(Subset(all_data, test_idx), batch_size=args.batch_size, collate_fn=collate_function) metrics_dict.update({ 'mae_denormalized': QM9DenormalizedL1(dataset=all_data), 'mse_denormalized': QM9DenormalizedL2(dataset=all_data) }) metrics = { metric: metrics_dict[metric] for metric in args.metrics if metric != 'qm9_properties' } tensorboard_functions = { function: TENSORBOARD_FUNCTIONS[function] for function in args.tensorboard_functions } if 'qm9_properties' in args.metrics: metrics.update({ task: QM9SingleTargetDenormalizedL1(dataset=all_data, task=task) for task in all_data.target_tasks }) # Needs "from torch.optim import *" and "from models import *" to work if args.model3d_type: model3d = globals()[args.model3d_type]( node_dim=all_data[0][1].ndata['feat'].shape[1] if isinstance( all_data[0][1], dgl.DGLGraph) else all_data[0][1].shape[-1], edge_dim=all_data[0][1].edata['d'].shape[1] if args.use_e_features and isinstance(all_data[0][1], dgl.DGLGraph) else 0, avg_d=all_data.avg_degree, **args.model3d_parameters) print('3D model trainable params: ', sum(p.numel() for p in model3d.parameters() if p.requires_grad)) critic = None if args.ssl_mode == 'byol': ssl_trainer = BYOLTrainer elif args.ssl_mode == 'alternating': ssl_trainer = SelfSupervisedAlternatingTrainer elif args.ssl_mode == 'contrastive': ssl_trainer = SelfSupervisedTrainer elif args.ssl_mode == 'philosophy': ssl_trainer = PhilosophyTrainer critic = globals()[args.critic_type](**args.critic_parameters) trainer = ssl_trainer( model=model, model3d=model3d, critic=critic, args=args, metrics=metrics, main_metric=args.main_metric, main_metric_goal=args.main_metric_goal, optim=globals()[args.optimizer], loss_func=globals()[args.loss_func](**args.loss_params), critic_loss=globals()[args.critic_loss](**args.critic_loss_params), device=device, tensorboard_functions=tensorboard_functions, scheduler_step_per_batch=args.scheduler_step_per_batch) else: trainer = Trainer( model=model, args=args, metrics=metrics, main_metric=args.main_metric, main_metric_goal=args.main_metric_goal, optim=globals()[args.optimizer], loss_func=globals()[args.loss_func](**args.loss_params), device=device, tensorboard_functions=tensorboard_functions, scheduler_step_per_batch=args.scheduler_step_per_batch) trainer.train(train_loader, val_loader) if args.eval_on_test: trainer.evaluation(test_loader, data_split='test')
def train_molhiv(args, device, metrics_dict): dataset = DglGraphPropPredDataset(name='ogbg-molhiv') split_idx = dataset.get_idx_split() train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=collate_dgl) val_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=collate_dgl) test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=collate_dgl) model = globals()[args.model_type]( node_dim=dataset[0][0].ndata['feat'].shape[1], edge_dim=dataset[0][0].edata['feat'].shape[1] if args.use_e_features else 0, **args.model_parameters) print('model trainable params: ', sum(p.numel() for p in model.parameters() if p.requires_grad)) collate_function = globals()[ args.collate_function] if args.collate_params == {} else globals()[ args.collate_function](**args.collate_params) metrics = {metric: metrics_dict[metric] for metric in args.metrics} tensorboard_functions = { function: TENSORBOARD_FUNCTIONS[function] for function in args.tensorboard_functions } # Needs "from torch.optim import *" and "from models import *" to work transferred_params = [ v for k, v in model.named_parameters() if any(transfer_name in k for transfer_name in args.transfer_layers) ] new_params = [ v for k, v in model.named_parameters() if all( transfer_name not in k for transfer_name in args.transfer_layers) ] transfer_lr = args.optimizer_params[ 'lr'] if args.transferred_lr == None else args.transferred_lr optim = globals()[args.optimizer]([{ 'params': new_params }, { 'params': transferred_params, 'lr': transfer_lr }], **args.optimizer_params) trainer = Trainer(model=model, args=args, metrics=metrics, main_metric=args.main_metric, main_metric_goal=args.main_metric_goal, optim=optim, loss_func=globals()[args.loss_func](**args.loss_params), device=device, tensorboard_functions=tensorboard_functions, scheduler_step_per_batch=args.scheduler_step_per_batch) trainer.train(train_loader, val_loader) if args.eval_on_test: trainer.evaluation(test_loader, data_split='test')
def train_zinc(args, device, metrics_dict): train_data = ZINCDataset(split='train', device=device, prefetch_graphs=args.prefetch_graphs) val_data = ZINCDataset(split='val', device=device, prefetch_graphs=args.prefetch_graphs) test_data = ZINCDataset(split='test', device=device, prefetch_graphs=args.prefetch_graphs) model = globals()[args.model_type]( node_dim=train_data[0][0].ndata['feat'].shape[1], edge_dim=train_data[0][0].edata['feat'].shape[1] if args.use_e_features else 0, **args.model_parameters) print('model trainable params: ', sum(p.numel() for p in model.parameters() if p.requires_grad)) collate_function = globals()[ args.collate_function] if args.collate_params == {} else globals()[ args.collate_function](**args.collate_params) if args.train_sampler != None: sampler = globals()[args.train_sampler](data_source=train_data, batch_size=args.batch_size, indices=range(len(train_data))) train_loader = DataLoader(train_data, batch_sampler=sampler, collate_fn=collate_function) else: train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_function) val_loader = DataLoader(val_data, batch_size=args.batch_size, collate_fn=collate_function) test_loader = DataLoader(test_data, batch_size=args.batch_size, collate_fn=collate_function) metrics = {metric: metrics_dict[metric] for metric in args.metrics} tensorboard_functions = { function: TENSORBOARD_FUNCTIONS[function] for function in args.tensorboard_functions } # Needs "from torch.optim import *" and "from models import *" to work transferred_params = [ v for k, v in model.named_parameters() if any(transfer_name in k for transfer_name in args.transfer_layers) ] new_params = [ v for k, v in model.named_parameters() if all( transfer_name not in k for transfer_name in args.transfer_layers) ] transfer_lr = args.optimizer_params[ 'lr'] if args.transferred_lr == None else args.transferred_lr optim = globals()[args.optimizer]([{ 'params': new_params }, { 'params': transferred_params, 'lr': transfer_lr }], **args.optimizer_params) trainer = Trainer(model=model, args=args, metrics=metrics, main_metric=args.main_metric, main_metric_goal=args.main_metric_goal, optim=optim, loss_func=globals()[args.loss_func](**args.loss_params), device=device, tensorboard_functions=tensorboard_functions, scheduler_step_per_batch=args.scheduler_step_per_batch) trainer.train(train_loader, val_loader) if args.eval_on_test: trainer.evaluation(test_loader, data_split='test')