def forward(self,t,z,sysP,wgrad=True): dynamics = HamiltonianDynamics(lambda t,z: self.compute_H(z,sysP),wgrad=wgrad) return dynamics(t,z)
def _get_dynamics(self, sys_params): H = lambda t, z: KeplerH(z, *sys_params) return HamiltonianDynamics(H, wgrad=False)
def forward(self, data): o = AttrDict() (z0, sys_params, ts), true_zs = data pred_zs = self._rollout_model(z0, ts, sys_params) mse = (pred_zs - true_zs).pow(2).mean() if self.debug: if self.task == "spring": # currently a bit inefficient to do the below? with torch.no_grad(): (z0, sys_params, ts), true_zs = data z = z0 m = sys_params[ ..., 0] # assume the first component encodes masses D = z.shape[-1] # of ODE dims, 2*num_particles*space_dim q = z[:, :D // 2].reshape(*m.shape, -1) p = z[:, D // 2:].reshape(*m.shape, -1) V_pred = self.predictor.compute_V((q, sys_params)) k = sys_params[..., 1] V_true = SpringV(q, k) mse_V = (V_pred - V_true).pow(2).mean() # dynamics dyn_tz_pred = self.predictor(ts, z0, sys_params) H = lambda t, z: SpringH(z, sys_params[..., 0].squeeze(-1), sys_params[..., 1].squeeze(-1)) dynamics = HamiltonianDynamics(H, wgrad=False) dyn_tz_true = dynamics(ts, z0) mse_dyn = (dyn_tz_true - dyn_tz_pred).pow(2).mean() if self.task == "nbody": # currently a bit inefficient to do the below? with torch.no_grad(): (z0, sys_params, ts), true_zs = data z = z0 m = sys_params[ ..., 0] # assume the first component encodes masses D = z.shape[-1] # of ODE dims, 2*num_particles*space_dim q = z[:, :D // 2].reshape(*m.shape, -1) p = z[:, D // 2:].reshape(*m.shape, -1) V_pred = self.predictor.compute_V((q, sys_params)) V_true = KeplerV(q, m) mse_V = (V_pred - V_true).pow(2).mean() # dynamics dyn_tz_pred = self.predictor(ts, z0, sys_params) H = lambda t, z: KeplerH(z, sys_params[..., 0].squeeze(-1)) dynamics = HamiltonianDynamics(H, wgrad=False) dyn_tz_true = dynamics(ts, z0) mse_dyn = (dyn_tz_true - dyn_tz_pred).pow(2).mean() o.mse_dyn = mse_dyn o.mse_V = mse_V o.prediction = pred_zs o.mse = mse o.loss = mse # loss wrt which we train the model if self.debug: o.reports = AttrDict({ "mse": o.mse, "mse_V": o.mse_V, "mse_dyn": o.mse_dyn }) else: o.reports = AttrDict({"mse": o.mse}) return o