Exemplo n.º 1
0
    def __init__(self,
                 node_dim,
                 edge_dim,
                 hidden_dim,
                 residual: bool = True,
                 pairwise_distances: bool = False,
                 activation: Union[Callable, str] = "relu",
                 last_activation: Union[Callable, str] = "none",
                 mid_batch_norm: bool = False,
                 last_batch_norm: bool = False,
                 propagation_depth: int = 5,
                 dropout: float = 0.0,
                 posttrans_layers: int = 1,
                 pretrans_layers: int = 1,
                 **kwargs):
        super(MPNNGNN, self).__init__()
        self.node_input_net = MLP(
            in_dim=node_dim,
            hidden_size=hidden_dim,
            out_dim=hidden_dim,
            mid_batch_norm=mid_batch_norm,
            last_batch_norm=last_batch_norm,
            layers=1,
            mid_activation='relu',
            dropout=dropout,
            last_activation=last_activation,
        )
        if edge_dim > 0:
            self.edge_input = MLP(
                in_dim=edge_dim,
                hidden_size=hidden_dim,
                out_dim=hidden_dim,
                mid_batch_norm=mid_batch_norm,
                last_batch_norm=last_batch_norm,
                layers=1,
                mid_activation='relu',
                dropout=dropout,
                last_activation=last_activation,
            )
        self.mp_layers = nn.ModuleList()
        for _ in range(propagation_depth):
            self.mp_layers.append(MPNNLayer(in_dim=hidden_dim,
                                           out_dim=int(hidden_dim),
                                           in_dim_edges=edge_dim,
                                           pairwise_distances=pairwise_distances,
                                           residual=residual,
                                           dropout=dropout,
                                           activation=activation,
                                           last_activation=last_activation,
                                           mid_batch_norm=mid_batch_norm,
                                           last_batch_norm=last_batch_norm,
                                           posttrans_layers=posttrans_layers,
                                           pretrans_layers=pretrans_layers,
                                           ),

                                  )
    def __init__(self, hidden_dim, target_dim, projection_dim  = 3, distance_net=False, projection_layers=1, **kwargs):
        super(DistancePredictor, self).__init__()
        self.gnn = PNAGNN(hidden_dim=hidden_dim, **kwargs)

        if projection_dim > 0:
            self.node_projection_net = MLP(in_dim=hidden_dim, hidden_size=32,
                                    mid_batch_norm=True, out_dim=projection_dim,
                                    layers=projection_layers)
        else:
            self.node_projection_net = None
        if distance_net:
            self.distance_net = MLP(in_dim=hidden_dim * 2, hidden_size=projection_dim,
                                mid_batch_norm=True, out_dim=target_dim,
                                layers=projection_layers)
        else:
            self.distance_net = None
Exemplo n.º 3
0
    def __init__(self, node_dim, edge_dim, batch_norm_momentum, residual,
                 in_feat_dropout, dropout, layer_norm, batch_norm, gamma,
                 full_graph, GT_hidden_dim, GT_n_heads, GT_out_dim, GT_layers,
                 LPE_n_heads, LPE_layers, LPE_dim, **kwargs):
        super().__init__()

        self.residual = residual

        self.layer_norm = layer_norm
        self.batch_norm = batch_norm

        self.in_feat_dropout = nn.Dropout(in_feat_dropout)

        self.embedding_h = MLP(in_dim=node_dim,
                               hidden_size=node_dim,
                               layers=1,
                               out_dim=GT_hidden_dim - LPE_dim)
        self.embedding_e_real = MLP(in_dim=edge_dim,
                                    hidden_size=edge_dim,
                                    layers=1,
                                    out_dim=GT_hidden_dim)
        self.embedding_e_fake = MLP(in_dim=edge_dim,
                                    hidden_size=edge_dim,
                                    layers=1,
                                    out_dim=GT_hidden_dim)

        self.linear_A = nn.Linear(2, LPE_dim)

        encoder_layer = nn.TransformerEncoderLayer(d_model=LPE_dim,
                                                   nhead=LPE_n_heads)
        self.PE_Transformer = nn.TransformerEncoder(encoder_layer,
                                                    num_layers=LPE_layers)

        self.layers = nn.ModuleList([
            GraphTransformerLayer(gamma, GT_hidden_dim, GT_hidden_dim,
                                  GT_n_heads, full_graph, dropout,
                                  self.layer_norm, self.batch_norm,
                                  self.residual) for _ in range(GT_layers - 1)
        ])

        self.layers.append(
            GraphTransformerLayer(gamma, GT_hidden_dim, GT_out_dim, GT_n_heads,
                                  full_graph, dropout, self.layer_norm,
                                  self.batch_norm, self.residual))
