def _make_env(
     self,
     observation_space,
     action_space,
     vocab_size,
     predict_fn=None,
     reward_fn=None,
     done_fn=None,
     batch_size=None,
     max_trajectory_length=None,
 ):
     mock_model_fn = mock.MagicMock()
     if predict_fn is not None:
         mock_model_fn.return_value = predict_fn
     return simulated_env_problem.SerializedSequenceSimulatedEnvProblem(
         model=mock_model_fn,
         reward_fn=reward_fn,
         done_fn=done_fn,
         vocab_size=vocab_size,
         max_trajectory_length=3,
         batch_size=batch_size,
         observation_space=observation_space,
         action_space=action_space,
         reward_range=(-1, 1),
         discrete_rewards=False,
         history_stream=itertools.repeat(None),
         output_dir=None,
     )
예제 #2
0
    def _make_env(
        mock_restore_state,
        observation_space,
        action_space,
        max_trajectory_length,
        batch_size,
    ):
        # (model_params, opt_state)
        mock_restore_state.return_value.params = (None, None)

        gin.bind_parameter("BoxSpaceSerializer.precision", 1)

        seq_length = max_trajectory_length * int(
            np.prod(observation_space.shape) + np.prod(action_space.shape))
        predict_output = (np.array([[[0.0]] * seq_length]), ())
        mock_model_fn = mock.MagicMock()
        mock_model_fn.return_value.side_effect = itertools.repeat(
            predict_output)

        return simulated_env_problem.SerializedSequenceSimulatedEnvProblem(
            model=mock_model_fn,
            reward_fn=(lambda _1, _2: np.zeros(batch_size)),
            done_fn=(lambda _1, _2: np.full((batch_size, ), False)),
            vocab_size=1,
            max_trajectory_length=max_trajectory_length,
            batch_size=batch_size,
            observation_space=observation_space,
            action_space=action_space,
            reward_range=(-1, 1),
            discrete_rewards=False,
            history_stream=itertools.repeat(None),
            output_dir=None,
        )
  def test_communicates_with_model(self, mock_restore_state):
    gin.bind_parameter("BoxSpaceSerializer.precision", 1)
    vocab_size = 16
    # Mock model predicting a fixed sequence of symbols. It is made such that
    # the first two observations are equal and the last one is different.
    symbols = [
        1, 1, 2, 2,  # obs1
        1, 1, 2, 2,  # obs2
        1, 2, 2, 1,  # obs3
    ]
    def make_prediction(symbol):
      one_hot = np.eye(vocab_size)[symbol]
      log_probs = (1 - one_hot) * -100.0  # Virtually deterministic.
      # (4 obs symbols + 1 action symbol) * 3 timesteps = 15.
      return np.array([[log_probs] * 15])

    mock_model_fn = mock.MagicMock()
    mock_model = mock_model_fn.return_value
    mock_model.side_effect = map(make_prediction, symbols)

    with backend.use_backend("numpy"):
      # (model_params, opt_state)
      mock_restore_state.return_value.params = (None, None)
      env = simulated_env_problem.SerializedSequenceSimulatedEnvProblem(
          model=mock_model_fn,
          reward_fn=(lambda _1, _2: np.array([0.5])),
          done_fn=(lambda _1, _2: np.array([False])),
          vocab_size=vocab_size,
          max_trajectory_length=3,
          batch_size=1,
          observation_space=gym.spaces.Box(low=0, high=5, shape=(4,)),
          action_space=gym.spaces.Discrete(2),
          reward_range=(-1, 1),
          discrete_rewards=False,
          history_stream=itertools.repeat(None),
          output_dir=None,
      )
      obs1 = env.reset()
      ((inputs,), _) = mock_model.call_args

      act1 = 0
      (obs2, reward, done, _) = env.step(np.array([act1]))
      ((inputs,), _) = mock_model.call_args
      self.assertEqual(inputs[0, 4], act1)
      np.testing.assert_array_equal(inputs[0, :4], symbols[:4])
      np.testing.assert_array_equal(obs1, obs2)
      np.testing.assert_array_equal(reward, [0.5])
      np.testing.assert_array_equal(done, [False])

      act2 = 1
      (obs3, reward, done, _) = env.step(np.array([act2]))
      ((inputs,), _) = mock_model.call_args
      self.assertEqual(inputs[0, 9], act2)
      np.testing.assert_array_equal(inputs[0, 5:9], symbols[4:8])
      self.assertFalse(np.array_equal(obs2, obs3))
      np.testing.assert_array_equal(reward, [0.5])
      np.testing.assert_array_equal(done, [False])