class PHMSoftAttentionPooling(nn.Module): def __init__(self, embed_dim: int, phm_dim: int, phm_rule: Union[None, nn.ParameterList], learn_phm: bool = True, bias: bool = True, w_init: str = "phm", c_init: str = "standard", real_trafo: str = "linear"): super(PHMSoftAttentionPooling, self).__init__() self.embed_dim = embed_dim self.phm_dim = phm_dim self.w_init = w_init self.c_init = c_init self.phm_rule = phm_rule self.learn_phm = learn_phm self.real_trafo_type = real_trafo self.bias = bias self.linear = PHMLinear(in_features=self.embed_dim, out_features=self.embed_dim, phm_dim=self.phm_dim, phm_rule=phm_rule, learn_phm=learn_phm, w_init=w_init, c_init=c_init, bias=bias) self.real_trafo = RealTransformer(type=self.real_trafo_type, phm_dim=self.phm_dim, in_features=self.embed_dim, bias=True) self.sigmoid = nn.Sigmoid() self.sum_pooling = PHMGlobalSumPooling(phm_dim=self.phm_dim) self.reset_parameters() def reset_parameters(self): self.real_trafo.reset_parameters() self.linear.reset_parameters() def forward(self, x: torch.Tensor, batch: Batch) -> torch.Tensor: out = self.linear(x) # get logits out = self.real_trafo(out) # "transform" to real-valued out = self.sigmoid(out) # get "probabilities" #x = torch.stack([*x.split(split_size=self.embed_dim, dim=-1)], dim=0) x = x.reshape(x.size(0), self.phm_dim, self.embed_dim) # apply element-wise hadamard product through broadcasting out = out.unsqueeze(dim=1) x = out * x x = x.reshape(x.size(0), self.phm_dim*self.embed_dim) x = self.sum_pooling(x, batch=batch) return x def __repr__(self): return "{}(embed_dim={}, phm_dim={}, phm_rule={}, learn_phm={}," \ "bias={}, init='{}', real_trafo='{}')".format(self.__class__.__name__, self.embed_dim, self.phm_dim, self.phm_rule, self.learn_phm, self.bias, self.init, self.real_trafo_type)
class PNAAggregator(nn.Module): """ Principal Neighbourhood Aggregator inherits from nn.Module in case we want to further parametrize. """ def __init__(self, phm_dim: int, in_features: int, out_features: int, learn_phm: bool, init: str, phm_rule, aggregators: List[str], scalers: Optional[List[str]], deg: Optional[torch.Tensor]) -> None: super(PNAAggregator, self).__init__() self.in_features = in_features self.out_features = out_features self.phm_dim = phm_dim self.aggregators_l = aggregators self.aggregators = [AGGREGATORS[aggr] for aggr in aggregators] self.scalers_l = scalers if scalers: self.scalers = [SCALERS[scale] for scale in scalers] out_trafo_dim = in_features*(len(aggregators) * len(scalers)) self.deg = deg.to(torch.float) self.avg_deg: Dict[str, float] = { 'lin': self.deg.mean().item(), 'log': (self.deg + 1).log().mean().item(), 'exp': self.deg.exp().mean().item(), } else: self.scalers = None self.avg_deg = None out_trafo_dim = in_features*len(aggregators) self.transform = PHMLinear(in_features=out_trafo_dim, out_features=out_features, bias=True, phm_dim=phm_dim, phm_rule=phm_rule, learn_phm=learn_phm, init=init) self.reset_parameters() def reset_parameters(self): self.transform.reset_parameters() def forward(self, x: torch.Tensor, idx: torch.Tensor, dim_size: Optional[int] = None, dim: int = 0) -> torch.Tensor: outs = [aggr(x, idx, dim_size) for aggr in self.aggregators] # concatenate the different aggregator results, considering the shape of the hypercomplex components. out = phm_cat(tensors=outs, phm_dim=self.phm_dim, dim=-1) if self.scalers is not None: deg = degree(idx, dim_size, dtype=x.dtype).view(-1, 1) # concatenate the different aggregator results, considering the shape of the hypercomplex components. outs = [scaler(out, deg, self.avg_deg) for scaler in self.scalers] out = phm_cat(tensors=outs, phm_dim=self.phm_dim, dim=-1) out = self.transform(out) return out
class PHMConv(MessagePassing): r""" Parametrized Hypercomplex Graphconvolution operator that uses edge-attributes. Transformation is a linear layer. """ def __init__(self, in_features: int, out_features: int, phm_dim: int, phm_rule: Union[None, nn.ParameterList], learn_phm: True, bias: bool = True, add_self_loops: bool = True, w_init: str = "phm", c_init: str = "standard", aggr: str = "add", same_dim: bool = True, msg_encoder: str = "identity") -> None: super(PHMConv, self).__init__(aggr=aggr) self.in_features = in_features self.out_features = out_features self.phm_dim = phm_dim self.phm_rule = phm_rule self.learn_phm = learn_phm self.bias = bias self.add_self_loops = add_self_loops self.w_init = w_init self.c_init = c_init self.aggr = aggr self.same_dim = same_dim self.transform = PHMLinear(in_features=in_features, out_features=out_features, phm_rule=phm_rule, phm_dim=phm_dim, bias=bias, w_init=w_init, c_init=c_init, learn_phm=learn_phm) self.msg_encoder_str = msg_encoder self.msg_encoder = get_module_activation(activation=msg_encoder) self.reset_parameters() def reset_parameters(self): self.transform.reset_parameters() def forward(self, x: torch.Tensor, edge_index: Adj, edge_attr: torch.Tensor, size: Size = None) -> torch.Tensor: if self.add_self_loops: x_c = x.clone() # propagate messages x = self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr, size=size) if self.same_dim: x = self.transform(x) if self.add_self_loops: x += x_c else: if self.add_self_loops: x += x_c x = self.transform(x) return x def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: assert x_j.size(-1) == edge_attr.size(-1) return self.msg_encoder(x_j + edge_attr) def __repr__(self): return "{}(in_features={}, out_features={}, phm_dim={}, phm_rule={}," \ " learn_phm={}, bias={}, add_self_loops={}, " \ ", w_init='{}', c_init='{}', aggr='{}')".format(self.__class__.__name__, self.in_features, self.out_features, self.phm_dim, self.phm_rule, self.learn_phm, self.bias, self.add_self_loops, self.w_init, self.c_init, self.aggr)