Exemplo n.º 4
0
    def __init__(self,
                 in_dim: int,
                 out_dim: int,
                 in_dim_edges: int,
                 activation: Union[Callable, str] = "relu",
                 last_activation: Union[Callable, str] = "none",
                 dropout: float = 0.0,
                 residual: bool = True,
                 pairwise_distances: bool = False,
                 mid_batch_norm: bool = False,
                 last_batch_norm: bool = False,
                 posttrans_layers: int = 2,
                 pretrans_layers: int = 1, ):
        super(MPNNLayer, self).__init__()

        self.edge_features = in_dim_edges > 0
        self.activation = activation
        self.pairwise_distances = pairwise_distances
        self.residual = residual
        if in_dim != out_dim:
            self.residual = False

        self.pretrans = MLP(
            in_dim=(2 * in_dim + in_dim_edges + 1) if self.pairwise_distances else (2 * in_dim + in_dim_edges),
            hidden_size=in_dim,
            out_dim=in_dim,
            mid_batch_norm=mid_batch_norm,
            last_batch_norm=last_batch_norm,
            layers=pretrans_layers,
            mid_activation='relu',
            dropout=dropout,
            last_activation=last_activation,
        )
        self.posttrans = MLP(
            in_dim=in_dim,
            hidden_size=out_dim,
            out_dim=out_dim,
            layers=posttrans_layers,
            mid_activation=self.activation,
            last_activation=last_activation,
            dropout=dropout,
            mid_batch_norm=mid_batch_norm,
            last_batch_norm=last_batch_norm,
        )
Exemplo n.º 5
0
 def __init__(self, GT_out_dim, readout_hidden_dim, readout_batchnorm,
              readout_aggregators, target_dim, readout_layers,
              batch_norm_momentum, **kwargs):
     super().__init__()
     self.readout_aggregators = readout_aggregators
     self.gnn = SAN_NodeLPE(GT_out_dim=GT_out_dim,
                            batch_norm_momentum=batch_norm_momentum,
                            **kwargs)
     self.output = MLP(in_dim=GT_out_dim * len(self.readout_aggregators),
                       hidden_size=readout_hidden_dim,
                       mid_batch_norm=readout_batchnorm,
                       out_dim=target_dim,
                       layers=readout_layers,
                       batch_norm_momentum=batch_norm_momentum)
    def __init__(self,
                 hidden_dim,
                 target_dim,
                 batch_norm=False,
                 readout_batchnorm=True,
                 batch_norm_momentum=0.1,
                 dropout=0.0,
                 readout_layers: int = 2,
                 readout_hidden_dim=None,
                 fourier_encodings=0,
                 activation: str = 'SiLU',
                 **kwargs):
        super(DistanceEncoder, self).__init__()
        self.fourier_encodings = fourier_encodings

        input_dim = 1 if fourier_encodings == 0 else 2 * fourier_encodings + 1
        self.input_net = MLP(
            in_dim=input_dim,
            hidden_size=hidden_dim,
            out_dim=hidden_dim,
            mid_batch_norm=batch_norm,
            last_batch_norm=batch_norm,
            batch_norm_momentum=batch_norm_momentum,
            layers=1,
            mid_activation=activation,
            dropout=dropout,
            last_activation=activation,
        )

        if readout_hidden_dim == None:
            readout_hidden_dim = hidden_dim
        self.output = MLP(in_dim=hidden_dim * 4,
                          hidden_size=readout_hidden_dim,
                          mid_batch_norm=readout_batchnorm,
                          batch_norm_momentum=batch_norm_momentum,
                          out_dim=target_dim,
                          layers=readout_layers)
