Exemplo n.º 1
0
    def _test_multi_step(self, predictor, fx_train_0, fx_test_0, g_td,
                         momentum):
        # Test multi-time prediction
        ts = np.arange(6).reshape((2, 1, 3))

        fx_train_single, fx_test_single = predictor(ts, fx_train_0, fx_test_0,
                                                    g_td)

        fx_train_concat, fx_test_concat = [], []
        for t in ts.ravel():
            fx_train_concat_t, fx_test_concat_t = predictor(
                t, fx_train_0, fx_test_0, g_td)
            fx_train_concat += [fx_train_concat_t]
            fx_test_concat += [fx_test_concat_t]
        fx_train_concat = np.stack(fx_train_concat).reshape(
            ts.shape + fx_train_single.shape[ts.ndim:])
        fx_test_concat = np.stack(fx_test_concat).reshape(
            ts.shape + fx_test_single.shape[ts.ndim:])

        self.assertAllClose(fx_train_concat, fx_train_single)
        self.assertAllClose(fx_test_concat, fx_test_single)

        if momentum is not None:
            state_0 = predict.ODEState(fx_train_0, fx_test_0)
            t_1 = (0, 0, 2)
            state_1 = predictor(ts[t_1], state_0, None, g_td)
            self.assertAllClose(fx_train_single[t_1], state_1.fx_train)
            self.assertAllClose(fx_test_single[t_1], state_1.fx_test)

            t_max = (-1, ) * ts.ndim
            state_max = predictor(ts[t_max] - ts[t_1], state_1, None, g_td)
            self.assertAllClose(fx_train_single[t_max], state_max.fx_train)
            self.assertAllClose(fx_test_single[t_max], state_max.fx_test)
  def _test_zero_time(self, predictor, fx_train_0, fx_test_0, g_td, momentum):
    fx_train_t0, fx_test_t0 = predictor(0.0, fx_train_0, fx_test_0, g_td)
    self.assertAllClose(fx_train_0, fx_train_t0)
    self.assertAllClose(fx_test_0, fx_test_t0)
    fx_train_only_t0 = predictor(0.0, fx_train_0, None, g_td)
    self.assertAllClose(fx_train_0, fx_train_only_t0)

    if momentum is not None:
      # Test state-based prediction
      state_0 = predict.ODEState(fx_train_0, fx_test_0)
      state_t0 = predictor(0.0, state_0, None, g_td)
      self.assertAllClose(state_0.fx_train, state_t0.fx_train)
      self.assertAllClose(state_0.fx_test, state_t0.fx_test)

      state_train_only_0 = predict.ODEState(fx_train_0)
      state_train_only_t0 = predictor(0.0, state_0, None, g_td)
      self.assertAllClose(state_train_only_0.fx_train,
                          state_train_only_t0.fx_train)