Esempio n. 1
0
def train(env_id, num_timesteps, seed):
    """
    Train PPO1 model for Atari environments, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    """
    rank = MPI.COMM_WORLD.Get_rank()

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = make_atari(env_id)

    env = bench.Monitor(env, logger.get_dir() and
                        os.path.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    env = wrap_deepmind(env)
    env.seed(workerseed)

    model = PPO1(CnnPolicy, env, timesteps_per_actorbatch=256, clip_param=0.2, entcoeff=0.01, optim_epochs=4,
                 optim_stepsize=1e-3, optim_batchsize=64, gamma=0.99, lam=0.95, schedule='linear', verbose=2)
    model.learn(total_timesteps=num_timesteps)
    env.close()
    del env
Esempio n. 2
0
def test_deepq():
    """
    test DeepQ on atari
    """
    logger.configure()
    set_global_seeds(SEED)
    env = make_atari(ENV_ID)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    q_func = deepq_models.cnn_to_mlp(convs=[(32, 8, 4), (64, 4, 2),
                                            (64, 3, 1)],
                                     hiddens=[256],
                                     dueling=True)

    model = DeepQ(env=env,
                  policy=q_func,
                  learning_rate=1e-4,
                  buffer_size=10000,
                  exploration_fraction=0.1,
                  exploration_final_eps=0.01,
                  train_freq=4,
                  learning_starts=10000,
                  target_network_update_freq=1000,
                  gamma=0.99,
                  prioritized_replay=True,
                  prioritized_replay_alpha=0.6,
                  checkpoint_freq=10000)
    model.learn(total_timesteps=NUM_TIMESTEPS)

    env.close()
    del model, env
Esempio n. 3
0
def test_deepq():
    """
    test DeepQ on atari
    """
    logger.configure()
    set_global_seeds(SEED)
    env = make_atari(ENV_ID)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)

    model = DQN(env=env,
                policy=CnnPolicy,
                learning_rate=1e-4,
                buffer_size=10000,
                exploration_fraction=0.1,
                exploration_final_eps=0.01,
                train_freq=4,
                learning_starts=10000,
                target_network_update_freq=1000,
                gamma=0.99,
                prioritized_replay=True,
                prioritized_replay_alpha=0.6,
                checkpoint_freq=10000)
    model.learn(total_timesteps=NUM_TIMESTEPS)

    env.close()
    del model, env
Esempio n. 4
0
 def _thunk():
   env = make_atari(env_id)
   env = gym.wrappers.Monitor(env, '/tmp/video', force=True, video_callable=lambda ep: True)
   env.seed(seed + rank)
   env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
                 allow_early_resets=allow_early_resets)
   return wrap_deepmind(env, **wrapper_kwargs)
Esempio n. 5
0
    def _thunk():
        # random_seed(seed)
        if env_id.startswith("dm"):
            import dm_control2gym
            _, domain, task = env_id.split('-')
            env = dm_control2gym.make(domain_name=domain, task_name=task)
        else:
            env = gym.make(env_id)
        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)
        # env.seed(seed + rank)
        env = OriginalReturnWrapper(env)
        if is_atari:
            env = wrap_deepmind(env,
                                episode_life=episode_life,
                                clip_rewards=False,
                                frame_stack=False,
                                scale=False)
            obs_shape = env.observation_space.shape
            if len(obs_shape) == 3:
                env = TransposeImage(env)
            env = FrameStack(env, 4)

        return env
Esempio n. 6
0
 def _thunk():
     env = make_atari(env_id)
     env.seed(seed + rank)
     env = Monitor(env,
                   logger.get_dir()
                   and os.path.join(logger.get_dir(), str(rank)),
                   allow_early_resets=allow_early_resets)
     return wrap_deepmind(env, **wrapper_kwargs)
Esempio n. 7
0
def make_env():

	env = wrap_deepmind(make_atari("PongNoFrameskip-v4"))
	workerseed = MPI.COMM_WORLD.Get_rank()*10000
	env.seed(workerseed)

	env = single_agent_wrapper(env)
	return env
	
