Ejemplo n.º 1
0
    def forward_onestep_linear(self, data, return_features=False):
        data_input = replace_graph(data, x=data.x[:, -1, :])
        if self.node_meta_dim > 0:
            node_meta = data.node_meta  # should be N x 2 (x, y)
        else:
            node_meta = None

        op_calc = self._calculate_op(
            data_input, node_meta,
            self.optype)  # N x self.parameterized_order_num x input_dim
        op_calc = op_calc['op_agg_msgs']
        batch_op_calc = op_calc.view(-1, self.node_num,
                                     self.parameterized_order_num,
                                     self.input_dim)
        delta_input = torch.sum(batch_op_calc * self.reg_param.unsqueeze(0),
                                dim=2)  # B x N x F
        delta_input = delta_input.view(-1, self.input_dim)
        next_x = data_input.x + delta_input
        output_graph = replace_graph(data_input, x=next_x)

        if return_features:
            DG_output_data = replace_graph(
                data_input,
                x=data_input.x,
                edge_attr=None,
                gradient_weight=self.reg_param.detach().data,
                laplacian_weight=None)
            return output_graph, DG_output_data
        else:
            return output_graph
Ejemplo n.º 2
0
    def forward_onestep(self, data, return_features=False):
        if self.predict_model == 'GN':
            if self.is_recurrent:
                if not hasattr(data, 'node_hidden'):
                    data = replace_graph(
                        data,
                        node_hidden=data.x.new_zeros(
                            self.gn_layer_num,
                            self.gn_net.gn_layers[0][1].net.num_layers,
                            data.x.shape[0],
                            self.gn_net.gn_layers[0][1].net.latent_dim))
                if not hasattr(data, 'edge_hidden'):
                    data = replace_graph(
                        data,
                        edge_hidden=data.x.new_zeros(
                            self.gn_layer_num,
                            self.gn_net.gn_layers[0][0].net.num_layers,
                            data.edge_index.shape[1],
                            self.gn_net.gn_layers[0][0].net.latent_dim))

        # One-step prediction
        # Read data (B,T,N,F) and return (B,1,N,output_dim).

        length = data.x.shape[1]  # T

        # 1. Derivative Cell
        DC_output = self.derivative_cell(data, length=length)  # dictionary

        # 2. NN_PDE
        PDE_params = self.NN_PDE(DC_output)  # (N,F) or (E,F)

        # 3. Derivative Graph
        DG_output = self.build_DG(data, DC_output,
                                  PDE_params)  # torch_geometric.Data

        DG_output_data = DG_output.clone().apply(lambda x: x.detach())

        # 4. Prediction
        if self.predict_model == 'GN':
            output_graph = self.gn_net(DG_output)  # torch_geometric.Data
        else:
            gradient = DC_output['gradient']
            laplacian = DC_output['laplacian']
            gradient_out = torch.zeros_like(laplacian)
            gradient_out = scatter_add(gradient,
                                       DC_output['edge_index'][1, :],
                                       dim=0,
                                       out=gradient_out)
            dx = gradient_out + laplacian
            output_graph = replace_graph(DG_output, x=dx)

        # 5. Outputs
        output_graph.x = output_graph.x + data.x[:, -1, -self.output_dim:]

        if return_features:
            return output_graph, DG_output_data
        else:
            return output_graph
Ejemplo n.º 3
0
 def forward_onestep(self, data):
     if not hasattr(data, 'node_hidden'):
         data = replace_graph(data,
             node_hidden=data.x.new_zeros(self.num_layers, data.x.shape[0], self.hidden_dim))
     node_hidden_output, node_hidden_next = self.rnn(data.x, data.node_hidden)
     node_output = self.decoder(node_hidden_output) + data.x[:, -1:, -self.output_dim:]
     node_output = node_output.squeeze(1)
     output_graph = replace_graph(data, x=node_output, node_hidden=node_hidden_next)
     return output_graph
Ejemplo n.º 4
0
    def forward_onestep(self, data):
        input_features_list = unbatch_node_feature(data, 'x', data.batch)  # list
        graph_batch_list = unbatch_node_feature(data, 'graph_batch', data.batch)
        input_features = list(itertools.chain.from_iterable([unbatch_node_feature_mat(input_features_i, graph_batch_i)
                                                             for input_features_i, graph_batch_i in
                                                             zip(input_features_list, graph_batch_list)]))
        input_features = torch.stack(input_features, dim=0)
        input_features = input_features.transpose(1, 2).flatten(2, 3)

        if not hasattr(data, 'node_hidden'):
            data = replace_graph(data,
                node_hidden=data.x.new_zeros(self.num_layers, input_features.shape[0], self.rnn.hidden_size))
        node_hidden_output, node_hidden_next = self.rnn(input_features, data.node_hidden)
        node_hidden_output = self.decoder(node_hidden_output)
        node_hidden_output = node_hidden_output.reshape(input_features.shape[0], self.node_num, self.output_dim).flatten(0, 1)
        node_output = node_hidden_output + data.x[:, -1, -self.output_dim:]
        output_graph = replace_graph(data, x=node_output, node_hidden=node_hidden_next)
        return output_graph
