def _make_env(
     self,
     observation_space,
     action_space,
     vocab_size,
     predict_fn=None,
     reward_fn=None,
     done_fn=None,
     batch_size=None,
     max_trajectory_length=None,
 ):
     mock_model_fn = mock.MagicMock()
     if predict_fn is not None:
         mock_model_fn.return_value = predict_fn
         mock_model_fn.return_value.initialize_once.return_value = (
             base.EMPTY_WEIGHTS, base.EMPTY_STATE)
     return simulated_env_problem.SerializedSequenceSimulatedEnvProblem(
         model=mock_model_fn,
         reward_fn=reward_fn,
         done_fn=done_fn,
         vocab_size=vocab_size,
         max_trajectory_length=max_trajectory_length,
         batch_size=batch_size,
         observation_space=observation_space,
         action_space=action_space,
         reward_range=(-1, 1),
         discrete_rewards=False,
         history_stream=itertools.repeat(None),
         output_dir=None,
     )
Ejemplo n.º 2
0
  def _make_env(
      mock_restore_state, observation_space, action_space,
      max_trajectory_length, batch_size,
  ):
    # (model_params, opt_state)
    mock_restore_state.return_value.params = (None, None)

    gin.bind_parameter('BoxSpaceSerializer.precision', 1)

    predict_output = (np.array([[[0.0]]] * batch_size))
    mock_model_fn = mock.MagicMock()
    mock_model_fn.return_value.side_effect = itertools.repeat(predict_output)
    mock_model_fn.return_value.init.return_value = (
        base.EMPTY_WEIGHTS, base.EMPTY_STATE)

    return simulated_env_problem.SerializedSequenceSimulatedEnvProblem(
        model=mock_model_fn,
        reward_fn=(lambda _1, _2: np.zeros(batch_size)),
        done_fn=(lambda _1, _2: np.full((batch_size,), False)),
        vocab_size=1,
        max_trajectory_length=max_trajectory_length,
        batch_size=batch_size,
        observation_space=observation_space,
        action_space=action_space,
        reward_range=(-1, 1),
        discrete_rewards=False,
        history_stream=itertools.repeat(None),
        output_dir=None,
    )
Ejemplo n.º 3
0
  def test_runs_with_transformer(self):
    env = simulated_env_problem.SerializedSequenceSimulatedEnvProblem(
        model=functools.partial(
            transformer.TransformerLM, d_model=2, d_ff=2, n_heads=1, n_layers=1
        ),
        reward_fn=(lambda _1, _2: np.array([0.5])),
        done_fn=(lambda _1, _2: np.array([False])),
        vocab_size=4,
        max_trajectory_length=3,
        batch_size=1,
        observation_space=gym.spaces.Box(low=0, high=5, shape=(4,)),
        action_space=gym.spaces.Discrete(n=2),
        reward_range=(-1, 1),
        discrete_rewards=False,
        history_stream=itertools.repeat(None),
        output_dir=None,
    )

    env.reset()
    for expected_done in [False, True]:
      (_, _, dones, _) = env.step(np.array([0]))
      np.testing.assert_array_equal(dones, [expected_done])