Esempio n. 8
0
 def _init():
     env = make_atari(env_id)
     # if 'BoxWorld' in env_id:
     #     print('using wrap_boxworld!')
     #     env = wrap_boxworld(env, episode_life=False, clip_rewards=False, frame_stack=False, scale=False)
     # else:
     #     env = wrap_deepmind(env, episode_life=False, clip_rewards=False, frame_stack=False, scale=False)
     if useMonitor:
         env = Monitor(env, log_dir + str(rank), allow_early_resets=True)
     env.seed(seed + rank)
     return env
def train_trpo(env_id, num_timesteps, seed):

    # env_id: typr str, identifies each environment uniquely
    # num_timesteps: number of timesteps to run the algorithm
    # seed: initial random seed

    # set up the environment
    rank = MPI.COMM_WORLD.Get_rank()
    sseed = seed + 10000 * rank
    set_global_seeds(sseed)
    env = make_atari(env_id)
    env.seed(sseed)
    env = wrap_deepmind(make_atari(env_id))
    env.seed(sseed)
    # define policies
    policy = {
        'cnn': CnnPolicy,
        'lstm': CnnLstmPolicy,
        'lnlstm': CnnLnLstmPolicy,
        'mlp': MlpPolicy
    }[policy]
    # define TRPO class object
    model = TRPO(policy=policy,
                 env=env,
                 timesteps_per_batch=1024,
                 max_kl=0.01,
                 cg_iters=10,
                 cg_dampling=1e-3,
                 ent_coef=0.0,
                 gamma=0.99,
                 lam=1,
                 vf_iters=3,
                 vf_stepsize=1e-4,
                 verbose=1)
    # Train TRPO for num_timesteps
    model.learn(total_timesteps=num_timesteps)
    # save the hyperparameters and weights
    model.save('trpo' + env_id)
    env.close()
    # free the memory
    del model
Esempio n. 10
0
def make_env():

	# create pong environment and use wrappers from stable baselines
	env = wrap_deepmind(make_atari("PongNoFrameskip-v4"))
	workerseed = MPI.COMM_WORLD.Get_rank()*10000
	env.seed(workerseed)

	# convert standard gym interface to multiagent interface expected by ai arena
	env = single_agent_wrapper(env)
	return env
	
	
Esempio n. 11
0
 def _thunk():
     env = make_atari(env_id)
     env.seed(seed + rank)
     # env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
     #               allow_early_resets=allow_early_resets)
     if logdir is not None:
         env = Monitor(env, os.path.join(logdir, str(rank)), allow_early_resets=allow_early_resets)
     env = wrap_deepmind(env, **wrapper_kwargs)
     if extra_wrapper_func is not None:
         return extra_wrapper_func(env)
     else:
         return env
def train_ppo(env_id,
              num_timesteps,
              seed,
              policy,
              save_params,
              n_envs=1,
              nminibatches=5,
              n_steps=8000):
    """
     env_id: typr str, identifies each environment uniquely
     num_timesteps: number of timesteps to run the algorithm
     seed: initial random seed
     policy: policy to be followed (mlp, cnn, lstm, etc)
     n_env: number of envs to run in parallel
     nminibatches: number of minibatches of mini batch gradient descent (first-order optimization) to update the policy params
     n_steps: number of steps in each update
    """
    # Train PPO algorithm for num_timesteps
    # stack the frames for the vectorized environment
    # Note: PPO2 works only with vectorized environment

    set_global_seeds(seed)
    env = make_atari(env_id)
    env.seed(seed)
    env = Monitor(env, log_dir, allow_early_resets=True)
    env = wrap_deepmind(env, frame_stack=True)
    # define the policy
    policy = {
        'cnn': CnnPolicy,
        'lstm': CnnLstmPolicy,
        'lnlstm': CnnLnLstmPolicy,
        'mlp': MlpPolicy
    }[policy]
    # create model object for class PPO2
    model = PPO2(policy=policy,
                 env=env,
                 n_steps=n_steps,
                 nminibatches=nminibatches,
                 lam=0.95,
                 gamma=0.99,
                 noptepochs=4,
                 ent_coef=.01,
                 learning_rate=lambda f: f * 2.5e-4,
                 cliprange=lambda f: f * 0.1,
                 verbose=1)
    # train the model
    # trained for 2e7 timesteps with seed = 5
    model.learn(total_timesteps=num_timesteps, callback=callback)
    # save the hyperparameters and weights
    model.save(save_params)
    env.close()
    # free the memory
    del model