Exemplo n.º 7
0
 def __init__(self,
              metric_dim,
              repeats,
              dropout=0.8,
              mid_batch_norm=True,
              last_batch_norm=True,
              **kwargs):
     super(BasicCritic, self).__init__()
     self.repeats = repeats
     self.dropout = dropout
     self.criticise = MLP(in_dim=metric_dim * repeats,
                          hidden_size=metric_dim * repeats,
                          mid_batch_norm=mid_batch_norm,
                          out_dim=metric_dim * repeats,
                          last_batch_norm=last_batch_norm,
                          dropout=0,
                          layers=2)
Exemplo n.º 8
0
    def __init__(self,
                 node_dim,
                 edge_dim,
                 hidden_dim,
                 aggregators: List[str],
                 scalers: List[str],
                 activation: Union[Callable, str] = "relu",
                 last_activation: Union[Callable, str] = "none",
                 mid_batch_norm: bool = False,
                 last_batch_norm: bool = False,
                 propagation_depth: int = 5,
                 dropout: float = 0.0,
                 posttrans_layers: int = 1,
                 pretrans_layers: int = 1,
                 **kwargs):
        super(DGNGNN, self).__init__()
        self.node_input_net = MLP(
            in_dim=node_dim,
            hidden_size=hidden_dim,
            out_dim=hidden_dim,
            mid_batch_norm=mid_batch_norm,
            last_batch_norm=last_batch_norm,
            layers=1,
            mid_activation='relu',
            dropout=dropout,
            last_activation=last_activation,
        )

        self.mp_layers = nn.ModuleList()
        for _ in range(propagation_depth):
            self.mp_layers.append(DGNMessagePassingLayer(in_dim=hidden_dim,
                                                         out_dim=hidden_dim,
                                                         in_dim_edges=edge_dim,
                                                         aggregators=aggregators,
                                                         scalers=scalers,
                                                         dropout=dropout,
                                                         activation=activation,
                                                         last_activation=last_activation,
                                                         avg_d={"log": 1.0},
                                                         posttrans_layers=posttrans_layers,
                                                         pretrans_layers=pretrans_layers,
                                                         ),

                                  )
Exemplo n.º 9
0
 def __init__(
         self,
         model_type,
         model_parameters,
         predictor_layers=1,
         predictor_hidden_size=256,
         predictor_batchnorm=False,
         metric_dim=256,
         ma_decay=0.99,  #moving average decay
         **kwargs):
     super(BYOLwrapper, self).__init__()
     self.student = globals()[model_type](**model_parameters, **kwargs)
     self.teacher = copy.deepcopy(self.student)
     self.predictor_layers = predictor_layers
     if predictor_layers > 0:
         self.predictor = MLP(in_dim=model_parameters['target_dim'],
                              hidden_size=predictor_hidden_size,
                              mid_batch_norm=predictor_batchnorm,
                              out_dim=metric_dim,
                              layers=predictor_layers)
     self.ma_decay = ma_decay
     for p in self.teacher.parameters():
         p.requires_grad = False
