def test_discrete_identity(alg):
    '''
    Test if the algorithm (with an mlp policy)
    can learn an identity transformation (i.e. return observation as an action)
    '''

    kwargs = learn_kwargs[alg]
    kwargs.update(common_kwargs)

    learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
    env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100)
    simple_test(env_fn, learn_fn, 0.9)
def test_continuous_identity(alg):
    '''
    Test if the algorithm (with an mlp policy)
    can learn an identity transformation (i.e. return observation as an action)
    to a required precision
    '''

    kwargs = learn_kwargs[alg]
    kwargs.update(common_kwargs)
    learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)

    env_fn = lambda: BoxIdentityEnv((1, ), episode_len=100)
    simple_test(env_fn, learn_fn, -0.1)
Exemple #3
0
def test_fixed_sequence(alg, rnn):
    '''
    Test if the algorithm (with a given policy)
    can learn an identity transformation (i.e. return observation as an action)
    '''

    kwargs = learn_kwargs[alg]
    kwargs.update(common_kwargs)

    episode_len = 5
    env_fn = lambda: FixedSequenceEnv(10, episode_len=episode_len)
    learn = lambda e: get_learn_function(alg)(env=e, network=rnn, **kwargs)

    simple_test(env_fn, learn, 0.7)
Exemple #4
0
def test_mnist(alg):
    '''
    Test if the algorithm can learn to classify MNIST digits. 
    Uses CNN policy. 
    '''
    
    learn_kwargs = learn_args[alg]
    learn_kwargs.update(common_kwargs)
    
    learn = get_learn_function(alg)
    learn_fn = lambda e: learn(env=e, **learn_kwargs)
    env_fn = lambda: MnistEnv(seed=0, episode_len=100)

    simple_test(env_fn, learn_fn, 0.6)
def test_serialization(learn_fn, network_fn):
    '''
    Test if the trained model can be serialized 
    '''

    if network_fn.endswith('lstm') and learn_fn in [
            'acktr', 'trpo_mpi', 'deepq'
    ]:
        # TODO make acktr work with recurrent policies
        # and test
        # github issue: https://github.com/openai/baselines/issues/194
        return

    env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
    ob = env.reset().copy()
    learn = get_learn_function(learn_fn)

    kwargs = {}
    kwargs.update(network_kwargs[network_fn])
    kwargs.update(learn_kwargs[learn_fn])

    learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)

    with tempfile.TemporaryDirectory() as td:
        model_path = os.path.join(td, 'serialization_test_model')

        with tf.Graph().as_default(), make_session().as_default():
            model = learn(total_timesteps=100)
            model.save(model_path)
            mean1, std1 = _get_action_stats(model, ob)
            variables_dict1 = _serialize_variables()

        with tf.Graph().as_default(), make_session().as_default():
            model = learn(total_timesteps=0, load_path=model_path)
            mean2, std2 = _get_action_stats(model, ob)
            variables_dict2 = _serialize_variables()

        for k, v in variables_dict1.items():
            np.testing.assert_allclose(
                v,
                variables_dict2[k],
                atol=0.01,
                err_msg='saved and loaded variable {} value mismatch'.format(
                    k))

        np.testing.assert_allclose(mean1, mean2, atol=0.5)
        np.testing.assert_allclose(std1, std2, atol=0.5)