예제 #1
0
    def reset_parameters(self):
        if self.w_init == "phm":
            W_init = phm_init(phm_dim=self.phm_dim,
                              in_features=self.in_features,
                              out_features=self.out_features)
            for W_param, W_i in zip(self.W, W_init):
                W_param.data = W_i.data

        elif self.w_init == "glorot-normal":
            for i in range(self.phm_dim):
                self.W[i] = glorot_normal(self.W[i])
        elif self.w_init == "glorot-uniform":
            for i in range(self.phm_dim):
                self.W[i] = glorot_uniform(self.W[i])
        else:
            raise ValueError
        if self.bias_flag:
            self.b[0].data.fill_(0.0)
            for bias in self.b[1:]:
                bias.data.fill_(0.2)

        if not self.shared_phm:
            phm_rule = get_multiplication_matrices(phm_dim=self.phm_dim,
                                                   type=self.c_init)
            for i, init_data in enumerate(phm_rule):
                self.phm_rule[i].data = init_data
예제 #2
0
    def reset_parameters(self):

        if not self.variable_phm:
            phm_rule = get_multiplication_matrices(phm_dim=self.phm_dim)
            for i, init_data in enumerate(phm_rule):
                self.phm_rule[i].data = init_data

        # atom encoder
        self.atomencoder.reset_parameters()

        # bond encoders
        for encoder in self.bondencoders:
            encoder.reset_parameters()

        # mp and norm layers
        for conv, norm in zip(self.convs, self.norms):
            conv.reset_parameters()
            if self.norm_mp:
                norm.reset_parameters()

        # pooling
        self.pooling.reset_parameters()

        # downstream network
        self.downstream.reset_parameters()
예제 #3
0
    def reset_parameters(self):
        # weight matrices W_i
        if self.w_init == "phm":
            W_init = phm_init(phm_dim=self.phm_dim,
                              in_features=self._in_feats_per_axis,
                              out_features=self._out_feats_per_axis,
                              transpose=False)
            self.W.data = W_init

        elif self.w_init == "glorot-normal":
            for i in range(self.phm_dim):
                self.W.data[i] = glorot_normal(self.W.data[i])
        elif self.w_init == "glorot-uniform":
            for i in range(self.phm_dim):
                self.W.data[i] = glorot_uniform(self.W.data[i])
        else:
            raise ValueError
        if self.bias_flag:
            self.b.data[:self._out_feats_per_axis] = 0.0
            self.b.data[(self._out_feats_per_axis + 1):] = 0.2

        # contribution matrices C_i
        self.phm_rule.data = torch.stack(get_multiplication_matrices(
            self.phm_dim, type=self.c_init),
                                         dim=0)
예제 #4
0
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 phm_dim: int,
                 phm_rule: Union[None, nn.Parameter, nn.ParameterList, list,
                                 torch.Tensor] = None,
                 bias: bool = True,
                 w_init: str = "phm",
                 c_init: str = "standard",
                 learn_phm: bool = True) -> None:
        super(PHMLinear, self).__init__()
        assert w_init in ["phm", "glorot-normal", "glorot-uniform"]
        assert c_init in ["standard", "random"]
        self.in_features = in_features
        self.out_features = out_features
        self.learn_phm = learn_phm
        self.phm_dim = phm_dim

        self.shared_phm = False
        if phm_rule is not None:
            self.shared_phm = True
            self.phm_rule = phm_rule
            if not isinstance(phm_rule, nn.ParameterList) and learn_phm:
                self.phm_rule = nn.ParameterList([
                    nn.Parameter(mat, requires_grad=learn_phm)
                    for mat in self.phm_rule
                ])
        else:
            self.phm_rule = get_multiplication_matrices(phm_dim, type=c_init)

        self.phm_rule = nn.ParameterList([
            nn.Parameter(mat, requires_grad=learn_phm) for mat in self.phm_rule
        ])

        self.bias_flag = bias
        self.w_init = w_init
        self.c_init = c_init
        self.W = nn.ParameterList([
            nn.Parameter(torch.Tensor(out_features, in_features),
                         requires_grad=True) for _ in range(phm_dim)
        ])
        if self.bias_flag:
            self.b = nn.ParameterList([
                nn.Parameter(torch.Tensor(out_features), requires_grad=True)
                for _ in range(phm_dim)
            ])
        else:
            self.register_parameter("b", None)

        self.reset_parameters()