Esempio n. 13
0
def setup_env(env_name, train=True):
    if args.env == "CartPole-v0":
        env = gym.make(env_name)
    else:
        env = make_atari(env_name)
        if train:
            env = wrap_deepmind(env, episode_life=True, clip_rewards=False,
                                frame_stack=True, scale=True)    
        else:
            env = wrap_deepmind(env, episode_life=False, clip_rewards=False,
                                frame_stack=True, scale=True)    

    return env
def create_env(args, idx):
    """
    Create and return an environment according to args (parsed arguments).
    idx specifies idx of this environment among parallel environments.
    """
    monitor_file = os.path.join(args.output, ("env_%d" % idx))

    # Check for Atari envs
    if "NoFrameskip" in args.env:
        env = make_atari(args.env)
        env = wrap_deepmind(env, frame_stack=True)
    else:
        env = gym.make(args.env)
    env = Monitor(env, monitor_file)

    return env
Esempio n. 15
0
def main():
    """
    run the atari test
    """
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    q_func = deepq_models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )

    model = DeepQ(
        env=env,
        policy=q_func,
        learning_rate=1e-4,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=args.checkpoint_path,
    )
    model.learn(total_timesteps=args.num_timesteps)

    env.close()
Esempio n. 16
0
def setup_env(env_name, train=True):
    if env_name in ["CartPole-v0", "SpaceInvaders-ram-v0"]:
        env = gym.make(env_name)
    else:
        env = make_atari(env_name)
        if train:
            env = wrap_deepmind(env,
                                episode_life=True,
                                clip_rewards=False,
                                frame_stack=True,
                                scale=True)
        else:
            env = wrap_deepmind(env,
                                episode_life=False,
                                clip_rewards=False,
                                frame_stack=True,
                                scale=True)

    return env