Exemplo n.º 10
0
    def __init__(self,
                 node_dim,
                 edge_dim,
                 hidden_dim,
                 target_dim,
                 aggregators: List[str],
                 scalers: List[str],
                 readout_aggregators: List[str],
                 frozen_readout_aggregators: List[str],
                 latent3d_dim: int = 256,
                 readout_batchnorm: bool = True,
                 readout_hidden_dim=None,
                 readout_layers: int = 2,
                 residual: bool = True,
                 pairwise_distances: bool = False,
                 activation: Union[Callable, str] = "relu",
                 last_activation: Union[Callable, str] = "none",
                 mid_batch_norm: bool = False,
                 last_batch_norm: bool = False,
                 propagation_depth: int = 5,
                 dropout: float = 0.0,
                 posttrans_layers: int = 1,
                 pretrans_layers: int = 1,
                 **kwargs):
        super(PNAFrozenCombined, self).__init__()
        # the pretrained GNN
        self.node_gnn = PNAGNN(node_dim=node_dim,
                               edge_dim=edge_dim,
                               hidden_dim=hidden_dim,
                               aggregators=aggregators,
                               scalers=scalers,
                               residual=residual,
                               pairwise_distances=pairwise_distances,
                               activation=activation,
                               last_activation=last_activation,
                               mid_batch_norm=mid_batch_norm,
                               last_batch_norm=last_batch_norm,
                               propagation_depth=propagation_depth,
                               dropout=dropout,
                               posttrans_layers=posttrans_layers,
                               pretrans_layers=pretrans_layers)
        self.node_gnn2D = PNAGNN(node_dim=node_dim,
                                 edge_dim=edge_dim,
                                 hidden_dim=hidden_dim,
                                 aggregators=aggregators,
                                 scalers=scalers,
                                 residual=residual,
                                 pairwise_distances=pairwise_distances,
                                 activation=activation,
                                 last_activation=last_activation,
                                 mid_batch_norm=mid_batch_norm,
                                 last_batch_norm=last_batch_norm,
                                 propagation_depth=propagation_depth,
                                 dropout=dropout,
                                 posttrans_layers=posttrans_layers,
                                 pretrans_layers=pretrans_layers)

        if readout_hidden_dim == None:
            readout_hidden_dim = hidden_dim
        self.frozen_readout_aggregators = frozen_readout_aggregators
        self.output = MLP(in_dim=hidden_dim *
                          len(self.frozen_readout_aggregators),
                          hidden_size=latent3d_dim,
                          mid_batch_norm=False,
                          out_dim=latent3d_dim,
                          layers=1)
        self.readout_aggregators = readout_aggregators
        self.output2D = MLP(in_dim=hidden_dim * len(self.readout_aggregators) +
                            latent3d_dim,
                            hidden_size=readout_hidden_dim,
                            mid_batch_norm=readout_batchnorm,
                            out_dim=target_dim,
                            layers=readout_layers)
 def __init__(self, hidden_dim, target_dim, **kwargs):
     super(GraphRepresentation, self).__init__()
     self.gnn = PNAGNN(hidden_dim=hidden_dim, **kwargs)
     self.distance_net = MLP(in_dim=hidden_dim * 2, hidden_size=32,
                             mid_batch_norm=True, out_dim=target_dim,
                             layers=2)