예제 #5
0
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 phm_dim: int,
                 phm_rule: Union[None, torch.Tensor] = None,
                 bias: bool = True,
                 w_init: str = "phm",
                 c_init: str = "random",
                 learn_phm: bool = True) -> None:
        super(PHMLinear, self).__init__()
        assert w_init in ["phm", "glorot-normal", "glorot-uniform"]
        assert c_init in ["standard", "random"]
        assert in_features % phm_dim == 0, f"Argument `in_features`={in_features} is not divisble be `phm_dim`{phm_dim}"
        assert out_features % phm_dim == 0, f"Argument `out_features`={out_features} is not divisble be `phm_dim`{phm_dim}"

        self.in_features = in_features
        self.out_features = out_features
        self.learn_phm = learn_phm
        self.phm_dim = phm_dim

        self._in_feats_per_axis = in_features // phm_dim
        self._out_feats_per_axis = out_features // phm_dim

        if phm_rule is not None:
            self.phm_rule = phm_rule
        else:
            self.phm_rule = get_multiplication_matrices(phm_dim, type=c_init)

        self.phm_rule = nn.Parameter(torch.stack([*self.phm_rule], dim=0),
                                     requires_grad=learn_phm)

        self.bias_flag = bias
        self.w_init = w_init
        self.c_init = c_init
        self.W = nn.Parameter(torch.Tensor(size=(phm_dim,
                                                 self._in_feats_per_axis,
                                                 self._out_feats_per_axis)),
                              requires_grad=True)
        if self.bias_flag:
            self.b = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter("b", None)

        self.reset_parameters()
