Exemplo n.º 1
0
        def test(self):
            """Test."""
            def policy_fn():
                return Policy(TicTacToeNet())

            a0 = partial(AlphaZero, game=TicTacToe(), policy=policy_fn())
            train('test', a0, maxt=100000, eval=True, eval_period=1000)
Exemplo n.º 2
0
        def test_feed_forward_ppo(self):
            """Test feed forward ppo."""
            def env_fn(nenv):
                return make_atari_env('Pong', nenv, frame_stack=4)

            def policy_fn(env):
                return Policy(NatureDQN(env.observation_space,
                                        env.action_space))

            ppo = partial(PPO, env_fn=env_fn, policy_fn=policy_fn)
            train('test', ppo, maxt=1000, eval=True, eval_period=1000)
            shutil.rmtree('test')
Exemplo n.º 3
0
Arquivo: ddpg.py Projeto: amackeith/dl
 def test_sac(self):
     """Test."""
     ddpg = partial(DDPG,
                    env_fn=env_fn,
                    policy_fn=policy_fn,
                    qf_fn=qf_fn,
                    learning_starts=300,
                    eval_num_episodes=1,
                    buffer_size=500)
     train('logs', ddpg, maxt=1000, eval=False, eval_period=1000)
     alg = ddpg('logs')
     assert alg.load() == 1000
     shutil.rmtree('logs')
Exemplo n.º 4
0
 def test_sac(self):
     """Test."""
     sac = partial(SAC,
                   env_fn=env_fn,
                   policy_fn=policy_fn,
                   qf_fn=qf_fn,
                   learning_starts=300,
                   eval_num_episodes=1,
                   buffer_size=500,
                   target_update_period=100)
     train('logs', sac, maxt=1000, eval=False, eval_period=1000)
     alg = sac('logs')
     alg.load()
     shutil.rmtree('logs')
Exemplo n.º 5
0
Arquivo: td3.py Projeto: amackeith/dl
 def test_ddpg(self):
     """Test."""
     td3 = partial(TD3,
                   env_fn=env_fn,
                   policy_fn=policy_fn,
                   qf_fn=qf_fn,
                   learning_starts=300,
                   eval_num_episodes=1,
                   buffer_size=500,
                   policy_update_period=2)
     train('logs', td3, maxt=1000, eval=False, eval_period=1000)
     alg = td3('logs')
     assert alg.load() == 1000
     shutil.rmtree('logs')
        def test_feed_forward_ppo2(self):
            """Test feed forward ppo2."""
            def env_fn(nenv):
                return make_env('LunarLanderContinuous-v2', nenv)

            def policy_fn(env):
                return Policy(PolicyNet(env.observation_space,
                                        env.action_space))

            def vf_fn(env):
                return ValueFunction(VFNet(env.observation_space,
                                           env.action_space))

            ppo = partial(ResidualPPO2, env_fn=env_fn, policy_fn=policy_fn,
                          value_fn=vf_fn)
            train('test', ppo, maxt=1000, eval=True, eval_period=1000)
            shutil.rmtree('test')
Exemplo n.º 7
0
        def test_feed_forward_ppo2(self):
            """Test feed forward ppo2."""
            def env_fn(nenv):
                return make_atari_env('Pong', nenv, frame_stack=4)

            def policy_fn(env):
                return Policy(NatureDQN(env.observation_space,
                                        env.action_space))

            def vf_fn(env):
                return ValueFunction(NatureDQNVF(env.observation_space,
                                                 env.action_space))

            ppo = partial(PPO2RND, env_fn=env_fn, policy_fn=policy_fn,
                          value_fn=vf_fn, rnd_net=RNDNet)
            train('test', ppo, maxt=100, eval=True, eval_period=100)
            shutil.rmtree('test')
Exemplo n.º 8
0
        def test_feed_forward_ppo(self):
            """Test feed forward ppo."""
            def env_fn(nenv):
                return make_env(env_id="LunarLanderRandomConstrained-v2",
                                nenv=nenv)

            def policy_fn(env):
                return Policy(
                    FeedForwardActorCriticBase(env.observation_space,
                                               env.action_space))

            ppo = partial(ConstrainedResidualPPO,
                          env_fn=env_fn,
                          nenv=32,
                          policy_fn=policy_fn,
                          base_actor_cls=RandomActor,
                          policy_training_start=500)
            train('test', ppo, maxt=1000, eval=True, eval_period=1000)
            alg = ppo('test')
            alg.load()
            shutil.rmtree('test')
Exemplo n.º 9
0
        def test_ql(self):
            """Test."""
            def env_fn(nenv):
                return make_atari_env('Pong', nenv, frame_stack=1)

            def qf_fn(env):
                return QFunction(
                    NatureDQN(env.observation_space, env.action_space))

            ql = partial(DQN,
                         env_fn=env_fn,
                         qf_fn=qf_fn,
                         learning_starts=100,
                         buffer_size=200,
                         update_period=4,
                         frame_stack=4,
                         exploration_timesteps=500,
                         target_update_period=100)
            train('logs', ql, maxt=1000, eval=True, eval_period=1000)
            alg = ql('logs')
            alg.load()
            assert np.allclose(alg.eps_schedule.value(alg.t), 0.1)
            shutil.rmtree('logs')
Exemplo n.º 10
0
import argparse
import dl

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Agent.')
    parser.add_argument('--expdir', type=str, help='expdir', required=True)
    parser.add_argument('--gin_config',
                        type=str,
                        help='gin config',
                        required=True)
    parser.add_argument('-b',
                        '--gin_bindings',
                        nargs='+',
                        help='gin bindings to overwrite config')
    args = parser.parse_args()
    dl.load_config(args.gin_config, args.gin_bindings)
    dl.train(args.expdir)
Exemplo n.º 11
0
import argparse
import dl
import yaml

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Agent.')
    parser.add_argument('config', type=str, help='config')
    args = parser.parse_args()
    with open(args.config, 'r') as f:
        config = yaml.load(f)

    gin_bindings = []
    for k, v in config['gin_bindings'].items():
        if isinstance(v, str) and v[0] != '@':
            gin_bindings.append(f'{k}="{v}"')
        else:
            gin_bindings.append(f"{k}={v}")
    dl.load_config(config['base_config'], gin_bindings)
    dl.train(config['logdir'])
Exemplo n.º 12
0
"""Main script for training models."""
import dl
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Script.')
    parser.add_argument('logdir', type=str, help='logdir')
    parser.add_argument('config', type=str, help='gin config')
    parser.add_argument('-b',
                        '--bindings',
                        nargs='+',
                        help='gin bindings to overwrite config')
    args = parser.parse_args()

    dl.load_config(args.config, args.bindings)
    dl.train(args.logdir)