def main(): args = get_parser() # get some argparse arguments that are parsed a bool string naive_encoder = not str2bool(args.full_encoder) pin_memory = str2bool(args.pin_memory) use_bias = str2bool(args.bias) downstream_bn = str(args.d_bn) same_dropout = str2bool(args.same_dropout) mlp_mp = str2bool(args.mlp_mp) phm_dim = args.phm_dim learn_phm = str2bool(args.learn_phm) base_dir = "cifar10/" if not os.path.exists(base_dir): os.makedirs(base_dir) if base_dir not in args.save_dir: args.save_dir = os.path.join(base_dir, args.save_dir) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) set_logging(save_dir=args.save_dir) logging.info(f"Creating log directory at {args.save_dir}.") with open(os.path.join(args.save_dir, "params.json"), 'w') as fp: json.dump(args.__dict__, fp) mp_layers = [int(item) for item in args.mp_units.split(',')] downstream_layers = [int(item) for item in args.d_units.split(',')] mp_dropout = [float(item) for item in args.dropout_mpnn.split(',')] dn_dropout = [float(item) for item in args.dropout_dn.split(',')] logging.info( f'Initialising model with {mp_layers} hidden units with dropout {mp_dropout} ' f'and downstream units: {downstream_layers} with dropout {dn_dropout}.' ) if args.pooling == "globalsum": logging.info("Using GlobalSum Pooling") else: logging.info("Using SoftAttention Pooling") logging.info( f"Using Adam optimizer with weight_decay ({args.weightdecay}) and regularization " f"norm ({args.regularization})") logging.info( f"Weight init: {args.w_init} \n Contribution init: {args.c_init}") # data path = osp.join(osp.dirname(osp.realpath(__file__)), 'dataset') transform = concat_x_pos train_data = GNNBenchmarkDataset(path, name="CIFAR10", split='train', transform=transform) valid_data = GNNBenchmarkDataset(path, name="CIFAR10", split='val', transform=transform) test_data = GNNBenchmarkDataset(path, name="CIFAR10", split='test', transform=transform) evaluator = Evaluator() train_loader = DataLoader(train_data, batch_size=args.batch_size, drop_last=False, shuffle=True, num_workers=args.nworkers, pin_memory=pin_memory) valid_loader = DataLoader(valid_data, batch_size=args.batch_size, drop_last=False, shuffle=False, num_workers=args.nworkers, pin_memory=pin_memory) test_loader = DataLoader(test_data, batch_size=args.batch_size, drop_last=False, shuffle=False, num_workers=args.nworkers, pin_memory=pin_memory) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' #device = "cpu" # for hypercomplex model unique_phm = str2bool(args.unique_phm) if unique_phm: phm_rule = get_multiplication_matrices(phm_dim=args.phm_dim, type="phm") phm_rule = torch.nn.ParameterList( [torch.nn.Parameter(a, requires_grad=learn_phm) for a in phm_rule]) else: phm_rule = None FULL_ATOM_FEATURE_DIMS = 5 FULL_BOND_FEATURE_DIMS = 1 if args.aggr_msg == "pna" or args.aggr_node == "pna": # if PNA is used # Compute in-degree histogram over training data. deg = torch.zeros(19, dtype=torch.long) for data in train_data: d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg += torch.bincount(d, minlength=deg.numel()) else: deg = None aggr_kwargs = { "aggregators": ['mean', 'min', 'max', 'std'], "scalers": ['identity', 'amplification', 'attenuation'], "deg": deg, "post_layers": 1, "msg_scalers": str2bool(args.msg_scale ), # this key is for directional messagepassing layers. "initial_beta": 1.0, # Softmax "learn_beta": True } if "quaternion" in args.type: if args.aggr_msg == "pna" or args.aggr_msg == "pna": logging.info("PNA not implemented for quaternion models.") raise NotImplementedError if args.type == "undirectional-quaternion-sc-add": logging.info( "Using Quaternion Undirectional MPNN with Skip Connection through Addition" ) model = UQ_SC_ADD(atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, init=args.w_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=train_data.num_classes, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, **aggr_kwargs) elif args.type == "undirectional-quaternion-sc-cat": logging.info( "Using Quaternion Undirectional MPNN with Skip Connection through Concatenation" ) model = UQ_SC_CAT(atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, init=args.w_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=train_data.num_classes, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, **aggr_kwargs) elif args.type == "undirectional-phm-sc-add": logging.info( "Using PHM Undirectional MPNN with Skip Connection through Addition" ) model = UPH_SC_ADD(phm_dim=phm_dim, learn_phm=learn_phm, phm_rule=phm_rule, atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, w_init=args.w_init, c_init=args.c_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=train_data.num_classes, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, sc_type=args.sc_type, **aggr_kwargs) elif args.type == "undirectional-phm-sc-cat": logging.info( "Using PHM Undirectional MPNN with Skip Connection through Concatenation" ) model = UPH_SC_CAT(phm_dim=phm_dim, learn_phm=learn_phm, phm_rule=phm_rule, atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, w_init=args.w_init, c_init=args.c_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=train_data.num_classes, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, **aggr_kwargs) else: raise ModuleNotFoundError logging.info( f"Model consists of {model.get_number_of_params_()} trainable parameters" ) # do runs test_best_epoch_metrics_arr = [] test_last_epoch_metrics_arr = [] val_metrics_arr = [] t0 = time.time() for i in range(1, args.n_runs + 1): ogb_bestEpoch_test_metrics, ogb_lastEpoch_test_metric, ogb_val_metrics = do_run( i, model, args, transform, train_loader, valid_loader, test_loader, device, evaluator, t0) test_best_epoch_metrics_arr.append(ogb_bestEpoch_test_metrics) test_last_epoch_metrics_arr.append(ogb_lastEpoch_test_metric) val_metrics_arr.append(ogb_val_metrics) logging.info(f"Performance of model across {args.n_runs} runs:") test_bestEpoch_perf = torch.tensor(test_best_epoch_metrics_arr) test_lastEpoch_perf = torch.tensor(test_last_epoch_metrics_arr) valid_perf = torch.tensor(val_metrics_arr) logging.info('===========================') logging.info( f'Final Test (best val-epoch) ' f'"{evaluator.eval_metric}": {test_bestEpoch_perf.mean():.4f} ± {test_bestEpoch_perf.std():.4f}' ) logging.info( f'Final Test (last-epoch) ' f'"{evaluator.eval_metric}": {test_lastEpoch_perf.mean():.4f} ± {test_lastEpoch_perf.std():.4f}' ) logging.info( f'Final (best) Valid "{evaluator.eval_metric}": {valid_perf.mean():.4f} ± {valid_perf.std():.4f}' )
def do_run(i, model, args, transform, train_loader, valid_loader, test_loader, device, evaluator, t0): logging.info(f"Run {i}/{args.n_runs}, seed: {args.seed + i - 1}") logging.info("Reset model parameters") set_seed_all(args.seed + i - 1) model.reset_parameters() model = model.to(device) # setting up parameter groups for optimization # quaternion/hypercomplex - in general params_mp = get_model_blocks(model, attr="convs", **dict(lr=args.lr, weight_decay=0.0)) params_pooling = get_model_blocks(model, attr="pooling", **dict(lr=args.lr, weight_decay=0.0)) params_downstream = get_model_blocks(model, attr="downstream", **dict(lr=args.lr, weight_decay=0.0)) # quaternion/hypercomplex - undirectional params_norms = get_model_blocks(model, attr="norms", **dict(lr=args.lr, weight_decay=0.0)) params_embedding_atom = get_model_blocks(model, attr="atomencoder", **dict(lr=args.lr, weight_decay=0.0)) params_embedding_bonds = get_model_blocks(model, attr="bondencoders", **dict(lr=args.lr, weight_decay=0.0)) params = params_mp + params_pooling + params_downstream + \ params_norms + params_embedding_atom + params_embedding_bonds # check if all params are captured total_params_splitted = sum([ sum([p.numel() for p in pl["params"] if p.requires_grad]) for pl in params ]) total_params_model = sum( [p.numel() for p in model.parameters() if p.requires_grad]) assert total_params_model == total_params_splitted, f"splitted total params: {total_params_splitted}." \ f"However, total params of model are: {total_params_model}" # optimizer optimizer = torch.optim.Adam(params, weight_decay=0.0) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.factor, patience=args.patience, mode="max", min_lr=1e-7, verbose=True) save_dir = os.path.join(args.save_dir, f"run_{i}") if not os.path.exists(save_dir): os.makedirs(save_dir) # tensorboard logging logging.info(f"Creating new tensorboard logging for run {i}.") writer = SummaryWriter(log_dir=save_dir) lr_arr = [] train_metrics_arr = [] val_metrics_arr = [] start_time = datetime.now() best_metric = float("-inf") model_save_dir = os.path.join(save_dir, "model.pt") model_last_save_dir = os.path.join(save_dir, "model_last.pt") metric_key = evaluator.eval_metric # At any point you can hit Ctrl + C to break out of training early. exit_script = "n" try: for epoch in range(1, args.epochs + 1): if str2bool(args.log_weights): # weights logging after one epoch # -> might need to move this into the train function to track after batch-updates for name, value in model.named_parameters(): if value.requires_grad: writer.add_histogram(tag=name, values=value.data.cpu().numpy(), global_step=epoch) # log only the parameter weights. # if the grads should be logged too, need to move that code into the train_step function and retrieve # gradients after loss.backward() is called. Otherwise the grads are None or 0.0 lr = scheduler.optimizer.param_groups[0]['lr'] lr_arr.append(lr) logging.info(f"Epoch: {epoch}/{args.epochs}, Learning Rate: {lr}.") train_metrics = train(epoch=epoch, model=model, device=device, transform=transform, loader=train_loader, optimizer=optimizer, evaluator=evaluator, kwargs={ "run": i, "nruns": args.n_runs, "grad_clipping": args.grad_clipping, "lr": lr, "weight_decay": args.weightdecay, "weight_decay2": args.weightdecay2, "p": args.regularization }) train_metrics_arr.append(train_metrics) logging.info( f"Training Metrics in Epoch: {epoch} \n Metrics: {train_metrics}." ) validation_metrics = test_validate(model=model, device=device, transform=transform, loader=valid_loader, evaluator=evaluator) val_metrics_arr.append(validation_metrics) logging.info( f"Validation Metrics in Epoch: {epoch} \n Metrics: {validation_metrics}." ) if validation_metrics[metric_key] > best_metric: logging.info( f"Saving model with validation '{metric_key}': {validation_metrics[metric_key]}." ) best_metric = validation_metrics[metric_key] torch.save(model, model_save_dir) scheduler.step(validation_metrics[metric_key]) # tensorboard logging writer.add_scalar(tag="lr", scalar_value=lr, global_step=epoch) writer.add_scalar(tag="train_loss", scalar_value=train_metrics["loss"], global_step=epoch) writer.add_scalar(tag=f"train_{metric_key}", scalar_value=train_metrics[metric_key], global_step=epoch) writer.add_scalar(tag="valid_loss", scalar_value=validation_metrics["loss"], global_step=epoch) writer.add_scalar(tag=f"valid_{metric_key}", scalar_value=validation_metrics[metric_key], global_step=epoch) # benchmarking GNNs if optimizer.param_groups[0]['lr'] < args.min_lr: logging.info("\n!! LR EQUAL TO MIN LR SET.") break if time.time() - t0 > args.max_time * 3600: logging.info('-' * 89) logging.info( "Max_time for training elapsed {:.2f} hours, so stopping". format(args.max_time)) break except KeyboardInterrupt: logging.info("-" * 80) logging.info(f"training interupted in epoch {epoch}") logging.info(f"saving model at {model_last_save_dir}") exit_script = input( "Should the entire run be exited after evaluation ? Type (Y/N) ") torch.save(model, model_last_save_dir) end_time = datetime.now() training_time = end_time - start_time logging.info(f"Training time: {training_time}") # after training, load the best model and test #if exit_script.lower() == "y": try: model = torch.load(model_save_dir) except FileNotFoundError: logging.info( f"File '{model_save_dir}' not found. Cannot load model. Will use current model." ) model = model.to(device) logging.info(f"Testing model from best validation epoch") test_metrics = test_validate(model=model, loader=test_loader, device=device, transform=transform, evaluator=evaluator) logging.info(f"Test Metrics of model: \n Metrics: {test_metrics}.") # we also load the last model and test, just to see how it performs model = torch.load(model_last_save_dir) model = model.to(device) logging.info(f"Testing model from last epoch") test_metrics_last = test_validate(model=model, loader=test_loader, device=device, transform=transform, evaluator=evaluator) logging.info(f"Test Metrics of model: \n Metrics: {test_metrics_last}.") svz = { 'train_metrics': train_metrics_arr, 'lr': lr_arr, 'val_metrics': val_metrics_arr, 'test_metrics': test_metrics, 'test_metrics_lastepoch': test_metrics_last } with open(os.path.join(save_dir, "arrays.pickle"), "wb") as fp: pickle.dump(svz, fp) svz2 = { "best_val": best_metric, "test_best_valEpoch": test_metrics[metric_key], "test_lastEpoch": test_metrics_last[metric_key] } with open(os.path.join(save_dir, "val_test.json"), 'w') as f: json.dump(svz2, f) # close tensorboard writer for run writer.close() if exit_script.lower() == "y": logging.info("exit script.") exit() return test_metrics[metric_key], test_metrics_last[metric_key], best_metric