コード例 #1
0
 def forward(self,t,z,sysP,wgrad=True):
     dynamics = HamiltonianDynamics(lambda t,z: self.compute_H(z,sysP),wgrad=wgrad)
     return dynamics(t,z)
コード例 #2
0
 def _get_dynamics(self, sys_params):
     H = lambda t, z: KeplerH(z, *sys_params)
     return HamiltonianDynamics(H, wgrad=False)
コード例 #3
0
    def forward(self, data):
        o = AttrDict()

        (z0, sys_params, ts), true_zs = data

        pred_zs = self._rollout_model(z0, ts, sys_params)
        mse = (pred_zs - true_zs).pow(2).mean()

        if self.debug:
            if self.task == "spring":
                # currently a bit inefficient to do the below?
                with torch.no_grad():
                    (z0, sys_params, ts), true_zs = data

                    z = z0
                    m = sys_params[
                        ..., 0]  # assume the first component encodes masses
                    D = z.shape[-1]  # of ODE dims, 2*num_particles*space_dim
                    q = z[:, :D // 2].reshape(*m.shape, -1)
                    p = z[:, D // 2:].reshape(*m.shape, -1)
                    V_pred = self.predictor.compute_V((q, sys_params))

                    k = sys_params[..., 1]
                    V_true = SpringV(q, k)

                    mse_V = (V_pred - V_true).pow(2).mean()

                    # dynamics
                    dyn_tz_pred = self.predictor(ts, z0, sys_params)

                    H = lambda t, z: SpringH(z, sys_params[..., 0].squeeze(-1),
                                             sys_params[..., 1].squeeze(-1))
                    dynamics = HamiltonianDynamics(H, wgrad=False)
                    dyn_tz_true = dynamics(ts, z0)

                    mse_dyn = (dyn_tz_true - dyn_tz_pred).pow(2).mean()

            if self.task == "nbody":
                # currently a bit inefficient to do the below?
                with torch.no_grad():
                    (z0, sys_params, ts), true_zs = data

                    z = z0
                    m = sys_params[
                        ..., 0]  # assume the first component encodes masses
                    D = z.shape[-1]  # of ODE dims, 2*num_particles*space_dim
                    q = z[:, :D // 2].reshape(*m.shape, -1)
                    p = z[:, D // 2:].reshape(*m.shape, -1)
                    V_pred = self.predictor.compute_V((q, sys_params))

                    V_true = KeplerV(q, m)

                    mse_V = (V_pred - V_true).pow(2).mean()

                    # dynamics
                    dyn_tz_pred = self.predictor(ts, z0, sys_params)

                    H = lambda t, z: KeplerH(z, sys_params[..., 0].squeeze(-1))
                    dynamics = HamiltonianDynamics(H, wgrad=False)
                    dyn_tz_true = dynamics(ts, z0)

                    mse_dyn = (dyn_tz_true - dyn_tz_pred).pow(2).mean()

            o.mse_dyn = mse_dyn
            o.mse_V = mse_V

        o.prediction = pred_zs
        o.mse = mse
        o.loss = mse  # loss wrt which we train the model

        if self.debug:
            o.reports = AttrDict({
                "mse": o.mse,
                "mse_V": o.mse_V,
                "mse_dyn": o.mse_dyn
            })
        else:
            o.reports = AttrDict({"mse": o.mse})

        return o