Ejemplo n.º 5
0
 def forward_onestep(self, data):
     input_features_list = unbatch_node_feature(data, 'x', data.batch) # list
     graph_batch_list = unbatch_node_feature(data, 'graph_batch', data.batch)
     input_features = list(itertools.chain.from_iterable([unbatch_node_feature_mat(input_features_i, graph_batch_i)
                       for input_features_i, graph_batch_i in zip(input_features_list, graph_batch_list)]))
     input_features = torch.stack(input_features, dim=0)
     input_features = input_features.reshape(input_features.shape[0], -1)
     out = self.var_layer(input_features)
     out = out.reshape(input_features.shape[0], self.node_num, self.output_dim).flatten(0, 1)
     out = out + data.x[:, -1, -self.output_dim:]
     out_graph = replace_graph(data, x=out)
     return out_graph
Ejemplo n.º 6
0
    def forward_onestep_singlemlp(self, data, return_features=False):
        data_input = replace_graph(data, x=data.x[:, -1, :])  # N x F
        if self.node_meta_dim > 0:
            node_meta = data.node_meta  # should be N x 2 (x, y)
        else:
            node_meta = None

        op_calc = self._calculate_op(
            data_input, node_meta,
            self.optype)  # N x self.parameterized_order_num x input_dim
        op_calc = op_calc['op_agg_msgs']  # N x 3 x F

        net_input = torch.cat((data_input.x.unsqueeze(1), op_calc),
                              dim=1).view(-1,
                                          4 * self.input_dim)  # N x (4 x F)
        delta_input = self.prediction_net(net_input)  # N x F
        next_x = data_input.x + delta_input
        output_graph = replace_graph(data_input, x=next_x)
        if return_features:
            raise NotImplementedError()
        else:
            return output_graph
Ejemplo n.º 7
0
    def forward(self, graph):
        # 1st dim is layer rank
        node_hidden_list = graph.node_hidden
        edge_hidden_list = graph.edge_hidden
        updated_node_hidden_list = []
        updated_edge_hidden_list = []
        assert len(node_hidden_list) == self.gn_layer_num

        graph_li = replace_graph(graph)

        for li in range(self.gn_layer_num):
            graph_li = replace_graph(graph_li,
                                     node_hidden=node_hidden_list[li],
                                     edge_hidden=edge_hidden_list[li])
            graph_li = self.gn_layers[li](graph_li)
            updated_node_hidden_list.append(graph_li.node_hidden)
            updated_edge_hidden_list.append(graph_li.edge_hidden)

        graph = replace_graph(graph_li,
                              node_hidden=torch.stack(updated_node_hidden_list,
                                                      dim=0),
                              edge_hidden=torch.stack(updated_edge_hidden_list,
                                                      dim=0))
        return graph
Ejemplo n.º 8
0
    def forward_onestep_gn_rgn(self, data, return_features=False):
        if self.prediction_net_is_recurrent and (self.prediction_model
                                                 == 'RGN'):
            if not hasattr(data, 'edge_hidden_prediction_net'):
                data = replace_graph(
                    data,
                    edge_hidden_prediction_net=data.x.new_zeros(
                        self.prediction_net_layer_num,
                        self.prediction_net.gn_layers[0][0].net.num_layers,
                        data.edge_index.shape[1],
                        self.prediction_net.gn_layers[0][0].net.latent_dim))
            if not hasattr(data, 'node_hidden_prediction_net'):
                data = replace_graph(
                    data,
                    node_hidden_prediction_net=data.x.new_zeros(
                        self.prediction_net_layer_num,
                        self.prediction_net.gn_layers[0][1].net.num_layers,
                        data.x.shape[0],
                        self.prediction_net.gn_layers[0][1].net.latent_dim))

        data_input = replace_graph(data, x=data.x[:, -1, :])
        if self.node_meta_dim > 0:
            node_meta = data.node_meta  # should be N x 2 (x, y)
        else:
            node_meta = None
        op_calc = self._calculate_op(data_input, node_meta, self.optype)

        if self.optype == 'standard':
            new_x_input = torch.cat(
                (data_input.x.unsqueeze(1), op_calc['op_agg_msgs'][:, 2:3, :]),
                dim=1).flatten(-2, -1)
            new_ea_input = op_calc['op_msgs'][:, 0:1, :].flatten(-2, -1)
        elif self.optype == 'trimesh':
            new_x_input = torch.cat(
                (data_input.x.unsqueeze(1), op_calc['op_agg_msgs']),
                dim=1).flatten(-2, -1)
            new_ea_input = None
        else:
            raise NotImplementedError()
        prediction_input_graph = replace_graph(data_input,
                                               x=new_x_input,
                                               edge_attr=new_ea_input)
        # print(torch.any(torch.isnan(data_input.x)), torch.any(torch.isnan(op_calc['op_agg_msgs'][:, 0, :])),
        #       torch.any(torch.isnan(op_calc['op_agg_msgs'][:, 1, :])),
        #       torch.any(torch.isnan(op_calc['op_agg_msgs'][:, 2, :])))
        # print(prediction_input_graph.x.shape, prediction_input_graph.edge_attr.shape)
        if self.prediction_net_is_recurrent:
            prediction_input_graph = replace_graph(
                prediction_input_graph,
                node_hidden=prediction_input_graph.node_hidden_prediction_net,
                edge_hidden=prediction_input_graph.edge_hidden_prediction_net)
            model_prediction_out_graph = self.prediction_net(
                prediction_input_graph)
            model_prediction_out_graph = replace_graph(
                model_prediction_out_graph,
                node_hidden_prediction_net=model_prediction_out_graph.
                node_hidden,
                edge_hidden_prediction_net=model_prediction_out_graph.
                edge_hidden)
        else:
            model_prediction_out_graph = self.prediction_net(
                prediction_input_graph)
        output_graph = replace_graph(model_prediction_out_graph,
                                     x=(data_input.x[..., -self.output_dim:] +
                                        model_prediction_out_graph.x))

        if return_features:
            raise NotImplementedError()
        else:
            return output_graph
