Ejemplo n.º 1
0
    def forward(self, gw):
        x = self._atom_encoder(gw)
        patch_repr = []
        for i in range(self.num_layers):
            e = self._bond_encoder(gw, name='l%d'%i)
            x = gin_layer(gw, x, e, 'gin_%s' % i)
            x = L.batch_norm(
                x, param_attr=F.ParamAttr(name='batchnorm_%s' % i))
            patch_repr.append(x)  # $h_i^{(k)}$

        patch_summary = L.concat(patch_repr, axis=1)  # $h_{\phi}^i$
        patch_pool = [pgl.layers.graph_pooling(gw, x, 'sum')
                      for x in patch_repr]
        global_repr = L.concat(patch_pool, axis=1)
        return global_repr, patch_summary
Ejemplo n.º 2
0
    def forward(self, graph_wrapper, is_test=False):
        """
        Build the network.
        """
        node_features = self._mol_encoder(graph_wrapper, name=self.name)

        features_list = [node_features]
        for layer in range(self.layer_num):
            edge_features = self._bond_encoder(
                    graph_wrapper, 
                    name='%s_layer%s' % (self.name, layer))
            if self.gnn_type == "gcn":
                feat = gcn_layer(
                        graph_wrapper,
                        features_list[layer],
                        edge_features,
                        act="relu",
                        name="%s_layer%s_gcn" % (self.name, layer))
            elif self.gnn_type == "gat":
                feat = gat_layer(
                        graph_wrapper, 
                        features_list[layer],
                        edge_features,
                        self.embed_dim,
                        act="relu",
                        name="%s_layer%s_gat" % (self.name, layer))
            else:
                feat = gin_layer(
                        graph_wrapper,
                        features_list[layer],
                        edge_features,
                        name="%s_layer%s_gin" % (self.name, layer))

            if self.norm_type == 'batch_norm':
                feat = layers.batch_norm(
                        feat, 
                        param_attr=fluid.ParamAttr(
                            name="%s_layer%s_batch_norm_scale" % (self.name, layer),
                            initializer=fluid.initializer.Constant(1.0)),
                        bias_attr=fluid.ParamAttr(
                            name="%s_layer%s_batch_norm_bias" % (self.name, layer),
                            initializer=fluid.initializer.Constant(0.0)),
                        moving_mean_name="%s_layer%s_batch_norm_moving_avearage" % (self.name, layer),
                        moving_variance_name="%s_layer%s_batch_norm_moving_variance" % (self.name, layer),
                        is_test=is_test)
            elif self.norm_type == 'layer_norm':
                feat = layers.layer_norm(
                        feat, 
                        param_attr=fluid.ParamAttr(
                            name="%s_layer%s_layer_norm_scale" % (self.name, layer),
                            initializer=fluid.initializer.Constant(1.0)),
                        bias_attr=fluid.ParamAttr(
                            name="%s_layer%s_layer_norm_bias" % (self.name, layer),
                            initializer=fluid.initializer.Constant(0.0)))
            else:
                raise ValueError('%s not supported.' % self.norm_type)

            if self.graph_norm:
                feat = pgl.layers.graph_norm(graph_wrapper, feat)

            if layer < self.layer_num - 1:
                feat = layers.relu(feat)
            feat = layers.dropout(
                    feat,
                    self.dropout_rate,
                    dropout_implementation="upscale_in_train",
                    is_test=is_test)

            # residual
            if self.residual:
                feat = feat + features_list[layer]

            features_list.append(feat)

        if self.JK == "sum":
            node_repr = layers.reduce_sum(features_list, axis=0)
        elif self.JK == "mean":
            node_repr = layers.reduce_mean(features_list, axis=0)
        elif self.JK == "last":
            node_repr = features_list[-1]
        else:
            node_repr = features_list[-1]
        return node_repr