示例#1
0
文件: r_unimp.py 项目: WenjinW/PGL
    def forward(self, graph_list, feature, m2v_feature, label_y, label_idx):
        m2v_fc = self.input_drop(self.m2v_fc(m2v_feature))
        feature = feature + m2v_fc

        label_embed = self.label_embed(label_y)
        label_embed = self.input_drop(label_embed)
        feature_label = paddle.gather(feature, label_idx)
        label_embed = paddle.concat([label_embed, feature_label], axis=1)
        label_embed = self.label_mlp(label_embed)
        feature = paddle.scatter(feature,
                                 label_idx,
                                 label_embed,
                                 overwrite=True)

        for idx, (sg, sub_index) in enumerate(graph_list):
            temp_feat = []
            skip_feat = paddle.gather(feature, sub_index, axis=0)
            skip_feat = self.skips[idx](skip_feat)
            skip_feat = self.norms[idx][0](skip_feat)
            skip_feat = F.elu(skip_feat)
            temp_feat.append(skip_feat)

            for i in range(self.edge_type):
                masked = sg.edge_feat['edge_type'] == i
                m_sg = self.get_subgraph_by_masked(sg, masked)
                if m_sg is not None:
                    feature_temp = self.gats[idx][i](m_sg, feature)
                    feature_temp = paddle.gather(feature_temp,
                                                 sub_index,
                                                 axis=0)
                    feature_temp = self.norms[idx][i + 1](feature_temp)
                    feature_temp = F.elu(feature_temp)
                    #skip_feat += feature_temp
                    temp_feat.append(feature_temp)
            temp_feat = paddle.stack(temp_feat, axis=1)  # b x 3 x dim
            temp_feat_attn = self.path_attns[idx](temp_feat)  # b x 3 x 1
            temp_feat_attn = F.softmax(temp_feat_attn, axis=1)
            temp_feat_attn = paddle.transpose(temp_feat_attn,
                                              perm=[0, 2, 1])  # b x 1 x 3
            skip_feat = paddle.bmm(temp_feat_attn, temp_feat)[:, 0]
            skip_feat = self.path_norms[idx](skip_feat)
            #feature = F.elu(skip_feat)
            feature = self.dropout(skip_feat)
        output = self.mlp(feature)
        return output
示例#2
0
    def forward(self, x):
        out_size = x.shape[2:]
        for i in range(3):
            out_size[i] = self.scale_factor[i] * out_size[i]

        return F.elu(
            self.bn(
                self.conv3d(
                    F.interpolate(x,
                                  size=out_size,
                                  mode='trilinear',
                                  align_corners=False,
                                  data_format='NCDHW',
                                  align_mode=0))))
示例#3
0
    def forward(self, graph_list, feature):
        for idx, (sg, sub_index) in enumerate(graph_list):
            #feature = paddle.gather(feature, sub_index, axis=0)
            skip_feat = paddle.gather(feature, sub_index, axis=0)
            skip_feat = self.skips[idx](skip_feat)

            for i in range(self.edge_type):
                masked = sg.edge_feat['edge_type'] == i
                m_sg = self.get_subgraph_by_masked(sg, masked)
                if m_sg is not None:
                    feature_temp = self.gats[idx][i](m_sg, feature)
                    feature_temp = paddle.gather(feature_temp,
                                                 sub_index,
                                                 axis=0)
                    skip_feat += feature_temp

            feature = F.elu(self.norms[idx](skip_feat))
            feature = self.dropout(feature)
        output = self.mlp(feature)
        return output
示例#4
0
    def forward(self, graphs, padded_feats):

        out = self.aa_emb(graphs.node_feat["seq"])

        if self.lm_model is not None:
            pass  # TODO: Sum output from lm_model using 'padded_feats' as input and variable 'out' above.

        gcnn_concat = []
        for gcnn in self.gcnn_list:
            out = gcnn(graphs, out)
            out = F.elu(out)
            gcnn_concat.append(out)

        out = paddle.concat(gcnn_concat, axis=1)
        out = self.global_pool(graphs, out)

        for fc in self.fc_list:
            out = fc(out)
            out = F.relu(out)
            out = F.dropout(out, p=self.drop, training=self.training)

        out = self.func_predictor(out)

        return out
示例#5
0
 def forward(self, x):
     return F.elu(self.bn(self.conv3d(x)))
示例#6
0
 def projection_3d(self, z):
     z = F.elu(self.proj_3d_fc1(z))
     return self.proj_3d_fc2(z)