Ejemplo n.º 9
0
 def forward_onestep(self, data):
     out = self.mlp(data.x.flatten(1, 2))
     out = out + data.x[:, -1, -self.output_dim:]
     out_graph = replace_graph(data, x=out)
     return out_graph
Ejemplo n.º 10
0
    def build_DG(self, data, DC_output, PDE_params):
        """
        Module for generating Derivative Graph.
        It builds a new graph having derivatives and PDE parameters as node-, edge-attributes.
        For instance, if a given PDE is Diffusion Eqn,
        this module will concat node-wise attributes with PDE_params (diffusion-coefficient).

        Input:
            DC_output:
                - Output of derivative_cell()
                - dictionary and key: du_dt, gradient, laplacian, curr
            PDE_params:
                - Output of NN_PDE()
                - Depending on "mode", it may be node-wise or edge-wise features.
            mode: (self)
                - This should be same as "mode" in NN_PDE
        Output:
            output_graph:
                - output_graph.x : curr, laplacian, du_dt
                - output_graph.edge_attr : gradient
                - Additionally, PDE_params will be concatenated properly.
        """

        curr = DC_output["curr"]  # (N, F)

        if self.nophysics_mode == 'nopad':
            if self.use_dist:
                output_graph = replace_graph(data,
                                             x=curr,
                                             edge_attr=data.edge_dist)
            else:
                output_graph = replace_graph(data, x=curr, edge_attr=None)
        else:
            du_dt = DC_output["du_dt"]  # (N, F)
            gradient = DC_output["gradient"]  # (E, F)
            laplacian = DC_output["laplacian"]  # (N, F)
            if self.mode == "diff":
                node_attr_to_cat = [
                    curr,
                ]
                if self.use_laplacian:
                    node_attr_to_cat.append(laplacian)
                if self.use_time_grad:
                    node_attr_to_cat.append(du_dt)
                if self.use_pde_params:
                    node_attr_to_cat.append(PDE_params)

                edge_attr_to_cat = []
                if self.use_dist:
                    edge_attr_to_cat.append(data.edge_dist)
                if self.use_edge_grad:
                    edge_attr_to_cat.append(gradient)
            elif self.mode == "adv":
                node_attr_to_cat = [
                    curr,
                ]
                if self.use_laplacian:
                    node_attr_to_cat.append(laplacian)
                if self.use_time_grad:
                    node_attr_to_cat.append(du_dt)

                edge_attr_to_cat = []
                if self.use_dist:
                    edge_attr_to_cat.append(data.edge_dist)
                if self.use_edge_grad:
                    edge_attr_to_cat.append(gradient)
                if self.use_pde_params:
                    edge_attr_to_cat.append(PDE_params)
            else:
                # TODO
                raise NotImplementedError()
            node_attr = torch.cat(node_attr_to_cat, dim=-1)
            if len(edge_attr_to_cat) > 0:
                edge_attr = torch.cat(edge_attr_to_cat, dim=-1)
            else:
                edge_attr = None
            output_graph = replace_graph(
                data,
                x=node_attr,
                edge_attr=edge_attr,
                gradient_weight=DC_output['gradient_weight'],
                laplacian_weight=DC_output['laplacian_weight'])
        return output_graph