def main(dir, interval):
    with logger.session(dir):
        saver = SnapshotSaver(dir, interval=interval)
        state = saver.get_state()
        alg_state = state['alg_state']
        env = alg_state['env_maker'].make()
        alg = state['alg']
        alg(env=env, snapshot_saver=saver, **alg_state)
Exemplo n.º 2
0
def main(dir, interval):
    with logger.session(dir):
        saver = SnapshotSaver(dir, interval=interval)
        state = saver.get_state()
        alg_state = state['alg_state']
        env = alg_state['env_maker'].make()
        alg = state['alg']
        alg(env=env, snapshot_saver=saver, **alg_state)
Exemplo n.º 3
0
def main(env_id, double, render):
    if env_id == 'GridWorld-v0':
        from simpledqn import gridworld_env
        env = gym.make('GridWorld-v0')

        def get_obs_dim(x):
            return x.observation_space.n

        def get_act_dim(x):
            return x.action_space.n

        obs_preprocessor = preprocess_obs_gridworld
        max_steps = 100000
        log_freq = 1000
        target_q_update_freq = 100
        initial_step = 0
        log_dir = "data/local/dqn_gridworld"
    elif env_id == 'Pong-ram-v0':
        env = EpisodicLifeEnv(NoopResetEnv(gym.make('Pong-ram-v0')))

        def get_obs_dim(x):
            return x.observation_space.shape[0]

        def get_act_dim(x):
            return x.action_space.n

        obs_preprocessor = preprocess_obs_ram
        max_steps = 10000000
        log_freq = 10000
        target_q_update_freq = 1000
        initial_step = 1000000
        log_dir = "data/local/dqn_pong"
    else:
        raise ValueError(
            "Unsupported environment: must be one of 'GridWorld-v0' 'Pong-ram-v0'"
        )

    logger.session(log_dir).__enter__()
    env.seed(42)

    # Initialize the replay buffer that we will use.
    replay_buffer = ReplayBuffer(max_size=10000)

    # Initialize DQN training procedure.
    dqn = DQN(
        env=env,
        get_obs_dim=get_obs_dim,
        get_act_dim=get_act_dim,
        obs_preprocessor=obs_preprocessor,
        replay_buffer=replay_buffer,

        # Q-value parameters
        q_dim_hid=[256, 256] if env_id == 'Pong-ram-v0' else [],
        opt_batch_size=64,

        # DQN gamma parameter
        discount=0.99,

        # Training procedure length
        initial_step=initial_step,
        max_steps=max_steps,
        learning_start_itr=max_steps // 100,
        # Frequency of copying the actual Q to the target Q
        target_q_update_freq=target_q_update_freq,
        # Frequency of updating the Q-value function
        train_q_freq=4,
        # Double Q
        double_q=double,

        # Exploration parameters
        initial_eps=1.0,
        final_eps=0.05,
        fraction_eps=0.1,

        # Logging
        log_freq=log_freq,
        render=render,
    )

    if env_id == 'Pong-ram-v0':
        # Warm start Q-function
        dqn._q.set_params(dqn._q.load('simpledqn/weights_warm_start.pkl'))
        dqn._qt.set_params(dqn._qt.load('simpledqn/weights_warm_start.pkl'))
        # Warm start replay buffer
        dqn._replay_buffer.load('simpledqn/replay_buffer_warm_start.pkl')
        print("Warm-starting Pong training!")

    if env_id == 'GridWorld-v0':
        # Run tests on GridWorld-v0
        test_args = dict(
            l_obs=nprs(0).rand(64, 16).astype(np.float32),
            l_act=nprs(1).randint(0, 3, size=(64, )),
            l_rew=nprs(2).randint(0, 3, size=(64, )).astype(np.float32),
            l_next_obs=nprs(3).rand(64, 16).astype(np.float32),
            l_done=nprs(4).randint(0, 2, size=(64, )).astype(np.float32),
        )
        if not double:
            tgt = np.array([1.909377098083496], dtype=np.float32)
            actual_var = dqn.compute_q_learning_loss(**test_args)
            test_name = "compute_q_learning_loss"
            assert isinstance(
                actual_var,
                C.Variable), "%s should return a Chainer variable" % test_name
            actual = actual_var.data
            try:
                assert_allclose(tgt, actual)
                print("Test for %s passed!" % test_name)
            except AssertionError as e:
                print("Warning: test for %s didn't pass!" % test_name)
                print(e)
                input(
                    "** Test failed. Press Ctrl+C to exit or press enter to continue training anyways"
                )
        else:
            tgt = np.array([1.9066928625106812], dtype=np.float32)
            actual_var = dqn.compute_double_q_learning_loss(**test_args)
            test_name = "compute_double_q_learning_loss"
            assert isinstance(
                actual_var,
                C.Variable), "%s should return a Chainer variable" % test_name
            actual = actual_var.data
            try:
                assert_allclose(tgt, actual)
                print("Test for %s passed!" % test_name)
            except AssertionError as e:
                print("Warning: test for %s didn't pass!" % test_name)
                print(e)
                input(
                    "** Test failed. Press Ctrl+C to exit or press enter to continue training anyways"
                )

    if render:
        dqn.test(epsilon=0.0)
    else:
        # Train the agent!
        dqn.train()

    # Close gym environment.
    env.close()
            test_feed.epoch_init(valid_config.batch_size,
                                 valid_config.backward_size,
                                 valid_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            valid_model.valid("ELBO_TEST", sess, test_feed)

            dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            test_feed.epoch_init(test_config.batch_size,
                                 test_config.backward_size,
                                 test_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            test_model.test_mul_ref(sess,
                                    test_feed,
                                    num_batch=None,
                                    repeat=5,
                                    dest=dest_f)
            dest_f.close()


if __name__ == "__main__":
    log_dir = os.path.join(FLAGS.work_dir, FLAGS.sub_dir)
    if FLAGS.forward_only:
        logger_dir = log_dir + "/testinfo"
    else:
        logger_dir = log_dir + "/traininfo"

    with logger.session(dir=logger_dir, format_strs=['stdout', 'csv', 'log']):
        main()
Exemplo n.º 5
0
#!/usr/bin/env python
from algs import a2c
from env_makers import EnvMaker
from models import CategoricalCNNPolicy
from utils import SnapshotSaver
import numpy as np
import os
import logger

log_dir = "data/local/a2c-pong"

np.random.seed(42)

# Clean up existing logs
os.system("rm -rf {}".format(log_dir))

with logger.session(log_dir):
    env_maker = EnvMaker('PongNoFrameskip-v4')
    env = env_maker.make()
    policy = CategoricalCNNPolicy(env.observation_space, env.action_space,
                                  env.spec)
    vf = policy.create_vf()
    a2c(
        env=env,
        env_maker=env_maker,
        n_envs=16,
        policy=policy,
        vf=vf,
        snapshot_saver=SnapshotSaver(log_dir, interval=10),
    )
from algs import trpo
from env_makers import EnvMaker
from models import CategoricalMLPPolicy, MLPBaseline
from utils import SnapshotSaver
import numpy as np
import os
import logger

log_dir = "data/local/trpo-cartpole"

np.random.seed(42)

# Clean up existing logs
os.system("rm -rf {}".format(log_dir))

with logger.session(log_dir):
    env_maker = EnvMaker('CartPole-v0')
    env = env_maker.make()
    policy = CategoricalMLPPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        env_spec=env.spec
    )
    baseline = MLPBaseline(
        observation_space=env.observation_space,
        action_space=env.action_space,
        env_spec=env.spec
    )
    trpo(
        env=env,
        env_maker=env_maker,
Exemplo n.º 7
0
def main(env_id, double, render):
    if env_id == 'GridWorld-v0':
        from simpledqn import gridworld_env
        env = gym.make('GridWorld-v0')

        def get_obs_dim(x): return x.observation_space.n

        def get_act_dim(x): return x.action_space.n
        obs_preprocessor = preprocess_obs_gridworld
        max_steps = 100000
        log_freq = 1000
        target_q_update_freq = 100
        initial_step = 0
        log_dir = "data/local/dqn_gridworld"
    elif env_id == 'Pong-ram-v0':
        env = EpisodicLifeEnv(NoopResetEnv(gym.make('Pong-ram-v0')))

        def get_obs_dim(x): return x.observation_space.shape[0]

        def get_act_dim(x): return x.action_space.n
        obs_preprocessor = preprocess_obs_ram
        max_steps = 10000000
        log_freq = 10000
        target_q_update_freq = 1000
        initial_step = 1000000
        log_dir = "data/local/dqn_pong"
    else:
        raise ValueError(
            "Unsupported environment: must be one of 'GridWorld-v0' 'Pong-ram-v0'")

    logger.session(log_dir).__enter__()
    env.seed(42)

    # Initialize the replay buffer that we will use.
    replay_buffer = ReplayBuffer(max_size=10000)

    # Initialize DQN training procedure.
    dqn = DQN(
        env=env,
        get_obs_dim=get_obs_dim,
        get_act_dim=get_act_dim,
        obs_preprocessor=obs_preprocessor,
        replay_buffer=replay_buffer,

        # Q-value parameters
        q_dim_hid=[256, 256] if env_id == 'Pong-ram-v0' else [],
        opt_batch_size=64,

        # DQN gamma parameter
        discount=0.99,

        # Training procedure length
        initial_step=initial_step,
        max_steps=max_steps,
        learning_start_itr=max_steps // 100,
        # Frequency of copying the actual Q to the target Q
        target_q_update_freq=target_q_update_freq,
        # Frequency of updating the Q-value function
        train_q_freq=4,
        # Double Q
        double_q=double,

        # Exploration parameters
        initial_eps=1.0,
        final_eps=0.05,
        fraction_eps=0.1,

        # Logging
        log_freq=log_freq,
        render=render,
    )

    if env_id == 'Pong-ram-v0':
        # Warm start Q-function
        dqn._q.set_params(dqn._q.load('simpledqn/weights_warm_start.pkl'))
        dqn._qt.set_params(dqn._qt.load('simpledqn/weights_warm_start.pkl'))
        # Warm start replay buffer
        dqn._replay_buffer.load('simpledqn/replay_buffer_warm_start.pkl')
        print("Warm-starting Pong training!")

    if env_id == 'GridWorld-v0':
        # Run tests on GridWorld-v0
        test_args = dict(
            l_obs=nprs(0).rand(64, 16).astype(np.float32),
            l_act=nprs(1).randint(0, 3, size=(64,)),
            l_rew=nprs(2).randint(0, 3, size=(64,)).astype(np.float32),
            l_next_obs=nprs(3).rand(64, 16).astype(np.float32),
            l_done=nprs(4).randint(0, 2, size=(64,)).astype(np.float32),
        )
        if not double:
            tgt = np.array([1.909377098083496], dtype=np.float32)
            actual_var = dqn.compute_q_learning_loss(**test_args)
            test_name = "compute_q_learning_loss"
            assert isinstance(
                actual_var, C.Variable), "%s should return a Chainer variable" % test_name
            actual = actual_var.data
            try:
                assert_allclose(tgt, actual)
                print("Test for %s passed!" % test_name)
            except AssertionError as e:
                print("Warning: test for %s didn't pass!" % test_name)
                print(e)
                input(
                    "** Test failed. Press Ctrl+C to exit or press enter to continue training anyways")
        else:
            tgt = np.array([1.9066928625106812], dtype=np.float32)
            actual_var = dqn.compute_double_q_learning_loss(**test_args)
            test_name = "compute_double_q_learning_loss"
            assert isinstance(
                actual_var, C.Variable), "%s should return a Chainer variable" % test_name
            actual = actual_var.data
            try:
                assert_allclose(tgt, actual)
                print("Test for %s passed!" % test_name)
            except AssertionError as e:
                print("Warning: test for %s didn't pass!" % test_name)
                print(e)
                input(
                    "** Test failed. Press Ctrl+C to exit or press enter to continue training anyways")

    if render:
        dqn.test(epsilon=0.0)
    else:
        # Train the agent!
        dqn.train()

    # Close gym environment.
    env.close()