コード例 #1
0
ファイル: smp_cycles.py プロジェクト: vijaydwivedi75/SMP
    def forward(self, data):
        """ data.x: (num_nodes, num_features)"""
        x, edge_index = data.x, data.edge_index

        batch_info = create_batch_info(data, self.edge_counter)

        # Create the context matrix
        if self.use_x:
            u = x
        else:
            u = data.x.new_zeros((data.num_nodes, batch_info['n_colors']))
            u.scatter_(1, data.coloring, 1)
            u = u[..., None]

        # Forward pass
        out = self.no_prop(u, batch_info)
        u = self.initial_lin(u)
        for i, (conv, bn, extractor) in enumerate(
                zip(self.convs, self.batch_norm_list,
                    self.feature_extractors)):
            if self.use_batch_norm and i > 0:
                u = bn(u)
            u = conv(u, edge_index, batch_info)
            global_features = extractor.forward(u, batch_info)
            out += global_features / len(self.convs)

        # Two layer MLP with dropout and residual connections:
        out = torch.relu(self.after_conv(out)) + out
        out = F.dropout(out, p=self.dropout_prob, training=self.training)
        out = self.final_lin(out)
        return F.log_softmax(out, dim=-1)
コード例 #2
0
ファイル: smp_multi_task.py プロジェクト: vijaydwivedi75/SMP
    def forward(self, data):
        """ data.x: (num_nodes, num_features)"""
        x, edge_index, batch, batch_size = data.x, data.edge_index, data.batch, data.num_graphs

        batch_info = create_batch_info(data, self.edge_counter)

        # Create the context matrix
        u = data.x.new_zeros((data.num_nodes, batch_info['n_colors']))
        u.scatter_(1, data.coloring, 1)
        u = u[..., None]

        # Map x to u
        shortest_path_ids = x[:, 0]
        lap_feat = x[:, 1]
        u_shortest_path = torch.zeros_like(u)
        u_lap_feat = torch.zeros_like(u)
        non_zero = shortest_path_ids.nonzero()[:, 0]
        nonzero_batch = batch_info['batch'][non_zero]
        nonzero_color = batch_info['coloring'][non_zero][:, 0]
        for b, c in zip(nonzero_batch, nonzero_color):
            u_shortest_path[batch == b, c] = 1

        for i, feat in enumerate(lap_feat):
            u_lap_feat[i, batch_info['coloring'][i]] = feat

        u = torch.cat((u, u_shortest_path, u_lap_feat), dim=2)

        # Forward pass
        u = self.initial_lin_u(u)
        hidden_state = None
        for i, (conv, bn_u) in enumerate(zip(self.convs,  self.batch_norm_u)):
            if i > 0:
                u = bn_u(u)
            u = conv(u, edge_index, batch_info, self.debug_model)
            extracted = self.extractor(x, u, batch_info)[None, :, :]
            hidden_state = self.gru(extracted, hidden_state)[1]

        # Compute the final representation
        out = hidden_state[0, :, :]
        nodes_out = self.final_node(out)
        after_set2set = self.set2set(out, batch_info['batch'])
        graph_out = self.final_graph(after_set2set)

        return nodes_out, graph_out
コード例 #3
0
    def forward(self, data):
        """ data.x: (num_nodes, num_node_features)"""
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        # Compute information about the batch
        batch_info = create_batch_info(data, self.edge_counter)

        # Create the context matrix
        if self.use_x:
            assert x is not None
            u = x
        elif self.map_x_to_u:
            u = map_x_to_u(data, batch_info)
        else:
            u = data.x.new_zeros((data.num_nodes, batch_info['n_colors']))
            u.scatter_(1, data.coloring, 1)
            u = u[..., None]

        # Forward pass
        out = self.no_prop(u, batch_info)
        u = self.initial_lin(u)
        for i in range(len(self.convs)):
            conv = self.convs[i]
            bn = self.batch_norm_list[i]
            extractor = self.feature_extractor if self.shared_extractor else self.feature_extractors[
                i]
            if self.use_batch_norm and i > 0:
                u = bn(u)
            u = conv(u, edge_index, edge_attr,
                     batch_info) + (u if self.residual else 0)
            global_features = extractor.forward(u, batch_info)
            out += global_features / len(self.convs)

        out = self.final_lin(torch.relu(self.after_conv(out)) + out)
        assert out.shape[1] == 1
        return out[:, 0]