Ejemplo n.º 1
0
    def get_feat(self, graph_list):
        node_feat, edge_feat, edge_from_idx, edge_to_idx, g_idx = prepare_gnn(
            graph_list, self.is_cuda())
        input_node_linear = self.w_n2l(node_feat)
        input_message = input_node_linear
        if edge_feat is not None:
            input_edge_linear = self.w_e2l(edge_feat)
            e2npool_input = scatter_add(input_edge_linear,
                                        edge_to_idx,
                                        dim=0,
                                        dim_size=node_feat.shape[0])
            input_message += e2npool_input
        input_potential = self.act_func(input_message)

        cur_message_layer = input_potential
        all_embeds = [cur_message_layer]
        edge_index = [edge_from_idx, edge_to_idx]
        edge_index = torch.stack(edge_index)
        for lv in range(self.max_lv):
            node_linear = self.conv_layers[lv](cur_message_layer, edge_index)
            merged_linear = node_linear + input_message
            cur_message_layer = self.act_func(merged_linear)
            all_embeds.append(cur_message_layer)

        return self.readout_net(all_embeds, g_idx, len(graph_list))
Ejemplo n.º 2
0
Archivo: s2v.py Proyecto: phymucs/GLN
    def get_feat(self, graph_list):
        node_feat, edge_feat, edge_from_idx, edge_to_idx, g_idx = prepare_gnn(
            graph_list, self.is_cuda())
        input_node_linear = self.w_n2l(node_feat)
        input_message = input_node_linear
        if edge_feat is not None:
            input_edge_linear = self.w_e2l[0](edge_feat)
            e2npool_input = scatter_add(input_edge_linear,
                                        edge_to_idx,
                                        dim=0,
                                        dim_size=node_feat.shape[0])
            input_message += e2npool_input
        input_potential = self.act_func(input_message)
        input_potential = self.msg_bn[0](input_potential)

        cur_message_layer = input_potential
        all_embeds = [cur_message_layer]
        edge_index = [edge_from_idx, edge_to_idx]
        for lv in range(self.max_lv):
            node_linear = self.conv_layers[lv](cur_message_layer, edge_index)
            edge_linear = self.w_e2l[lv + 1](edge_feat)
            e2npool_input = scatter_add(edge_linear,
                                        edge_to_idx,
                                        dim=0,
                                        dim_size=node_linear.shape[0])
            merged_hidden = self.act_func(node_linear + e2npool_input)
            merged_hidden = self.hidden_bn[lv](merged_hidden)
            residual_out = self.conv_l2[lv](merged_hidden) + cur_message_layer
            cur_message_layer = self.act_func(residual_out)
            cur_message_layer = self.msg_bn[lv + 1](cur_message_layer)
            all_embeds.append(cur_message_layer)
        return self.readout_net(all_embeds, g_idx, len(graph_list))
Ejemplo n.º 3
0
    def get_feat(self, graph_list):
        node_feat, edge_feat, edge_from_idx, edge_to_idx, g_idx = prepare_gnn(
            graph_list, self.is_cuda())
        out = self.act_func(self.lin0(node_feat))
        h = out.unsqueeze(0)
        edge_index = [edge_from_idx, edge_to_idx]
        edge_index = torch.stack(edge_index)
        for lv in range(self.max_lv):
            m = self.act_func(self.conv(out, edge_index, edge_feat))
            out, h = self.gru(m.unsqueeze(0), h)
            out = out.squeeze(0)
        out = self.set2set(out, g_idx)
        out = self.readout(out)

        return out, None
Ejemplo n.º 4
0
Archivo: ggnn.py Proyecto: phymucs/GLN
    def get_feat(self, graph_list):
        node_feat, edge_feat, edge_from_idx, edge_to_idx, g_idx = prepare_gnn(
            graph_list, self.is_cuda())
        edge_index = [edge_from_idx, edge_to_idx]

        node_states = self.node2hidden(node_feat)
        init_embed = self.readout_funcs[-1](node_states)
        outs = self.readout_agg(init_embed,
                                g_idx,
                                dim=0,
                                dim_size=len(graph_list))
        for i in range(self.max_lv):
            layer = self.layers[i]
            new_states = layer(node_states, edge_index, edge_feat)
            node_states = new_states

            if self.out_method == 'last':
                continue

            out_states = self.readout_funcs[i](node_states)

            graph_embed = self.readout_agg(out_states,
                                           g_idx,
                                           dim=0,
                                           dim_size=len(graph_list))
            if self.out_method == 'gru':
                outs = self.final_cell(graph_embed, outs)
            else:
                outs += graph_embed

        if self.out_method == 'last':
            out_states = self.readout_funcs[0](node_states)

            graph_embed = self.readout_agg(out_states,
                                           g_idx,
                                           dim=0,
                                           dim_size=len(graph_list))
            return graph_embed, (g_idx, out_states)
        else:
            if self.out_method == 'mean':
                outs /= self.max_lv + 1
            return outs, None