예제 #1
0
def main():
    args = get_parser()
    # get some argparse arguments that are parsed a bool string
    naive_encoder = not str2bool(args.full_encoder)
    pin_memory = str2bool(args.pin_memory)
    use_bias = str2bool(args.bias)
    downstream_bn = str(args.d_bn)
    same_dropout = str2bool(args.same_dropout)
    mlp_mp = str2bool(args.mlp_mp)
    phm_dim = args.phm_dim
    learn_phm = str2bool(args.learn_phm)

    base_dir = "cifar10/"
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)

    if base_dir not in args.save_dir:
        args.save_dir = os.path.join(base_dir, args.save_dir)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    set_logging(save_dir=args.save_dir)
    logging.info(f"Creating log directory at {args.save_dir}.")
    with open(os.path.join(args.save_dir, "params.json"), 'w') as fp:
        json.dump(args.__dict__, fp)

    mp_layers = [int(item) for item in args.mp_units.split(',')]
    downstream_layers = [int(item) for item in args.d_units.split(',')]
    mp_dropout = [float(item) for item in args.dropout_mpnn.split(',')]
    dn_dropout = [float(item) for item in args.dropout_dn.split(',')]
    logging.info(
        f'Initialising model with {mp_layers} hidden units with dropout {mp_dropout} '
        f'and downstream units: {downstream_layers} with dropout {dn_dropout}.'
    )

    if args.pooling == "globalsum":
        logging.info("Using GlobalSum Pooling")
    else:
        logging.info("Using SoftAttention Pooling")

    logging.info(
        f"Using Adam optimizer with weight_decay ({args.weightdecay}) and regularization "
        f"norm ({args.regularization})")
    logging.info(
        f"Weight init: {args.w_init} \n Contribution init: {args.c_init}")

    # data
    path = osp.join(osp.dirname(osp.realpath(__file__)), 'dataset')
    transform = concat_x_pos
    train_data = GNNBenchmarkDataset(path,
                                     name="CIFAR10",
                                     split='train',
                                     transform=transform)
    valid_data = GNNBenchmarkDataset(path,
                                     name="CIFAR10",
                                     split='val',
                                     transform=transform)
    test_data = GNNBenchmarkDataset(path,
                                    name="CIFAR10",
                                    split='test',
                                    transform=transform)
    evaluator = Evaluator()

    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              drop_last=False,
                              shuffle=True,
                              num_workers=args.nworkers,
                              pin_memory=pin_memory)
    valid_loader = DataLoader(valid_data,
                              batch_size=args.batch_size,
                              drop_last=False,
                              shuffle=False,
                              num_workers=args.nworkers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_data,
                             batch_size=args.batch_size,
                             drop_last=False,
                             shuffle=False,
                             num_workers=args.nworkers,
                             pin_memory=pin_memory)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    #device = "cpu"

    # for hypercomplex model
    unique_phm = str2bool(args.unique_phm)
    if unique_phm:
        phm_rule = get_multiplication_matrices(phm_dim=args.phm_dim,
                                               type="phm")
        phm_rule = torch.nn.ParameterList(
            [torch.nn.Parameter(a, requires_grad=learn_phm) for a in phm_rule])
    else:
        phm_rule = None

    FULL_ATOM_FEATURE_DIMS = 5
    FULL_BOND_FEATURE_DIMS = 1

    if args.aggr_msg == "pna" or args.aggr_node == "pna":
        # if PNA is used
        # Compute in-degree histogram over training data.
        deg = torch.zeros(19, dtype=torch.long)
        for data in train_data:
            d = degree(data.edge_index[1],
                       num_nodes=data.num_nodes,
                       dtype=torch.long)
            deg += torch.bincount(d, minlength=deg.numel())
    else:
        deg = None

    aggr_kwargs = {
        "aggregators": ['mean', 'min', 'max', 'std'],
        "scalers": ['identity', 'amplification', 'attenuation'],
        "deg": deg,
        "post_layers": 1,
        "msg_scalers":
        str2bool(args.msg_scale
                 ),  # this key is for directional messagepassing layers.
        "initial_beta": 1.0,  # Softmax
        "learn_beta": True
    }

    if "quaternion" in args.type:
        if args.aggr_msg == "pna" or args.aggr_msg == "pna":
            logging.info("PNA not implemented for quaternion models.")
            raise NotImplementedError

    if args.type == "undirectional-quaternion-sc-add":
        logging.info(
            "Using Quaternion Undirectional MPNN with Skip Connection through Addition"
        )
        model = UQ_SC_ADD(atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                          atom_encoded_dim=args.input_embed_dim,
                          bond_input_dims=FULL_BOND_FEATURE_DIMS,
                          naive_encoder=naive_encoder,
                          mp_layers=mp_layers,
                          dropout_mpnn=mp_dropout,
                          init=args.w_init,
                          same_dropout=same_dropout,
                          norm_mp=args.mp_norm,
                          add_self_loops=True,
                          msg_aggr=args.aggr_msg,
                          node_aggr=args.aggr_node,
                          mlp=mlp_mp,
                          pooling=args.pooling,
                          activation=args.activation,
                          real_trafo=args.real_trafo,
                          downstream_layers=downstream_layers,
                          target_dim=train_data.num_classes,
                          dropout_dn=dn_dropout,
                          norm_dn=downstream_bn,
                          msg_encoder=args.msg_encoder,
                          **aggr_kwargs)
    elif args.type == "undirectional-quaternion-sc-cat":
        logging.info(
            "Using Quaternion Undirectional MPNN with Skip Connection through Concatenation"
        )
        model = UQ_SC_CAT(atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                          atom_encoded_dim=args.input_embed_dim,
                          bond_input_dims=FULL_BOND_FEATURE_DIMS,
                          naive_encoder=naive_encoder,
                          mp_layers=mp_layers,
                          dropout_mpnn=mp_dropout,
                          init=args.w_init,
                          same_dropout=same_dropout,
                          norm_mp=args.mp_norm,
                          add_self_loops=True,
                          msg_aggr=args.aggr_msg,
                          node_aggr=args.aggr_node,
                          mlp=mlp_mp,
                          pooling=args.pooling,
                          activation=args.activation,
                          real_trafo=args.real_trafo,
                          downstream_layers=downstream_layers,
                          target_dim=train_data.num_classes,
                          dropout_dn=dn_dropout,
                          norm_dn=downstream_bn,
                          msg_encoder=args.msg_encoder,
                          **aggr_kwargs)
    elif args.type == "undirectional-phm-sc-add":
        logging.info(
            "Using PHM Undirectional MPNN with Skip Connection through Addition"
        )
        model = UPH_SC_ADD(phm_dim=phm_dim,
                           learn_phm=learn_phm,
                           phm_rule=phm_rule,
                           atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                           atom_encoded_dim=args.input_embed_dim,
                           bond_input_dims=FULL_BOND_FEATURE_DIMS,
                           naive_encoder=naive_encoder,
                           mp_layers=mp_layers,
                           dropout_mpnn=mp_dropout,
                           w_init=args.w_init,
                           c_init=args.c_init,
                           same_dropout=same_dropout,
                           norm_mp=args.mp_norm,
                           add_self_loops=True,
                           msg_aggr=args.aggr_msg,
                           node_aggr=args.aggr_node,
                           mlp=mlp_mp,
                           pooling=args.pooling,
                           activation=args.activation,
                           real_trafo=args.real_trafo,
                           downstream_layers=downstream_layers,
                           target_dim=train_data.num_classes,
                           dropout_dn=dn_dropout,
                           norm_dn=downstream_bn,
                           msg_encoder=args.msg_encoder,
                           sc_type=args.sc_type,
                           **aggr_kwargs)
    elif args.type == "undirectional-phm-sc-cat":
        logging.info(
            "Using PHM Undirectional MPNN with Skip Connection through Concatenation"
        )
        model = UPH_SC_CAT(phm_dim=phm_dim,
                           learn_phm=learn_phm,
                           phm_rule=phm_rule,
                           atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                           atom_encoded_dim=args.input_embed_dim,
                           bond_input_dims=FULL_BOND_FEATURE_DIMS,
                           naive_encoder=naive_encoder,
                           mp_layers=mp_layers,
                           dropout_mpnn=mp_dropout,
                           w_init=args.w_init,
                           c_init=args.c_init,
                           same_dropout=same_dropout,
                           norm_mp=args.mp_norm,
                           add_self_loops=True,
                           msg_aggr=args.aggr_msg,
                           node_aggr=args.aggr_node,
                           mlp=mlp_mp,
                           pooling=args.pooling,
                           activation=args.activation,
                           real_trafo=args.real_trafo,
                           downstream_layers=downstream_layers,
                           target_dim=train_data.num_classes,
                           dropout_dn=dn_dropout,
                           norm_dn=downstream_bn,
                           msg_encoder=args.msg_encoder,
                           **aggr_kwargs)
    else:
        raise ModuleNotFoundError

    logging.info(
        f"Model consists of {model.get_number_of_params_()} trainable parameters"
    )
    # do runs
    test_best_epoch_metrics_arr = []
    test_last_epoch_metrics_arr = []
    val_metrics_arr = []
    t0 = time.time()

    for i in range(1, args.n_runs + 1):
        ogb_bestEpoch_test_metrics, ogb_lastEpoch_test_metric, ogb_val_metrics = do_run(
            i, model, args, transform, train_loader, valid_loader, test_loader,
            device, evaluator, t0)

        test_best_epoch_metrics_arr.append(ogb_bestEpoch_test_metrics)
        test_last_epoch_metrics_arr.append(ogb_lastEpoch_test_metric)
        val_metrics_arr.append(ogb_val_metrics)

    logging.info(f"Performance of model across {args.n_runs} runs:")
    test_bestEpoch_perf = torch.tensor(test_best_epoch_metrics_arr)
    test_lastEpoch_perf = torch.tensor(test_last_epoch_metrics_arr)
    valid_perf = torch.tensor(val_metrics_arr)
    logging.info('===========================')
    logging.info(
        f'Final Test (best val-epoch) '
        f'"{evaluator.eval_metric}": {test_bestEpoch_perf.mean():.4f} ± {test_bestEpoch_perf.std():.4f}'
    )
    logging.info(
        f'Final Test (last-epoch) '
        f'"{evaluator.eval_metric}": {test_lastEpoch_perf.mean():.4f} ± {test_lastEpoch_perf.std():.4f}'
    )
    logging.info(
        f'Final (best) Valid "{evaluator.eval_metric}": {valid_perf.mean():.4f} ± {valid_perf.std():.4f}'
    )
