def _make_env( self, observation_space, action_space, vocab_size, predict_fn=None, reward_fn=None, done_fn=None, batch_size=None, max_trajectory_length=None, ): mock_model_fn = mock.MagicMock() if predict_fn is not None: mock_model_fn.return_value = predict_fn mock_model_fn.return_value.initialize_once.return_value = ( base.EMPTY_WEIGHTS, base.EMPTY_STATE) return simulated_env_problem.SerializedSequenceSimulatedEnvProblem( model=mock_model_fn, reward_fn=reward_fn, done_fn=done_fn, vocab_size=vocab_size, max_trajectory_length=max_trajectory_length, batch_size=batch_size, observation_space=observation_space, action_space=action_space, reward_range=(-1, 1), discrete_rewards=False, history_stream=itertools.repeat(None), output_dir=None, )
def _make_env( mock_restore_state, observation_space, action_space, max_trajectory_length, batch_size, ): # (model_params, opt_state) mock_restore_state.return_value.params = (None, None) gin.bind_parameter('BoxSpaceSerializer.precision', 1) predict_output = (np.array([[[0.0]]] * batch_size)) mock_model_fn = mock.MagicMock() mock_model_fn.return_value.side_effect = itertools.repeat(predict_output) mock_model_fn.return_value.init.return_value = ( base.EMPTY_WEIGHTS, base.EMPTY_STATE) return simulated_env_problem.SerializedSequenceSimulatedEnvProblem( model=mock_model_fn, reward_fn=(lambda _1, _2: np.zeros(batch_size)), done_fn=(lambda _1, _2: np.full((batch_size,), False)), vocab_size=1, max_trajectory_length=max_trajectory_length, batch_size=batch_size, observation_space=observation_space, action_space=action_space, reward_range=(-1, 1), discrete_rewards=False, history_stream=itertools.repeat(None), output_dir=None, )
def test_runs_with_transformer(self): env = simulated_env_problem.SerializedSequenceSimulatedEnvProblem( model=functools.partial( transformer.TransformerLM, d_model=2, d_ff=2, n_heads=1, n_layers=1 ), reward_fn=(lambda _1, _2: np.array([0.5])), done_fn=(lambda _1, _2: np.array([False])), vocab_size=4, max_trajectory_length=3, batch_size=1, observation_space=gym.spaces.Box(low=0, high=5, shape=(4,)), action_space=gym.spaces.Discrete(n=2), reward_range=(-1, 1), discrete_rewards=False, history_stream=itertools.repeat(None), output_dir=None, ) env.reset() for expected_done in [False, True]: (_, _, dones, _) = env.step(np.array([0])) np.testing.assert_array_equal(dones, [expected_done])