def reset_parameters(self): if self.w_init == "phm": W_init = phm_init(phm_dim=self.phm_dim, in_features=self.in_features, out_features=self.out_features) for W_param, W_i in zip(self.W, W_init): W_param.data = W_i.data elif self.w_init == "glorot-normal": for i in range(self.phm_dim): self.W[i] = glorot_normal(self.W[i]) elif self.w_init == "glorot-uniform": for i in range(self.phm_dim): self.W[i] = glorot_uniform(self.W[i]) else: raise ValueError if self.bias_flag: self.b[0].data.fill_(0.0) for bias in self.b[1:]: bias.data.fill_(0.2) if not self.shared_phm: phm_rule = get_multiplication_matrices(phm_dim=self.phm_dim, type=self.c_init) for i, init_data in enumerate(phm_rule): self.phm_rule[i].data = init_data
def reset_parameters(self): if not self.variable_phm: phm_rule = get_multiplication_matrices(phm_dim=self.phm_dim) for i, init_data in enumerate(phm_rule): self.phm_rule[i].data = init_data # atom encoder self.atomencoder.reset_parameters() # bond encoders for encoder in self.bondencoders: encoder.reset_parameters() # mp and norm layers for conv, norm in zip(self.convs, self.norms): conv.reset_parameters() if self.norm_mp: norm.reset_parameters() # pooling self.pooling.reset_parameters() # downstream network self.downstream.reset_parameters()
def reset_parameters(self): # weight matrices W_i if self.w_init == "phm": W_init = phm_init(phm_dim=self.phm_dim, in_features=self._in_feats_per_axis, out_features=self._out_feats_per_axis, transpose=False) self.W.data = W_init elif self.w_init == "glorot-normal": for i in range(self.phm_dim): self.W.data[i] = glorot_normal(self.W.data[i]) elif self.w_init == "glorot-uniform": for i in range(self.phm_dim): self.W.data[i] = glorot_uniform(self.W.data[i]) else: raise ValueError if self.bias_flag: self.b.data[:self._out_feats_per_axis] = 0.0 self.b.data[(self._out_feats_per_axis + 1):] = 0.2 # contribution matrices C_i self.phm_rule.data = torch.stack(get_multiplication_matrices( self.phm_dim, type=self.c_init), dim=0)
def __init__(self, in_features: int, out_features: int, phm_dim: int, phm_rule: Union[None, nn.Parameter, nn.ParameterList, list, torch.Tensor] = None, bias: bool = True, w_init: str = "phm", c_init: str = "standard", learn_phm: bool = True) -> None: super(PHMLinear, self).__init__() assert w_init in ["phm", "glorot-normal", "glorot-uniform"] assert c_init in ["standard", "random"] self.in_features = in_features self.out_features = out_features self.learn_phm = learn_phm self.phm_dim = phm_dim self.shared_phm = False if phm_rule is not None: self.shared_phm = True self.phm_rule = phm_rule if not isinstance(phm_rule, nn.ParameterList) and learn_phm: self.phm_rule = nn.ParameterList([ nn.Parameter(mat, requires_grad=learn_phm) for mat in self.phm_rule ]) else: self.phm_rule = get_multiplication_matrices(phm_dim, type=c_init) self.phm_rule = nn.ParameterList([ nn.Parameter(mat, requires_grad=learn_phm) for mat in self.phm_rule ]) self.bias_flag = bias self.w_init = w_init self.c_init = c_init self.W = nn.ParameterList([ nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=True) for _ in range(phm_dim) ]) if self.bias_flag: self.b = nn.ParameterList([ nn.Parameter(torch.Tensor(out_features), requires_grad=True) for _ in range(phm_dim) ]) else: self.register_parameter("b", None) self.reset_parameters()
def __init__(self, in_features: int, out_features: int, phm_dim: int, phm_rule: Union[None, torch.Tensor] = None, bias: bool = True, w_init: str = "phm", c_init: str = "random", learn_phm: bool = True) -> None: super(PHMLinear, self).__init__() assert w_init in ["phm", "glorot-normal", "glorot-uniform"] assert c_init in ["standard", "random"] assert in_features % phm_dim == 0, f"Argument `in_features`={in_features} is not divisble be `phm_dim`{phm_dim}" assert out_features % phm_dim == 0, f"Argument `out_features`={out_features} is not divisble be `phm_dim`{phm_dim}" self.in_features = in_features self.out_features = out_features self.learn_phm = learn_phm self.phm_dim = phm_dim self._in_feats_per_axis = in_features // phm_dim self._out_feats_per_axis = out_features // phm_dim if phm_rule is not None: self.phm_rule = phm_rule else: self.phm_rule = get_multiplication_matrices(phm_dim, type=c_init) self.phm_rule = nn.Parameter(torch.stack([*self.phm_rule], dim=0), requires_grad=learn_phm) self.bias_flag = bias self.w_init = w_init self.c_init = c_init self.W = nn.Parameter(torch.Tensor(size=(phm_dim, self._in_feats_per_axis, self._out_feats_per_axis)), requires_grad=True) if self.bias_flag: self.b = nn.Parameter(torch.Tensor(out_features)) else: self.register_parameter("b", None) self.reset_parameters()
def main(): args = get_parser() # get some argparse arguments that are parsed a bool string naive_encoder = not str2bool(args.full_encoder) pin_memory = str2bool(args.pin_memory) use_bias = str2bool(args.bias) downstream_bn = str(args.d_bn) same_dropout = str2bool(args.same_dropout) mlp_mp = str2bool(args.mlp_mp) phm_dim = args.phm_dim learn_phm = str2bool(args.learn_phm) base_dir = "cifar10/" if not os.path.exists(base_dir): os.makedirs(base_dir) if base_dir not in args.save_dir: args.save_dir = os.path.join(base_dir, args.save_dir) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) set_logging(save_dir=args.save_dir) logging.info(f"Creating log directory at {args.save_dir}.") with open(os.path.join(args.save_dir, "params.json"), 'w') as fp: json.dump(args.__dict__, fp) mp_layers = [int(item) for item in args.mp_units.split(',')] downstream_layers = [int(item) for item in args.d_units.split(',')] mp_dropout = [float(item) for item in args.dropout_mpnn.split(',')] dn_dropout = [float(item) for item in args.dropout_dn.split(',')] logging.info( f'Initialising model with {mp_layers} hidden units with dropout {mp_dropout} ' f'and downstream units: {downstream_layers} with dropout {dn_dropout}.' ) if args.pooling == "globalsum": logging.info("Using GlobalSum Pooling") else: logging.info("Using SoftAttention Pooling") logging.info( f"Using Adam optimizer with weight_decay ({args.weightdecay}) and regularization " f"norm ({args.regularization})") logging.info( f"Weight init: {args.w_init} \n Contribution init: {args.c_init}") # data path = osp.join(osp.dirname(osp.realpath(__file__)), 'dataset') transform = concat_x_pos train_data = GNNBenchmarkDataset(path, name="CIFAR10", split='train', transform=transform) valid_data = GNNBenchmarkDataset(path, name="CIFAR10", split='val', transform=transform) test_data = GNNBenchmarkDataset(path, name="CIFAR10", split='test', transform=transform) evaluator = Evaluator() train_loader = DataLoader(train_data, batch_size=args.batch_size, drop_last=False, shuffle=True, num_workers=args.nworkers, pin_memory=pin_memory) valid_loader = DataLoader(valid_data, batch_size=args.batch_size, drop_last=False, shuffle=False, num_workers=args.nworkers, pin_memory=pin_memory) test_loader = DataLoader(test_data, batch_size=args.batch_size, drop_last=False, shuffle=False, num_workers=args.nworkers, pin_memory=pin_memory) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' #device = "cpu" # for hypercomplex model unique_phm = str2bool(args.unique_phm) if unique_phm: phm_rule = get_multiplication_matrices(phm_dim=args.phm_dim, type="phm") phm_rule = torch.nn.ParameterList( [torch.nn.Parameter(a, requires_grad=learn_phm) for a in phm_rule]) else: phm_rule = None FULL_ATOM_FEATURE_DIMS = 5 FULL_BOND_FEATURE_DIMS = 1 if args.aggr_msg == "pna" or args.aggr_node == "pna": # if PNA is used # Compute in-degree histogram over training data. deg = torch.zeros(19, dtype=torch.long) for data in train_data: d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg += torch.bincount(d, minlength=deg.numel()) else: deg = None aggr_kwargs = { "aggregators": ['mean', 'min', 'max', 'std'], "scalers": ['identity', 'amplification', 'attenuation'], "deg": deg, "post_layers": 1, "msg_scalers": str2bool(args.msg_scale ), # this key is for directional messagepassing layers. "initial_beta": 1.0, # Softmax "learn_beta": True } if "quaternion" in args.type: if args.aggr_msg == "pna" or args.aggr_msg == "pna": logging.info("PNA not implemented for quaternion models.") raise NotImplementedError if args.type == "undirectional-quaternion-sc-add": logging.info( "Using Quaternion Undirectional MPNN with Skip Connection through Addition" ) model = UQ_SC_ADD(atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, init=args.w_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=train_data.num_classes, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, **aggr_kwargs) elif args.type == "undirectional-quaternion-sc-cat": logging.info( "Using Quaternion Undirectional MPNN with Skip Connection through Concatenation" ) model = UQ_SC_CAT(atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, init=args.w_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=train_data.num_classes, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, **aggr_kwargs) elif args.type == "undirectional-phm-sc-add": logging.info( "Using PHM Undirectional MPNN with Skip Connection through Addition" ) model = UPH_SC_ADD(phm_dim=phm_dim, learn_phm=learn_phm, phm_rule=phm_rule, atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, w_init=args.w_init, c_init=args.c_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=train_data.num_classes, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, sc_type=args.sc_type, **aggr_kwargs) elif args.type == "undirectional-phm-sc-cat": logging.info( "Using PHM Undirectional MPNN with Skip Connection through Concatenation" ) model = UPH_SC_CAT(phm_dim=phm_dim, learn_phm=learn_phm, phm_rule=phm_rule, atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, w_init=args.w_init, c_init=args.c_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=train_data.num_classes, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, **aggr_kwargs) else: raise ModuleNotFoundError logging.info( f"Model consists of {model.get_number_of_params_()} trainable parameters" ) # do runs test_best_epoch_metrics_arr = [] test_last_epoch_metrics_arr = [] val_metrics_arr = [] t0 = time.time() for i in range(1, args.n_runs + 1): ogb_bestEpoch_test_metrics, ogb_lastEpoch_test_metric, ogb_val_metrics = do_run( i, model, args, transform, train_loader, valid_loader, test_loader, device, evaluator, t0) test_best_epoch_metrics_arr.append(ogb_bestEpoch_test_metrics) test_last_epoch_metrics_arr.append(ogb_lastEpoch_test_metric) val_metrics_arr.append(ogb_val_metrics) logging.info(f"Performance of model across {args.n_runs} runs:") test_bestEpoch_perf = torch.tensor(test_best_epoch_metrics_arr) test_lastEpoch_perf = torch.tensor(test_last_epoch_metrics_arr) valid_perf = torch.tensor(val_metrics_arr) logging.info('===========================') logging.info( f'Final Test (best val-epoch) ' f'"{evaluator.eval_metric}": {test_bestEpoch_perf.mean():.4f} ± {test_bestEpoch_perf.std():.4f}' ) logging.info( f'Final Test (last-epoch) ' f'"{evaluator.eval_metric}": {test_lastEpoch_perf.mean():.4f} ± {test_lastEpoch_perf.std():.4f}' ) logging.info( f'Final (best) Valid "{evaluator.eval_metric}": {valid_perf.mean():.4f} ± {valid_perf.std():.4f}' )