def __init__(self, node_dim, edge_dim, hidden_dim, residual: bool = True, pairwise_distances: bool = False, activation: Union[Callable, str] = "relu", last_activation: Union[Callable, str] = "none", mid_batch_norm: bool = False, last_batch_norm: bool = False, propagation_depth: int = 5, dropout: float = 0.0, posttrans_layers: int = 1, pretrans_layers: int = 1, **kwargs): super(MPNNGNN, self).__init__() self.node_input_net = MLP( in_dim=node_dim, hidden_size=hidden_dim, out_dim=hidden_dim, mid_batch_norm=mid_batch_norm, last_batch_norm=last_batch_norm, layers=1, mid_activation='relu', dropout=dropout, last_activation=last_activation, ) if edge_dim > 0: self.edge_input = MLP( in_dim=edge_dim, hidden_size=hidden_dim, out_dim=hidden_dim, mid_batch_norm=mid_batch_norm, last_batch_norm=last_batch_norm, layers=1, mid_activation='relu', dropout=dropout, last_activation=last_activation, ) self.mp_layers = nn.ModuleList() for _ in range(propagation_depth): self.mp_layers.append(MPNNLayer(in_dim=hidden_dim, out_dim=int(hidden_dim), in_dim_edges=edge_dim, pairwise_distances=pairwise_distances, residual=residual, dropout=dropout, activation=activation, last_activation=last_activation, mid_batch_norm=mid_batch_norm, last_batch_norm=last_batch_norm, posttrans_layers=posttrans_layers, pretrans_layers=pretrans_layers, ), )
def __init__(self, hidden_dim, target_dim, projection_dim = 3, distance_net=False, projection_layers=1, **kwargs): super(DistancePredictor, self).__init__() self.gnn = PNAGNN(hidden_dim=hidden_dim, **kwargs) if projection_dim > 0: self.node_projection_net = MLP(in_dim=hidden_dim, hidden_size=32, mid_batch_norm=True, out_dim=projection_dim, layers=projection_layers) else: self.node_projection_net = None if distance_net: self.distance_net = MLP(in_dim=hidden_dim * 2, hidden_size=projection_dim, mid_batch_norm=True, out_dim=target_dim, layers=projection_layers) else: self.distance_net = None
def __init__(self, node_dim, edge_dim, batch_norm_momentum, residual, in_feat_dropout, dropout, layer_norm, batch_norm, gamma, full_graph, GT_hidden_dim, GT_n_heads, GT_out_dim, GT_layers, LPE_n_heads, LPE_layers, LPE_dim, **kwargs): super().__init__() self.residual = residual self.layer_norm = layer_norm self.batch_norm = batch_norm self.in_feat_dropout = nn.Dropout(in_feat_dropout) self.embedding_h = MLP(in_dim=node_dim, hidden_size=node_dim, layers=1, out_dim=GT_hidden_dim - LPE_dim) self.embedding_e_real = MLP(in_dim=edge_dim, hidden_size=edge_dim, layers=1, out_dim=GT_hidden_dim) self.embedding_e_fake = MLP(in_dim=edge_dim, hidden_size=edge_dim, layers=1, out_dim=GT_hidden_dim) self.linear_A = nn.Linear(2, LPE_dim) encoder_layer = nn.TransformerEncoderLayer(d_model=LPE_dim, nhead=LPE_n_heads) self.PE_Transformer = nn.TransformerEncoder(encoder_layer, num_layers=LPE_layers) self.layers = nn.ModuleList([ GraphTransformerLayer(gamma, GT_hidden_dim, GT_hidden_dim, GT_n_heads, full_graph, dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(GT_layers - 1) ]) self.layers.append( GraphTransformerLayer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads, full_graph, dropout, self.layer_norm, self.batch_norm, self.residual))
def __init__(self, in_dim: int, out_dim: int, in_dim_edges: int, activation: Union[Callable, str] = "relu", last_activation: Union[Callable, str] = "none", dropout: float = 0.0, residual: bool = True, pairwise_distances: bool = False, mid_batch_norm: bool = False, last_batch_norm: bool = False, posttrans_layers: int = 2, pretrans_layers: int = 1, ): super(MPNNLayer, self).__init__() self.edge_features = in_dim_edges > 0 self.activation = activation self.pairwise_distances = pairwise_distances self.residual = residual if in_dim != out_dim: self.residual = False self.pretrans = MLP( in_dim=(2 * in_dim + in_dim_edges + 1) if self.pairwise_distances else (2 * in_dim + in_dim_edges), hidden_size=in_dim, out_dim=in_dim, mid_batch_norm=mid_batch_norm, last_batch_norm=last_batch_norm, layers=pretrans_layers, mid_activation='relu', dropout=dropout, last_activation=last_activation, ) self.posttrans = MLP( in_dim=in_dim, hidden_size=out_dim, out_dim=out_dim, layers=posttrans_layers, mid_activation=self.activation, last_activation=last_activation, dropout=dropout, mid_batch_norm=mid_batch_norm, last_batch_norm=last_batch_norm, )
def __init__(self, GT_out_dim, readout_hidden_dim, readout_batchnorm, readout_aggregators, target_dim, readout_layers, batch_norm_momentum, **kwargs): super().__init__() self.readout_aggregators = readout_aggregators self.gnn = SAN_NodeLPE(GT_out_dim=GT_out_dim, batch_norm_momentum=batch_norm_momentum, **kwargs) self.output = MLP(in_dim=GT_out_dim * len(self.readout_aggregators), hidden_size=readout_hidden_dim, mid_batch_norm=readout_batchnorm, out_dim=target_dim, layers=readout_layers, batch_norm_momentum=batch_norm_momentum)
def __init__(self, hidden_dim, target_dim, batch_norm=False, readout_batchnorm=True, batch_norm_momentum=0.1, dropout=0.0, readout_layers: int = 2, readout_hidden_dim=None, fourier_encodings=0, activation: str = 'SiLU', **kwargs): super(DistanceEncoder, self).__init__() self.fourier_encodings = fourier_encodings input_dim = 1 if fourier_encodings == 0 else 2 * fourier_encodings + 1 self.input_net = MLP( in_dim=input_dim, hidden_size=hidden_dim, out_dim=hidden_dim, mid_batch_norm=batch_norm, last_batch_norm=batch_norm, batch_norm_momentum=batch_norm_momentum, layers=1, mid_activation=activation, dropout=dropout, last_activation=activation, ) if readout_hidden_dim == None: readout_hidden_dim = hidden_dim self.output = MLP(in_dim=hidden_dim * 4, hidden_size=readout_hidden_dim, mid_batch_norm=readout_batchnorm, batch_norm_momentum=batch_norm_momentum, out_dim=target_dim, layers=readout_layers)
def __init__(self, metric_dim, repeats, dropout=0.8, mid_batch_norm=True, last_batch_norm=True, **kwargs): super(BasicCritic, self).__init__() self.repeats = repeats self.dropout = dropout self.criticise = MLP(in_dim=metric_dim * repeats, hidden_size=metric_dim * repeats, mid_batch_norm=mid_batch_norm, out_dim=metric_dim * repeats, last_batch_norm=last_batch_norm, dropout=0, layers=2)
def __init__(self, node_dim, edge_dim, hidden_dim, aggregators: List[str], scalers: List[str], activation: Union[Callable, str] = "relu", last_activation: Union[Callable, str] = "none", mid_batch_norm: bool = False, last_batch_norm: bool = False, propagation_depth: int = 5, dropout: float = 0.0, posttrans_layers: int = 1, pretrans_layers: int = 1, **kwargs): super(DGNGNN, self).__init__() self.node_input_net = MLP( in_dim=node_dim, hidden_size=hidden_dim, out_dim=hidden_dim, mid_batch_norm=mid_batch_norm, last_batch_norm=last_batch_norm, layers=1, mid_activation='relu', dropout=dropout, last_activation=last_activation, ) self.mp_layers = nn.ModuleList() for _ in range(propagation_depth): self.mp_layers.append(DGNMessagePassingLayer(in_dim=hidden_dim, out_dim=hidden_dim, in_dim_edges=edge_dim, aggregators=aggregators, scalers=scalers, dropout=dropout, activation=activation, last_activation=last_activation, avg_d={"log": 1.0}, posttrans_layers=posttrans_layers, pretrans_layers=pretrans_layers, ), )
def __init__( self, model_type, model_parameters, predictor_layers=1, predictor_hidden_size=256, predictor_batchnorm=False, metric_dim=256, ma_decay=0.99, #moving average decay **kwargs): super(BYOLwrapper, self).__init__() self.student = globals()[model_type](**model_parameters, **kwargs) self.teacher = copy.deepcopy(self.student) self.predictor_layers = predictor_layers if predictor_layers > 0: self.predictor = MLP(in_dim=model_parameters['target_dim'], hidden_size=predictor_hidden_size, mid_batch_norm=predictor_batchnorm, out_dim=metric_dim, layers=predictor_layers) self.ma_decay = ma_decay for p in self.teacher.parameters(): p.requires_grad = False
def __init__(self, node_dim, edge_dim, hidden_dim, target_dim, aggregators: List[str], scalers: List[str], readout_aggregators: List[str], frozen_readout_aggregators: List[str], latent3d_dim: int = 256, readout_batchnorm: bool = True, readout_hidden_dim=None, readout_layers: int = 2, residual: bool = True, pairwise_distances: bool = False, activation: Union[Callable, str] = "relu", last_activation: Union[Callable, str] = "none", mid_batch_norm: bool = False, last_batch_norm: bool = False, propagation_depth: int = 5, dropout: float = 0.0, posttrans_layers: int = 1, pretrans_layers: int = 1, **kwargs): super(PNAFrozenCombined, self).__init__() # the pretrained GNN self.node_gnn = PNAGNN(node_dim=node_dim, edge_dim=edge_dim, hidden_dim=hidden_dim, aggregators=aggregators, scalers=scalers, residual=residual, pairwise_distances=pairwise_distances, activation=activation, last_activation=last_activation, mid_batch_norm=mid_batch_norm, last_batch_norm=last_batch_norm, propagation_depth=propagation_depth, dropout=dropout, posttrans_layers=posttrans_layers, pretrans_layers=pretrans_layers) self.node_gnn2D = PNAGNN(node_dim=node_dim, edge_dim=edge_dim, hidden_dim=hidden_dim, aggregators=aggregators, scalers=scalers, residual=residual, pairwise_distances=pairwise_distances, activation=activation, last_activation=last_activation, mid_batch_norm=mid_batch_norm, last_batch_norm=last_batch_norm, propagation_depth=propagation_depth, dropout=dropout, posttrans_layers=posttrans_layers, pretrans_layers=pretrans_layers) if readout_hidden_dim == None: readout_hidden_dim = hidden_dim self.frozen_readout_aggregators = frozen_readout_aggregators self.output = MLP(in_dim=hidden_dim * len(self.frozen_readout_aggregators), hidden_size=latent3d_dim, mid_batch_norm=False, out_dim=latent3d_dim, layers=1) self.readout_aggregators = readout_aggregators self.output2D = MLP(in_dim=hidden_dim * len(self.readout_aggregators) + latent3d_dim, hidden_size=readout_hidden_dim, mid_batch_norm=readout_batchnorm, out_dim=target_dim, layers=readout_layers)
def __init__(self, hidden_dim, target_dim, **kwargs): super(GraphRepresentation, self).__init__() self.gnn = PNAGNN(hidden_dim=hidden_dim, **kwargs) self.distance_net = MLP(in_dim=hidden_dim * 2, hidden_size=32, mid_batch_norm=True, out_dim=target_dim, layers=2)
def __init__(self, node_dim, edge_dim, hidden_dim, target_dim, readout_aggregators: List[str], batch_norm=False, node_wise_output_layers=2, readout_batchnorm=True, batch_norm_momentum=0.1, reduce_func='sum', dropout=0.0, propagation_depth: int = 4, readout_layers: int = 2, readout_hidden_dim=None, fourier_encodings=0, activation: str = 'SiLU', update_net_layers=2, message_net_layers=2, **kwargs): super(Net3D, self).__init__() self.fourier_encodings = fourier_encodings self.input = MLP( in_dim=node_dim, hidden_size=hidden_dim, out_dim=hidden_dim, mid_batch_norm=batch_norm, last_batch_norm=batch_norm, batch_norm_momentum=batch_norm_momentum, layers=1, mid_activation=activation, dropout=dropout, last_activation=activation, ) edge_in_dim = 1 if fourier_encodings == 0 else 2 * fourier_encodings + 1 self.edge_input = MLP( in_dim=edge_in_dim, hidden_size=hidden_dim, out_dim=hidden_dim, mid_batch_norm=batch_norm, last_batch_norm=batch_norm, batch_norm_momentum=batch_norm_momentum, layers=1, mid_activation=activation, dropout=dropout, last_activation=activation, ) self.mp_layers = nn.ModuleList() for _ in range(propagation_depth): self.mp_layers.append( Net3DLayer(node_dim, edge_dim=hidden_dim, hidden_dim=hidden_dim, batch_norm=batch_norm, batch_norm_momentum=batch_norm_momentum, dropout=dropout, mid_activation=activation, reduce_func=reduce_func, message_net_layers=message_net_layers, update_net_layers=update_net_layers)) self.node_wise_output_layers = node_wise_output_layers if self.node_wise_output_layers > 0: self.node_wise_output_network = MLP( in_dim=hidden_dim, hidden_size=hidden_dim, out_dim=hidden_dim, mid_batch_norm=batch_norm, last_batch_norm=batch_norm, batch_norm_momentum=batch_norm_momentum, layers=node_wise_output_layers, mid_activation=activation, dropout=dropout, last_activation='None', ) self.node_embedding = nn.Parameter(torch.empty((hidden_dim, ))) nn.init.normal_(self.node_embedding) if readout_hidden_dim == None: readout_hidden_dim = hidden_dim self.readout_aggregators = readout_aggregators self.output = MLP(in_dim=hidden_dim * len(self.readout_aggregators), hidden_size=readout_hidden_dim, mid_batch_norm=readout_batchnorm, batch_norm_momentum=batch_norm_momentum, out_dim=target_dim, layers=readout_layers)
def __init__(self, node_dim, edge_dim, hidden_dim, target_dim, readout_aggregators: List[str], batch_norm=False, node_wise_output_layers=2, readout_batchnorm=True, batch_norm_momentum=0.1, reduce_func='sum', dropout=0.0, readout_layers: int = 2, readout_hidden_dim=None, fourier_encodings=0, activation: str = 'SiLU', **kwargs): super(DistanceAggregator, self).__init__() self.fourier_encodings = fourier_encodings if reduce_func == 'sum': self.reduce_func = fn.sum elif reduce_func == 'mean': self.reduce_func = fn.mean else: raise ValueError('reduce function not supportet: ', reduce_func) edge_in_dim = 1 if fourier_encodings == 0 else 2 * fourier_encodings + 1 self.edge_input = MLP( in_dim=edge_in_dim, hidden_size=hidden_dim, out_dim=hidden_dim, mid_batch_norm=batch_norm, last_batch_norm=batch_norm, batch_norm_momentum=batch_norm_momentum, layers=1, mid_activation=activation, dropout=dropout, last_activation=activation, ) self.node_wise_output_layers = node_wise_output_layers if self.node_wise_output_layers > 0: self.node_wise_output_network = MLP( in_dim=hidden_dim, hidden_size=hidden_dim, out_dim=hidden_dim, mid_batch_norm=batch_norm, last_batch_norm=batch_norm, batch_norm_momentum=batch_norm_momentum, layers=node_wise_output_layers, mid_activation=activation, dropout=dropout, last_activation='None', ) if readout_hidden_dim == None: readout_hidden_dim = hidden_dim self.readout_aggregators = readout_aggregators self.output = MLP(in_dim=hidden_dim * len(self.readout_aggregators), hidden_size=readout_hidden_dim, mid_batch_norm=readout_batchnorm, batch_norm_momentum=batch_norm_momentum, out_dim=target_dim, layers=readout_layers)