def _test_multi_step(self, predictor, fx_train_0, fx_test_0, g_td, momentum): # Test multi-time prediction ts = np.arange(6).reshape((2, 1, 3)) fx_train_single, fx_test_single = predictor(ts, fx_train_0, fx_test_0, g_td) fx_train_concat, fx_test_concat = [], [] for t in ts.ravel(): fx_train_concat_t, fx_test_concat_t = predictor( t, fx_train_0, fx_test_0, g_td) fx_train_concat += [fx_train_concat_t] fx_test_concat += [fx_test_concat_t] fx_train_concat = np.stack(fx_train_concat).reshape( ts.shape + fx_train_single.shape[ts.ndim:]) fx_test_concat = np.stack(fx_test_concat).reshape( ts.shape + fx_test_single.shape[ts.ndim:]) self.assertAllClose(fx_train_concat, fx_train_single) self.assertAllClose(fx_test_concat, fx_test_single) if momentum is not None: state_0 = predict.ODEState(fx_train_0, fx_test_0) t_1 = (0, 0, 2) state_1 = predictor(ts[t_1], state_0, None, g_td) self.assertAllClose(fx_train_single[t_1], state_1.fx_train) self.assertAllClose(fx_test_single[t_1], state_1.fx_test) t_max = (-1, ) * ts.ndim state_max = predictor(ts[t_max] - ts[t_1], state_1, None, g_td) self.assertAllClose(fx_train_single[t_max], state_max.fx_train) self.assertAllClose(fx_test_single[t_max], state_max.fx_test)
def _test_zero_time(self, predictor, fx_train_0, fx_test_0, g_td, momentum): fx_train_t0, fx_test_t0 = predictor(0.0, fx_train_0, fx_test_0, g_td) self.assertAllClose(fx_train_0, fx_train_t0) self.assertAllClose(fx_test_0, fx_test_t0) fx_train_only_t0 = predictor(0.0, fx_train_0, None, g_td) self.assertAllClose(fx_train_0, fx_train_only_t0) if momentum is not None: # Test state-based prediction state_0 = predict.ODEState(fx_train_0, fx_test_0) state_t0 = predictor(0.0, state_0, None, g_td) self.assertAllClose(state_0.fx_train, state_t0.fx_train) self.assertAllClose(state_0.fx_test, state_t0.fx_test) state_train_only_0 = predict.ODEState(fx_train_0) state_train_only_t0 = predictor(0.0, state_0, None, g_td) self.assertAllClose(state_train_only_0.fx_train, state_train_only_t0.fx_train)