Beispiel #1
0
def main(args):
    if args.gpu:
        ctx = get_extension_context('cudnn', device_id=str(args.device))
        nn.set_default_context(ctx)

    # atari environment
    env = AtariWrapper(gym.make(args.env), args.seed, episodic=True)
    eval_env = AtariWrapper(gym.make(args.env), args.seed, episodic=False)
    num_actions = env.action_space.n

    # action-value function built with neural network
    model = NoisyNetDQN(q_function, num_actions, args.batch_size, args.gamma,
                        args.lr)
    if args.load is not None:
        nn.load_parameters(args.load)
    model.update_target()

    buffer = ReplayBuffer(args.buffer_size, args.batch_size)

    exploration = ConstantEpsilonGreedy(num_actions, 0.0)

    monitor = prepare_monitor(args.logdir)

    update_fn = update(model, buffer, args.target_update_interval)

    eval_fn = evaluate(eval_env, model, render=args.render)

    train(env, model, buffer, exploration, monitor, update_fn, eval_fn,
          args.final_step, args.update_start, args.update_interval,
          args.save_interval, args.evaluate_interval, ['loss'])
Beispiel #2
0
def main(args):
    env = gym.make(args.env)
    env.seed(args.seed)
    eval_env = gym.make(args.env)
    eval_env.seed(50)
    action_shape = env.action_space.shape

    # GPU
    if args.gpu:
        ctx = get_extension_context('cudnn', device_id=str(args.device))
        nn.set_default_context(ctx)

    if args.load:
        nn.load_parameters(args.load)

    model = SAC(env.observation_space.shape, action_shape[0], args.batch_size,
                args.critic_lr, args.actor_lr, args.temp_lr, args.tau,
                args.gamma)
    model.sync_target()

    buffer = ReplayBuffer(args.buffer_size, args.batch_size)

    monitor = prepare_monitor(args.logdir)

    update_fn = update(model, buffer)

    eval_fn = evaluate(eval_env, model, render=args.render)

    train(env, model, buffer, EmptyNoise(), monitor, update_fn, eval_fn,
          args.final_step, args.batch_size, 1, args.save_interval,
          args.evaluate_interval, ['critic_loss', 'actor_loss', 'temp_loss'])
Beispiel #3
0
def main(args):
    env = gym.make(args.env)
    env.seed(args.seed)
    eval_env = gym.make(args.env)
    eval_env.seed(50)
    action_shape = env.action_space.shape

    # GPU
    if args.gpu:
        ctx = get_extension_context('cudnn', device_id=str(args.device))
        nn.set_default_context(ctx)

    if args.load:
        nn.load_parameters(args.load)

    model = TD3(env.observation_space.shape, action_shape[0], args.batch_size,
                args.critic_lr, args.actor_lr, args.tau, args.gamma,
                args.target_reg_sigma, args.target_reg_clip)
    model.sync_target()

    noise = NormalNoise(np.zeros(action_shape),
                        args.exploration_sigma + np.zeros(action_shape))

    buffer = ReplayBuffer(args.buffer_size, args.batch_size)

    monitor = prepare_monitor(args.logdir)

    update_fn = update(model, buffer, args.update_actor_freq)

    eval_fn = evaluate(eval_env, model, render=args.render)

    train(env, model, buffer, noise, monitor, update_fn, eval_fn,
          args.final_step, args.batch_size, 1, args.save_interval,
          args.evaluate_interval, ['critic_loss', 'actor_loss'])
Beispiel #4
0
def main(args):
    if args.gpu:
        ctx = get_extension_context('cudnn', device_id=str(args.device))
        nn.set_default_context(ctx)

    # atari environment
    num_envs = args.num_envs
    envs = [gym.make(args.env) for _ in range(num_envs)]
    batch_env = BatchEnv([AtariWrapper(env, args.seed) for env in envs])
    eval_env = AtariWrapper(gym.make(args.env), 50, episodic=False)
    num_actions = envs[0].action_space.n

    # action-value function built with neural network
    lr_scheduler = learning_rate_scheduler(args.lr, 10**7)
    model = A2C(num_actions, num_envs, num_envs * args.time_horizon,
                args.v_coeff, args.ent_coeff, lr_scheduler)
    if args.load is not None:
        nn.load_parameters(args.load)

    logdir = prepare_directory(args.logdir)

    eval_fn = evaluate(eval_env, model, args.render)

    # start training loop
    return_fn = compute_returns(args.gamma)
    train_loop(batch_env, model, num_actions, return_fn, logdir, eval_fn, args)