예제 #6
0
def main():
    args = get_parser()
    # get some argparse arguments that are parsed a bool string
    naive_encoder = not str2bool(args.full_encoder)
    pin_memory = str2bool(args.pin_memory)
    use_bias = str2bool(args.bias)
    downstream_bn = str(args.d_bn)
    same_dropout = str2bool(args.same_dropout)
    mlp_mp = str2bool(args.mlp_mp)
    phm_dim = args.phm_dim
    learn_phm = str2bool(args.learn_phm)

    base_dir = "cifar10/"
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)

    if base_dir not in args.save_dir:
        args.save_dir = os.path.join(base_dir, args.save_dir)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    set_logging(save_dir=args.save_dir)
    logging.info(f"Creating log directory at {args.save_dir}.")
    with open(os.path.join(args.save_dir, "params.json"), 'w') as fp:
        json.dump(args.__dict__, fp)

    mp_layers = [int(item) for item in args.mp_units.split(',')]
    downstream_layers = [int(item) for item in args.d_units.split(',')]
    mp_dropout = [float(item) for item in args.dropout_mpnn.split(',')]
    dn_dropout = [float(item) for item in args.dropout_dn.split(',')]
    logging.info(
        f'Initialising model with {mp_layers} hidden units with dropout {mp_dropout} '
        f'and downstream units: {downstream_layers} with dropout {dn_dropout}.'
    )

    if args.pooling == "globalsum":
        logging.info("Using GlobalSum Pooling")
    else:
        logging.info("Using SoftAttention Pooling")

    logging.info(
        f"Using Adam optimizer with weight_decay ({args.weightdecay}) and regularization "
        f"norm ({args.regularization})")
    logging.info(
        f"Weight init: {args.w_init} \n Contribution init: {args.c_init}")

    # data
    path = osp.join(osp.dirname(osp.realpath(__file__)), 'dataset')
    transform = concat_x_pos
    train_data = GNNBenchmarkDataset(path,
                                     name="CIFAR10",
                                     split='train',
                                     transform=transform)
    valid_data = GNNBenchmarkDataset(path,
                                     name="CIFAR10",
                                     split='val',
                                     transform=transform)
    test_data = GNNBenchmarkDataset(path,
                                    name="CIFAR10",
                                    split='test',
                                    transform=transform)
    evaluator = Evaluator()

    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              drop_last=False,
                              shuffle=True,
                              num_workers=args.nworkers,
                              pin_memory=pin_memory)
    valid_loader = DataLoader(valid_data,
                              batch_size=args.batch_size,
                              drop_last=False,
                              shuffle=False,
                              num_workers=args.nworkers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_data,
                             batch_size=args.batch_size,
                             drop_last=False,
                             shuffle=False,
                             num_workers=args.nworkers,
                             pin_memory=pin_memory)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    #device = "cpu"

    # for hypercomplex model
    unique_phm = str2bool(args.unique_phm)
    if unique_phm:
        phm_rule = get_multiplication_matrices(phm_dim=args.phm_dim,
                                               type="phm")
        phm_rule = torch.nn.ParameterList(
            [torch.nn.Parameter(a, requires_grad=learn_phm) for a in phm_rule])
    else:
        phm_rule = None

    FULL_ATOM_FEATURE_DIMS = 5
    FULL_BOND_FEATURE_DIMS = 1

    if args.aggr_msg == "pna" or args.aggr_node == "pna":
        # if PNA is used
        # Compute in-degree histogram over training data.
        deg = torch.zeros(19, dtype=torch.long)
        for data in train_data:
            d = degree(data.edge_index[1],
                       num_nodes=data.num_nodes,
                       dtype=torch.long)
            deg += torch.bincount(d, minlength=deg.numel())
    else:
        deg = None

    aggr_kwargs = {
        "aggregators": ['mean', 'min', 'max', 'std'],
        "scalers": ['identity', 'amplification', 'attenuation'],
        "deg": deg,
        "post_layers": 1,
        "msg_scalers":
        str2bool(args.msg_scale
                 ),  # this key is for directional messagepassing layers.
        "initial_beta": 1.0,  # Softmax
        "learn_beta": True
    }

    if "quaternion" in args.type:
        if args.aggr_msg == "pna" or args.aggr_msg == "pna":
            logging.info("PNA not implemented for quaternion models.")
            raise NotImplementedError

    if args.type == "undirectional-quaternion-sc-add":
        logging.info(
            "Using Quaternion Undirectional MPNN with Skip Connection through Addition"
        )
        model = UQ_SC_ADD(atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                          atom_encoded_dim=args.input_embed_dim,
                          bond_input_dims=FULL_BOND_FEATURE_DIMS,
                          naive_encoder=naive_encoder,
                          mp_layers=mp_layers,
                          dropout_mpnn=mp_dropout,
                          init=args.w_init,
                          same_dropout=same_dropout,
                          norm_mp=args.mp_norm,
                          add_self_loops=True,
                          msg_aggr=args.aggr_msg,
                          node_aggr=args.aggr_node,
                          mlp=mlp_mp,
                          pooling=args.pooling,
                          activation=args.activation,
                          real_trafo=args.real_trafo,
                          downstream_layers=downstream_layers,
                          target_dim=train_data.num_classes,
                          dropout_dn=dn_dropout,
                          norm_dn=downstream_bn,
                          msg_encoder=args.msg_encoder,
                          **aggr_kwargs)
    elif args.type == "undirectional-quaternion-sc-cat":
        logging.info(
            "Using Quaternion Undirectional MPNN with Skip Connection through Concatenation"
        )
        model = UQ_SC_CAT(atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                          atom_encoded_dim=args.input_embed_dim,
                          bond_input_dims=FULL_BOND_FEATURE_DIMS,
                          naive_encoder=naive_encoder,
                          mp_layers=mp_layers,
                          dropout_mpnn=mp_dropout,
                          init=args.w_init,
                          same_dropout=same_dropout,
                          norm_mp=args.mp_norm,
                          add_self_loops=True,
                          msg_aggr=args.aggr_msg,
                          node_aggr=args.aggr_node,
                          mlp=mlp_mp,
                          pooling=args.pooling,
                          activation=args.activation,
                          real_trafo=args.real_trafo,
                          downstream_layers=downstream_layers,
                          target_dim=train_data.num_classes,
                          dropout_dn=dn_dropout,
                          norm_dn=downstream_bn,
                          msg_encoder=args.msg_encoder,
                          **aggr_kwargs)
    elif args.type == "undirectional-phm-sc-add":
        logging.info(
            "Using PHM Undirectional MPNN with Skip Connection through Addition"
        )
        model = UPH_SC_ADD(phm_dim=phm_dim,
                           learn_phm=learn_phm,
                           phm_rule=phm_rule,
                           atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                           atom_encoded_dim=args.input_embed_dim,
                           bond_input_dims=FULL_BOND_FEATURE_DIMS,
                           naive_encoder=naive_encoder,
                           mp_layers=mp_layers,
                           dropout_mpnn=mp_dropout,
                           w_init=args.w_init,
                           c_init=args.c_init,
                           same_dropout=same_dropout,
                           norm_mp=args.mp_norm,
                           add_self_loops=True,
                           msg_aggr=args.aggr_msg,
                           node_aggr=args.aggr_node,
                           mlp=mlp_mp,
                           pooling=args.pooling,
                           activation=args.activation,
                           real_trafo=args.real_trafo,
                           downstream_layers=downstream_layers,
                           target_dim=train_data.num_classes,
                           dropout_dn=dn_dropout,
                           norm_dn=downstream_bn,
                           msg_encoder=args.msg_encoder,
                           sc_type=args.sc_type,
                           **aggr_kwargs)
    elif args.type == "undirectional-phm-sc-cat":
        logging.info(
            "Using PHM Undirectional MPNN with Skip Connection through Concatenation"
        )
        model = UPH_SC_CAT(phm_dim=phm_dim,
                           learn_phm=learn_phm,
                           phm_rule=phm_rule,
                           atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                           atom_encoded_dim=args.input_embed_dim,
                           bond_input_dims=FULL_BOND_FEATURE_DIMS,
                           naive_encoder=naive_encoder,
                           mp_layers=mp_layers,
                           dropout_mpnn=mp_dropout,
                           w_init=args.w_init,
                           c_init=args.c_init,
                           same_dropout=same_dropout,
                           norm_mp=args.mp_norm,
                           add_self_loops=True,
                           msg_aggr=args.aggr_msg,
                           node_aggr=args.aggr_node,
                           mlp=mlp_mp,
                           pooling=args.pooling,
                           activation=args.activation,
                           real_trafo=args.real_trafo,
                           downstream_layers=downstream_layers,
                           target_dim=train_data.num_classes,
                           dropout_dn=dn_dropout,
                           norm_dn=downstream_bn,
                           msg_encoder=args.msg_encoder,
                           **aggr_kwargs)
    else:
        raise ModuleNotFoundError

    logging.info(
        f"Model consists of {model.get_number_of_params_()} trainable parameters"
    )
    # do runs
    test_best_epoch_metrics_arr = []
    test_last_epoch_metrics_arr = []
    val_metrics_arr = []
    t0 = time.time()

    for i in range(1, args.n_runs + 1):
        ogb_bestEpoch_test_metrics, ogb_lastEpoch_test_metric, ogb_val_metrics = do_run(
            i, model, args, transform, train_loader, valid_loader, test_loader,
            device, evaluator, t0)

        test_best_epoch_metrics_arr.append(ogb_bestEpoch_test_metrics)
        test_last_epoch_metrics_arr.append(ogb_lastEpoch_test_metric)
        val_metrics_arr.append(ogb_val_metrics)

    logging.info(f"Performance of model across {args.n_runs} runs:")
    test_bestEpoch_perf = torch.tensor(test_best_epoch_metrics_arr)
    test_lastEpoch_perf = torch.tensor(test_last_epoch_metrics_arr)
    valid_perf = torch.tensor(val_metrics_arr)
    logging.info('===========================')
    logging.info(
        f'Final Test (best val-epoch) '
        f'"{evaluator.eval_metric}": {test_bestEpoch_perf.mean():.4f} ± {test_bestEpoch_perf.std():.4f}'
    )
    logging.info(
        f'Final Test (last-epoch) '
        f'"{evaluator.eval_metric}": {test_lastEpoch_perf.mean():.4f} ± {test_lastEpoch_perf.std():.4f}'
    )
    logging.info(
        f'Final (best) Valid "{evaluator.eval_metric}": {valid_perf.mean():.4f} ± {valid_perf.std():.4f}'
    )