コード例 #1
0
def test_multidiscrete_identity(alg):
    '''
    Test if the algorithm (with an mlp policy)
    can learn an identity transformation (i.e. return observation as an action)
    '''

    kwargs = learn_kwargs[alg]
    kwargs.update(common_kwargs)

    learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
    env_fn = lambda: MultiDiscreteIdentityEnv((3,3), episode_len=100)
    simple_test(env_fn, learn_fn, 0.9)
コード例 #2
0
def test_serialization(learn_fn, network_fn):
    '''
    Test if the trained model can be serialized
    '''

    if network_fn.endswith('lstm') and learn_fn in [
            'acer', 'acktr', 'trpo_mpi', 'deepq'
    ]:
        # TODO make acktr work with recurrent policies
        # and test
        # github issue: https://github.com/openai/baselines/issues/660
        return

    def make_env():
        env = MnistEnv(episode_len=100)
        env.seed(10)
        return env

    env = DummyVecEnv([make_env])
    ob = env.reset().copy()
    learn = get_learn_function(learn_fn)

    kwargs = {}
    kwargs.update(network_kwargs[network_fn])
    kwargs.update(learn_kwargs[learn_fn])

    learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)

    with tempfile.TemporaryDirectory() as td:
        model_path = os.path.join(td, 'serialization_test_model')

        with tf.Graph().as_default(), make_session().as_default():
            model = learn(total_timesteps=100)
            model.save(model_path)
            mean1, std1 = _get_action_stats(model, ob)
            variables_dict1 = _serialize_variables()

        with tf.Graph().as_default(), make_session().as_default():
            model = learn(total_timesteps=0, load_path=model_path)
            mean2, std2 = _get_action_stats(model, ob)
            variables_dict2 = _serialize_variables()

        for k, v in variables_dict1.items():
            np.testing.assert_allclose(
                v,
                variables_dict2[k],
                atol=0.01,
                err_msg='saved and loaded variable {} value mismatch'.format(
                    k))

        np.testing.assert_allclose(mean1, mean2, atol=0.5)
        np.testing.assert_allclose(std1, std2, atol=0.5)
コード例 #3
0
def test_continuous_identity(alg):
    '''
    Test if the algorithm (with an mlp policy)
    can learn an identity transformation (i.e. return observation as an action)
    to a required precision
    '''

    kwargs = learn_kwargs[alg]
    kwargs.update(common_kwargs)
    learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)

    env_fn = lambda: BoxIdentityEnv((1,), episode_len=100)
    simple_test(env_fn, learn_fn, -0.1)
コード例 #4
0
def test_fixed_sequence(alg, rnn):
    '''
    Test if the algorithm (with a given policy)
    can learn an identity transformation (i.e. return observation as an action)
    '''

    kwargs = learn_kwargs[alg]
    kwargs.update(common_kwargs)

    env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
    learn = lambda e: get_learn_function(alg)(env=e, network=rnn, **kwargs)

    simple_test(env_fn, learn, 0.7)
コード例 #5
0
def test_mnist(alg):
    '''
    Test if the algorithm can learn to classify MNIST digits.
    Uses CNN policy.
    '''

    learn_kwargs = learn_args[alg]
    learn_kwargs.update(common_kwargs)

    learn = get_learn_function(alg)
    learn_fn = lambda e: learn(env=e, **learn_kwargs)
    env_fn = lambda: MnistEnv(episode_len=100)

    simple_test(env_fn, learn_fn, 0.6)
コード例 #6
0
def test_env_after_learn(algo):
    def make_env():
        # acktr requires too much RAM, fails on travis
        env = gym.make('CartPole-v1' if algo == 'acktr' else 'PongNoFrameskip-v4')
        return env

    make_session(make_default=True, graph=tf.Graph())
    env = SubprocVecEnv([make_env])

    learn = get_learn_function(algo)

    # Commenting out the following line resolves the issue, though crash happens at env.reset().
    learn(network='mlp', env=env, total_timesteps=0, load_path=None, seed=None)

    env.reset()
    env.close()
コード例 #7
0
def test_cartpole(alg):
    '''
    Test if the algorithm (with an mlp policy)
    can learn to balance the cartpole
    '''

    kwargs = common_kwargs.copy()
    kwargs.update(learn_kwargs[alg])

    learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)

    def env_fn():

        env = gym.make('CartPole-v0')
        env.seed(0)
        return env

    reward_per_episode_test(env_fn, learn_fn, 100)
コード例 #8
0
def test_coexistence(learn_fn, network_fn):
    '''
    Test if more than one model can exist at a time
    '''

    if learn_fn == 'deepq':
        # TODO enable multiple DQN models to be useable at the same time
        # github issue https://github.com/openai/baselines/issues/656
        return

    if network_fn.endswith('lstm') and learn_fn in [
            'acktr', 'trpo_mpi', 'deepq'
    ]:
        # TODO make acktr work with recurrent policies
        # and test
        # github issue: https://github.com/openai/baselines/issues/660
        return

    env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
    learn = get_learn_function(learn_fn)

    kwargs = {}
    kwargs.update(network_kwargs[network_fn])
    kwargs.update(learn_kwargs[learn_fn])

    learn = partial(learn,
                    env=env,
                    network=network_fn,
                    total_timesteps=0,
                    **kwargs)
    make_session(make_default=True, graph=tf.Graph())
    model1 = learn(seed=1)
    make_session(make_default=True, graph=tf.Graph())
    model2 = learn(seed=2)

    model1.step(env.observation_space.sample())
    model2.step(env.observation_space.sample())