Esempio n. 17
0
def train(env_id, num_timesteps, seed):
    """
    Train TRPO model for the atari environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    """
    rank = MPI.COMM_WORLD.Get_rank()

    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])

    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = make_atari(env_id)

    # def policy_fn(name, ob_space, ac_space, sess=None, placeholders=None):  # pylint: disable=W0613
    #     return CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space, sess=sess, placeholders=placeholders)

    env = bench.Monitor(
        env,
        logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    env = wrap_deepmind(env)
    env.seed(workerseed)

    model = TRPO(CnnPolicy,
                 env,
                 timesteps_per_batch=512,
                 max_kl=0.001,
                 cg_iters=10,
                 cg_damping=1e-3,
                 entcoeff=0.0,
                 gamma=0.98,
                 lam=1,
                 vf_iters=3,
                 vf_stepsize=1e-4)
    model.learn(total_timesteps=int(num_timesteps * 1.1))
    env.close()
Esempio n. 18
0
def main():
    """
    Run the atari test
    """
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env',
                        help='environment ID',
                        default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    policy = partial(CnnPolicy, dueling=args.dueling == 1)

    model = DQN(
        env=env,
        policy=policy,
        learning_rate=1e-4,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
    )
    model.learn(total_timesteps=args.num_timesteps)

    env.close()
def train_dqn_adv(env_id, train_timesteps, seed, policy, save_params, n_envs = 1):
    set_global_seeds(seed)
    env = make_atari(env_id)
    env.seed(seed)
    env = Monitor(env, log_dir, allow_early_resets=True)
    env = wrap_deepmind(env, frame_stack=True)
    # define the policy
    policy = {'cnn': CnnPolicy, 'mlp': MlpPolicy}[policy]
    # create model object for class DQN
    model = DQN(policy = policy, env = env, gamma=0.99, learning_rate=0.0001, buffer_size=10000, exploration_fraction=0.1, exploration_final_eps=0.01, 
                exploration_initial_eps=1.0, train_freq=4, batch_size=32, double_q=True, learning_starts=10000, target_network_update_freq=1000, 
                prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None, prioritized_replay_eps=1e-06, 
                param_noise=False, n_cpu_tf_sess=None, verbose=1)
    callback = save_best_model_callback(save_freq = 100, log_dir = log_dir, save_params = save_params, verbose=1)
    # train the model
    # trained for 2e7 timesteps with seed = 7
    model.learn(total_timesteps = train_timesteps, callback = callback)
    plot_results([log_dir], train_timesteps, results_plotter.X_TIMESTEPS, "DQNPong_TrainedByAdversary")
    plt.show()
    env.close()
    # free the memory
    del model
Esempio n. 20
0
    def _thunk():
        if env_id.startswith("dm"):
            _, domain, task = env_id.split('.')
            env = dm_control2gym.make(domain_name=domain, task_name=task)
        else:
            env = gym.make(env_id)

        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)

        env.seed(seed + rank)

        if str(env.__class__.__name__).find('TimeLimit') >= 0:
            env = TimeLimitMask(env)

        if log_dir is not None:
            env = bench.Monitor(env,
                                os.path.join(log_dir, str(rank)),
                                allow_early_resets=allow_early_resets)

        if is_atari:
            if len(env.observation_space.shape) == 3:
                env = wrap_deepmind(env)
        elif len(env.observation_space.shape) == 3:
            raise NotImplementedError(
                "CNN models work only for atari,\n"
                "please use a custom wrapper for a custom pixel input env.\n"
                "See wrap_deepmind for an example.")

        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env, op=[2, 0, 1])

        return env
Esempio n. 21
0
 def _thunk():
     env = make_atari(env_id)
     env.seed(seed + rank)
     env = Monitor(env, os.path.join(logdir, '{:03d}.monitor.csv'.format(rank)),
                   allow_early_resets=allow_early_resets)
     return wrap_deepmind(env, **wrapper_kwargs)
Esempio n. 22
0
 def _init():
     env = make_atari(env_id)
     # env = VecFrameStack(env, n_stack=4)
     env.seed(seed + rank)
     return env
Esempio n. 23
0
from stable_baselines.common.atari_wrappers import make_atari
from stable_baselines.deepq.policies import CnnPolicy
# from stable_baselines.common.policies import CnnPolicy
from stable_baselines.common.vec_env import VecFrameStack
from stable_baselines import DQN

env = make_atari('BreakoutNoFrameskip-v4')
# env = VecFrameStack(env, n_stack=4)

