def test_lstm_example():
    import tensorflow as tf
    from common import policies, models, cmd_util
    from common.vec_env.dummy_vec_env import DummyVecEnv

    # create vectorized environment
    venv = DummyVecEnv(
        [lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)])

    with tf.Session() as sess:
        # build policy based on lstm network with 128 units
        policy = policies.build_policy(venv, models.lstm(128))(nbatch=1,
                                                               nsteps=1)

        # initialize tensorflow variables
        sess.run(tf.global_variables_initializer())

        # prepare environment variables
        ob = venv.reset()
        state = policy.initial_state
        done = [False]
        step_counter = 0

        # run a single episode until the end (i.e. until done)
        while True:
            action, _, state, _ = policy.step(ob, S=state, M=done)
            ob, reward, done, _ = venv.step(action)
            step_counter += 1
            if done:
                break

        assert step_counter > 5
def test_serialization(learn_fn, network_fn):
    '''
    Test if the trained model can be serialized
    '''

    if network_fn.endswith('lstm') and learn_fn in [
            'acer', 'acktr', 'trpo_mpi', 'deepq'
    ]:
        # TODO make acktr work with recurrent policies
        # and test
        # github issue: https://github.com/openai/baselines/issues/660
        return

    def make_env():
        env = MnistEnv(episode_len=100)
        env.seed(10)
        return env

    env = DummyVecEnv([make_env])
    ob = env.reset().copy()
    learn = get_learn_function(learn_fn)

    kwargs = {}
    kwargs.update(network_kwargs[network_fn])
    kwargs.update(learn_kwargs[learn_fn])

    learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)

    with tempfile.TemporaryDirectory() as td:
        model_path = os.path.join(td, 'serialization_test_model')

        with tf.Graph().as_default(), make_session().as_default():
            model = learn(total_timesteps=100)
            model.save(model_path)
            mean1, std1 = _get_action_stats(model, ob)
            variables_dict1 = _serialize_variables()

        with tf.Graph().as_default(), make_session().as_default():
            model = learn(total_timesteps=0, load_path=model_path)
            mean2, std2 = _get_action_stats(model, ob)
            variables_dict2 = _serialize_variables()

        for k, v in variables_dict1.items():
            np.testing.assert_allclose(
                v,
                variables_dict2[k],
                atol=0.01,
                err_msg='saved and loaded variable {} value mismatch'.format(
                    k))

        np.testing.assert_allclose(mean1, mean2, atol=0.5)
        np.testing.assert_allclose(std1, std2, atol=0.5)