Beispiel #1
0
 def forward_onestep(self, data):
     input_features_list = unbatch_node_feature(data, 'x', data.batch) # list
     graph_batch_list = unbatch_node_feature(data, 'graph_batch', data.batch)
     input_features = list(itertools.chain.from_iterable([unbatch_node_feature_mat(input_features_i, graph_batch_i)
                       for input_features_i, graph_batch_i in zip(input_features_list, graph_batch_list)]))
     input_features = torch.stack(input_features, dim=0)
     input_features = input_features.reshape(input_features.shape[0], -1)
     out = self.var_layer(input_features)
     out = out.reshape(input_features.shape[0], self.node_num, self.output_dim).flatten(0, 1)
     out = out + data.x[:, -1, -self.output_dim:]
     out_graph = replace_graph(data, x=out)
     return out_graph
Beispiel #2
0
    def forward_onestep(self, data):
        input_features_list = unbatch_node_feature(data, 'x', data.batch)  # list
        graph_batch_list = unbatch_node_feature(data, 'graph_batch', data.batch)
        input_features = list(itertools.chain.from_iterable([unbatch_node_feature_mat(input_features_i, graph_batch_i)
                                                             for input_features_i, graph_batch_i in
                                                             zip(input_features_list, graph_batch_list)]))
        input_features = torch.stack(input_features, dim=0)
        input_features = input_features.transpose(1, 2).flatten(2, 3)

        if not hasattr(data, 'node_hidden'):
            data = replace_graph(data,
                node_hidden=data.x.new_zeros(self.num_layers, input_features.shape[0], self.rnn.hidden_size))
        node_hidden_output, node_hidden_next = self.rnn(input_features, data.node_hidden)
        node_hidden_output = self.decoder(node_hidden_output)
        node_hidden_output = node_hidden_output.reshape(input_features.shape[0], self.node_num, self.output_dim).flatten(0, 1)
        node_output = node_hidden_output + data.x[:, -1, -self.output_dim:]
        output_graph = replace_graph(data, x=node_output, node_hidden=node_hidden_next)
        return output_graph
Beispiel #3
0
 def unpadding_edge_output(output, data):
     edge_batch = data.batch[data.edge_index[0, :]]
     list_output = unbatch_node_feature_mat(output, edge_batch)
     for si in range(data.tlens.shape[0]):
         list_output[si] = list_output[si][:, :data.tlens[si], :]
     return list_output
Beispiel #4
0
 def unpadding_output(output, data):
     list_output = unbatch_node_feature_mat(output, data.batch)
     for si in range(data.tlens.shape[0]):
         list_output[si] = list_output[si][:, :data.tlens[si], :]
     return list_output