예제 #2
0
def do_run(i, model, args, transform, train_loader, valid_loader, test_loader,
           device, evaluator, t0):

    logging.info(f"Run {i}/{args.n_runs}, seed: {args.seed + i - 1}")
    logging.info("Reset model parameters")
    set_seed_all(args.seed + i - 1)
    model.reset_parameters()
    model = model.to(device)

    # setting up parameter groups for optimization
    # quaternion/hypercomplex - in general
    params_mp = get_model_blocks(model,
                                 attr="convs",
                                 **dict(lr=args.lr, weight_decay=0.0))
    params_pooling = get_model_blocks(model,
                                      attr="pooling",
                                      **dict(lr=args.lr, weight_decay=0.0))
    params_downstream = get_model_blocks(model,
                                         attr="downstream",
                                         **dict(lr=args.lr, weight_decay=0.0))
    # quaternion/hypercomplex - undirectional
    params_norms = get_model_blocks(model,
                                    attr="norms",
                                    **dict(lr=args.lr, weight_decay=0.0))
    params_embedding_atom = get_model_blocks(model,
                                             attr="atomencoder",
                                             **dict(lr=args.lr,
                                                    weight_decay=0.0))
    params_embedding_bonds = get_model_blocks(model,
                                              attr="bondencoders",
                                              **dict(lr=args.lr,
                                                     weight_decay=0.0))

    params = params_mp + params_pooling + params_downstream + \
             params_norms + params_embedding_atom + params_embedding_bonds

    #  check if all params are captured
    total_params_splitted = sum([
        sum([p.numel() for p in pl["params"] if p.requires_grad])
        for pl in params
    ])
    total_params_model = sum(
        [p.numel() for p in model.parameters() if p.requires_grad])
    assert total_params_model == total_params_splitted, f"splitted total params: {total_params_splitted}." \
                                                        f"However, total params of model are: {total_params_model}"

    # optimizer
    optimizer = torch.optim.Adam(params, weight_decay=0.0)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.factor,
        patience=args.patience,
        mode="max",
        min_lr=1e-7,
        verbose=True)

    save_dir = os.path.join(args.save_dir, f"run_{i}")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # tensorboard logging
    logging.info(f"Creating new tensorboard logging for run {i}.")
    writer = SummaryWriter(log_dir=save_dir)

    lr_arr = []
    train_metrics_arr = []
    val_metrics_arr = []
    start_time = datetime.now()
    best_metric = float("-inf")
    model_save_dir = os.path.join(save_dir, "model.pt")
    model_last_save_dir = os.path.join(save_dir, "model_last.pt")
    metric_key = evaluator.eval_metric

    # At any point you can hit Ctrl + C to break out of training early.
    exit_script = "n"

    try:
        for epoch in range(1, args.epochs + 1):
            if str2bool(args.log_weights):
                # weights logging after one epoch
                # -> might need to move this into the train function to track after batch-updates
                for name, value in model.named_parameters():
                    if value.requires_grad:
                        writer.add_histogram(tag=name,
                                             values=value.data.cpu().numpy(),
                                             global_step=epoch)
                        # log only the parameter weights.
                        # if the grads should be logged too, need to move that code into the train_step function and retrieve
                        # gradients after loss.backward() is called. Otherwise the grads are None or 0.0

            lr = scheduler.optimizer.param_groups[0]['lr']
            lr_arr.append(lr)
            logging.info(f"Epoch: {epoch}/{args.epochs}, Learning Rate: {lr}.")
            train_metrics = train(epoch=epoch,
                                  model=model,
                                  device=device,
                                  transform=transform,
                                  loader=train_loader,
                                  optimizer=optimizer,
                                  evaluator=evaluator,
                                  kwargs={
                                      "run": i,
                                      "nruns": args.n_runs,
                                      "grad_clipping": args.grad_clipping,
                                      "lr": lr,
                                      "weight_decay": args.weightdecay,
                                      "weight_decay2": args.weightdecay2,
                                      "p": args.regularization
                                  })
            train_metrics_arr.append(train_metrics)
            logging.info(
                f"Training Metrics in Epoch: {epoch} \n Metrics: {train_metrics}."
            )

            validation_metrics = test_validate(model=model,
                                               device=device,
                                               transform=transform,
                                               loader=valid_loader,
                                               evaluator=evaluator)
            val_metrics_arr.append(validation_metrics)
            logging.info(
                f"Validation Metrics in Epoch: {epoch} \n Metrics: {validation_metrics}."
            )
            if validation_metrics[metric_key] > best_metric:
                logging.info(
                    f"Saving model with validation '{metric_key}': {validation_metrics[metric_key]}."
                )
                best_metric = validation_metrics[metric_key]

                torch.save(model, model_save_dir)

            scheduler.step(validation_metrics[metric_key])

            # tensorboard logging
            writer.add_scalar(tag="lr", scalar_value=lr, global_step=epoch)
            writer.add_scalar(tag="train_loss",
                              scalar_value=train_metrics["loss"],
                              global_step=epoch)
            writer.add_scalar(tag=f"train_{metric_key}",
                              scalar_value=train_metrics[metric_key],
                              global_step=epoch)
            writer.add_scalar(tag="valid_loss",
                              scalar_value=validation_metrics["loss"],
                              global_step=epoch)
            writer.add_scalar(tag=f"valid_{metric_key}",
                              scalar_value=validation_metrics[metric_key],
                              global_step=epoch)

            # benchmarking GNNs
            if optimizer.param_groups[0]['lr'] < args.min_lr:
                logging.info("\n!! LR EQUAL TO MIN LR SET.")
                break

            if time.time() - t0 > args.max_time * 3600:
                logging.info('-' * 89)
                logging.info(
                    "Max_time for training elapsed {:.2f} hours, so stopping".
                    format(args.max_time))
                break

    except KeyboardInterrupt:
        logging.info("-" * 80)
        logging.info(f"training interupted in epoch {epoch}")
        logging.info(f"saving model at {model_last_save_dir}")
        exit_script = input(
            "Should the entire run be exited after evaluation ? Type (Y/N) ")

    torch.save(model, model_last_save_dir)

    end_time = datetime.now()
    training_time = end_time - start_time
    logging.info(f"Training time: {training_time}")

    # after training, load the best model and test
    #if exit_script.lower() == "y":
    try:
        model = torch.load(model_save_dir)
    except FileNotFoundError:
        logging.info(
            f"File '{model_save_dir}' not found. Cannot load model. Will use current model."
        )

    model = model.to(device)
    logging.info(f"Testing model from best validation epoch")
    test_metrics = test_validate(model=model,
                                 loader=test_loader,
                                 device=device,
                                 transform=transform,
                                 evaluator=evaluator)
    logging.info(f"Test Metrics of model: \n Metrics: {test_metrics}.")

    # we also load the last model and test, just to see how it performs
    model = torch.load(model_last_save_dir)
    model = model.to(device)
    logging.info(f"Testing model from last epoch")
    test_metrics_last = test_validate(model=model,
                                      loader=test_loader,
                                      device=device,
                                      transform=transform,
                                      evaluator=evaluator)
    logging.info(f"Test Metrics of model: \n Metrics: {test_metrics_last}.")

    svz = {
        'train_metrics': train_metrics_arr,
        'lr': lr_arr,
        'val_metrics': val_metrics_arr,
        'test_metrics': test_metrics,
        'test_metrics_lastepoch': test_metrics_last
    }

    with open(os.path.join(save_dir, "arrays.pickle"), "wb") as fp:
        pickle.dump(svz, fp)

    svz2 = {
        "best_val": best_metric,
        "test_best_valEpoch": test_metrics[metric_key],
        "test_lastEpoch": test_metrics_last[metric_key]
    }

    with open(os.path.join(save_dir, "val_test.json"), 'w') as f:
        json.dump(svz2, f)

    # close tensorboard writer for run
    writer.close()

    if exit_script.lower() == "y":
        logging.info("exit script.")
        exit()

    return test_metrics[metric_key], test_metrics_last[metric_key], best_metric