model = DQN(CnnPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("dqn_breakout")

obs = env.reset()

while True:
    action, _states = model.predict(obs)
    obs, rewards, done, info = env.step(action)
    env.render()

'''
from stable_baselines.common.cmd_util import make_atari_env
# from stable_baselines.deepq.policies import CnnPolicy
from stable_baselines.common.policies import CnnPolicy
from stable_baselines.common.vec_env import VecFrameStack
from stable_baselines import ACER

env = make_atari_env('BreakoutNoFrameskip-v4', num_env=4, seed=0)
env = VecFrameStack(env, n_stack=4)

model = ACER(CnnPolicy, env, verbose=1)
Esempio n. 24
0
 def _init():
     env = make_atari(env_id)
     env = Monitor(env, log_dir + str(rank), allow_early_resets=True)
     env.seed(seed + rank)
     return env
Esempio n. 25
0
import gym
from wrap_env import wrap_env
import numpy as np
import gym_boxworld
import matplotlib.pyplot as plt
from stable_baselines.common.atari_wrappers import make_atari, wrap_deepmind, wrap_boxworld

env_name = 'Breakout'
env_id = env_name + 'NoFrameskip-v4'
print(env_id)
env = make_atari(env_id)
# if 'BoxWorld' in env_id:
#     print('using wrap_boxworld!')
#     env = wrap_boxworld(env, episode_life=False, clip_rewards=True, frame_stack=True, scale=True)
# else:
#     env = wrap_deepmind(env, episode_life=True, clip_rewards=False, frame_stack=True, scale=False)

observation = env.reset()
print(env.action_space)
i = 0
for i in range(10000):
    observation, reward, done, info = env.step(env.action_space.sample())
    # if done:
    #     observation = env.reset()
    print(reward, done, info)
    # observation = np.array(observation)
    # print(observation.shape)
    # img = observation[:, :, 0]
    # fig = plt.figure(2)
    # plt.clf()
    # plt.imshow(img)
def main():
    """
    Run the atari test
    """
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env',
                        help='environment ID',
                        default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=0)
    parser.add_argument('--dueling', type=int, default=0)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)
    parser.add_argument('--kappa', default=1.0, type=float, help="Kappa value")
    parser.add_argument('--log',
                        default=False,
                        type=bool,
                        help="True if you wanna log progress, else False")
    parser.add_argument('--run', default=0, type=int, help="which run?")
    parser.add_argument('--phi_grad_update_freq',
                        default=1,
                        type=int,
                        help="gradient scaling")

    args = parser.parse_args()
    if not args.log:
        logger.configure(folder='./experiments/Atari/' + str(args.env) +
                         '/final_nkappa_cfa0.001_new/' + str(args.kappa) +
                         '_' + str(args.run) + '_' +
                         str(args.phi_grad_update_freq),
                         format_strs=["csv"])
        checkpoint_path = "./experiments/Atari/" + str(
            args.env) + "/models/" + str(args.kappa) + "_" + str(
                args.run) + '_xnew' + str(args.phi_grad_update_freq) + '.pkl'
    else:
        logger.configure()
        checkpoint_path = None

    set_global_seeds(args.run)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    policy = partial(CnnPolicy, dueling=args.dueling == 1)

    #test_env = make_atari(args.env)
    #lives = test_env.env.ale.lives()
    #if lives == 0:
    #    lives = 1
    #test_env = bench.Monitor(test_env, None)
    #test_env = wrap_atari_dqn(test_env)
    test_env = None
    lives = 1

    model = DQN(
        env=env,
        test_env=test_env,
        policy=policy,
        learning_rate=1e-4,  #0.00025, #1e-4,
        buffer_size=1e5,  #1e6, #1e5,
        exploration_fraction=0.1,
        exploration_final_eps=0.1,  #0.01
        train_freq=4,
        learning_starts=10000,  #50000,
        target_network_update_freq=1000,  #10000,
        gamma=0.99,
        kappa=args.kappa,
        verbose=1,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=checkpoint_path,
        phi_grad_update_freq=args.phi_grad_update_freq,
        seed=args.run,
        eval_episodes=lives)
    #model = DQN.load(args.env+"/models/"+str(args.kappa)+"_0_xnew1.pkl", env, test_env=test_env, checkpoint_path=checkpoint_path, eval_episodes=lives, kappa=args.kappa)

    model.learn(total_timesteps=args.num_timesteps)

    env.close()
Esempio n. 27
0
def setup_wandb(args):
    config = dict(env=args.env, max_frames=args.max_frames)
    wandb.init(project='rlmp',
               notes='Random Agent',
               tags=['Random'],
               config=config)


if __name__ == "__main__":

    args = get_args()
    setup_wandb(args)
    video_path = 'tmp/video/{}'.format(wandb.run.id)

    env = make_atari(args.env)
    env = wrap_deepmind(env)
    env = wrappers.Monitor(gym.make(args.env),
                           video_path,
                           video_callable=lambda x: x % 20 == 0)

    # Configure display
    virtual_display = Display(visible=0, size=(320, 240))
    virtual_display.start()

    num_frames = 0
    while num_frames < args.max_frames:

        state = env.reset()
        done = False
        ep_reward = 0