def test_multidiscrete_identity(alg): ''' Test if the algorithm (with an mlp policy) can learn an identity transformation (i.e. return observation as an action) ''' kwargs = learn_kwargs[alg] kwargs.update(common_kwargs) learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) env_fn = lambda: MultiDiscreteIdentityEnv((3, 3), episode_len=100) simple_test(env_fn, learn_fn, 0.9)
def test_continuous_identity(alg): ''' Test if the algorithm (with an mlp policy) can learn an identity transformation (i.e. return observation as an action) to a required precision ''' kwargs = learn_kwargs[alg] kwargs.update(common_kwargs) learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) env_fn = lambda: BoxIdentityEnv((1, ), episode_len=100) simple_test(env_fn, learn_fn, -0.1)
def test_mnist(alg): ''' Test if the algorithm can learn to classify MNIST digits. Uses CNN policy. ''' learn_kwargs = learn_args[alg] learn_kwargs.update(common_kwargs) learn = get_learn_function(alg) learn_fn = lambda e: learn(env=e, **learn_kwargs) env_fn = lambda: MnistEnv(episode_len=100) simple_test(env_fn, learn_fn, 0.6)
def test_fixed_sequence(alg, rnn): ''' Test if the algorithm (with a given policy) can learn an identity transformation (i.e. return observation as an action) ''' kwargs = learn_kwargs[alg] kwargs.update(common_kwargs) env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5) learn = lambda e: get_learn_function(alg)( env=e, network=rnn, **kwargs ) simple_test(env_fn, learn, 0.7)