示例#1
0
def test_ppo(args=get_args()):
    env = create_atari_environment(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space().shape or env.action_space().n
    # train_envs = gym.make(args.task)
    train_envs = SubprocVectorEnv([
        lambda: create_atari_environment(args.task)
        for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv([
        lambda: create_atari_environment(args.task)
        for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = Net(args.layer_num, args.state_shape, device=args.device)
    actor = Actor(net, args.action_shape).to(args.device)
    critic = Critic(net).to(args.device)
    optim = torch.optim.Adam(list(
        actor.parameters()) + list(critic.parameters()), lr=args.lr)
    dist = torch.distributions.Categorical
    policy = PPOPolicy(
        actor, critic, optim, dist, args.gamma,
        max_grad_norm=args.max_grad_norm,
        eps_clip=args.eps_clip,
        vf_coef=args.vf_coef,
        ent_coef=args.ent_coef,
        action_range=None)
    # collector
    train_collector = Collector(
        policy, train_envs, ReplayBuffer(args.buffer_size),
        preprocess_fn=preprocess_fn)
    test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
    # log
    writer = SummaryWriter(args.logdir + '/' + 'ppo')

    def stop_fn(x):
        if env.env.spec.reward_threshold:
            return x >= env.spec.reward_threshold
        else:
            return False

    # trainer
    result = onpolicy_trainer(
        policy, train_collector, test_collector, args.epoch,
        args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
    train_collector.close()
    test_collector.close()
    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = create_atari_environment(args.task)
        collector = Collector(policy, env, preprocess_fn=preprocess_fn)
        result = collector.collect(n_step=2000, render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')
        collector.close()
示例#2
0
def test_dqn(args=get_args()):
    env = create_atari_environment(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
    # train_envs = gym.make(args.task)
    train_envs = SubprocVectorEnv([
        lambda: create_atari_environment(args.task)
        for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv([
        lambda: create_atari_environment(args.task)
        for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = DQN(
        args.state_shape[0], args.state_shape[1],
        args.action_shape, args.device)
    net = net.to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    policy = DQNPolicy(
        net, optim, args.gamma, args.n_step,
        target_update_freq=args.target_update_freq)
    # collector
    train_collector = Collector(
        policy, train_envs, ReplayBuffer(args.buffer_size),
        preprocess_fn=preprocess_fn)
    test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
    # policy.set_eps(1)
    train_collector.collect(n_step=args.batch_size * 4)
    print(len(train_collector.buffer))
    # log
    writer = SummaryWriter(args.logdir + '/' + 'dqn')

    def stop_fn(x):
        if env.env.spec.reward_threshold:
            return x >= env.spec.reward_threshold
        else:
            return False

    def train_fn(x):
        policy.set_eps(args.eps_train)

    def test_fn(x):
        policy.set_eps(args.eps_test)

    # trainer
    result = offpolicy_trainer(
        policy, train_collector, test_collector, args.epoch,
        args.step_per_epoch, args.collect_per_step, args.test_num,
        args.batch_size, train_fn=train_fn, test_fn=test_fn,
        stop_fn=stop_fn, writer=writer)

    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = create_atari_environment(args.task)
        collector = Collector(policy, env, preprocess_fn=preprocess_fn)
        result = collector.collect(n_episode=1, render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')