Exemplo n.º 12
0
    def __init__(self,
                 node_dim,
                 edge_dim,
                 hidden_dim,
                 target_dim,
                 readout_aggregators: List[str],
                 batch_norm=False,
                 node_wise_output_layers=2,
                 readout_batchnorm=True,
                 batch_norm_momentum=0.1,
                 reduce_func='sum',
                 dropout=0.0,
                 propagation_depth: int = 4,
                 readout_layers: int = 2,
                 readout_hidden_dim=None,
                 fourier_encodings=0,
                 activation: str = 'SiLU',
                 update_net_layers=2,
                 message_net_layers=2,
                 **kwargs):
        super(Net3D, self).__init__()
        self.fourier_encodings = fourier_encodings
        self.input = MLP(
            in_dim=node_dim,
            hidden_size=hidden_dim,
            out_dim=hidden_dim,
            mid_batch_norm=batch_norm,
            last_batch_norm=batch_norm,
            batch_norm_momentum=batch_norm_momentum,
            layers=1,
            mid_activation=activation,
            dropout=dropout,
            last_activation=activation,
        )

        edge_in_dim = 1 if fourier_encodings == 0 else 2 * fourier_encodings + 1
        self.edge_input = MLP(
            in_dim=edge_in_dim,
            hidden_size=hidden_dim,
            out_dim=hidden_dim,
            mid_batch_norm=batch_norm,
            last_batch_norm=batch_norm,
            batch_norm_momentum=batch_norm_momentum,
            layers=1,
            mid_activation=activation,
            dropout=dropout,
            last_activation=activation,
        )
        self.mp_layers = nn.ModuleList()
        for _ in range(propagation_depth):
            self.mp_layers.append(
                Net3DLayer(node_dim,
                           edge_dim=hidden_dim,
                           hidden_dim=hidden_dim,
                           batch_norm=batch_norm,
                           batch_norm_momentum=batch_norm_momentum,
                           dropout=dropout,
                           mid_activation=activation,
                           reduce_func=reduce_func,
                           message_net_layers=message_net_layers,
                           update_net_layers=update_net_layers))

        self.node_wise_output_layers = node_wise_output_layers
        if self.node_wise_output_layers > 0:
            self.node_wise_output_network = MLP(
                in_dim=hidden_dim,
                hidden_size=hidden_dim,
                out_dim=hidden_dim,
                mid_batch_norm=batch_norm,
                last_batch_norm=batch_norm,
                batch_norm_momentum=batch_norm_momentum,
                layers=node_wise_output_layers,
                mid_activation=activation,
                dropout=dropout,
                last_activation='None',
            )
        self.node_embedding = nn.Parameter(torch.empty((hidden_dim, )))
        nn.init.normal_(self.node_embedding)
        if readout_hidden_dim == None:
            readout_hidden_dim = hidden_dim
        self.readout_aggregators = readout_aggregators
        self.output = MLP(in_dim=hidden_dim * len(self.readout_aggregators),
                          hidden_size=readout_hidden_dim,
                          mid_batch_norm=readout_batchnorm,
                          batch_norm_momentum=batch_norm_momentum,
                          out_dim=target_dim,
                          layers=readout_layers)
    def __init__(self,
                 node_dim,
                 edge_dim,
                 hidden_dim,
                 target_dim,
                 readout_aggregators: List[str],
                 batch_norm=False,
                 node_wise_output_layers=2,
                 readout_batchnorm=True,
                 batch_norm_momentum=0.1,
                 reduce_func='sum',
                 dropout=0.0,
                 readout_layers: int = 2,
                 readout_hidden_dim=None,
                 fourier_encodings=0,
                 activation: str = 'SiLU',
                 **kwargs):
        super(DistanceAggregator, self).__init__()
        self.fourier_encodings = fourier_encodings

        if reduce_func == 'sum':
            self.reduce_func = fn.sum
        elif reduce_func == 'mean':
            self.reduce_func = fn.mean
        else:
            raise ValueError('reduce function not supportet: ', reduce_func)

        edge_in_dim = 1 if fourier_encodings == 0 else 2 * fourier_encodings + 1
        self.edge_input = MLP(
            in_dim=edge_in_dim,
            hidden_size=hidden_dim,
            out_dim=hidden_dim,
            mid_batch_norm=batch_norm,
            last_batch_norm=batch_norm,
            batch_norm_momentum=batch_norm_momentum,
            layers=1,
            mid_activation=activation,
            dropout=dropout,
            last_activation=activation,
        )

        self.node_wise_output_layers = node_wise_output_layers
        if self.node_wise_output_layers > 0:
            self.node_wise_output_network = MLP(
                in_dim=hidden_dim,
                hidden_size=hidden_dim,
                out_dim=hidden_dim,
                mid_batch_norm=batch_norm,
                last_batch_norm=batch_norm,
                batch_norm_momentum=batch_norm_momentum,
                layers=node_wise_output_layers,
                mid_activation=activation,
                dropout=dropout,
                last_activation='None',
            )
        if readout_hidden_dim == None:
            readout_hidden_dim = hidden_dim
        self.readout_aggregators = readout_aggregators
        self.output = MLP(in_dim=hidden_dim * len(self.readout_aggregators),
                          hidden_size=readout_hidden_dim,
                          mid_batch_norm=readout_batchnorm,
                          batch_norm_momentum=batch_norm_momentum,
                          out_dim=target_dim,
                          layers=readout_layers)