示例#1
0
if training_config['prioritized_replay']:
    buf = PrioritizedReplayBuffer(training_config['buffer_size'],
                                  alpha=training_config['alpha'],
                                  beta=training_config['beta'])
else:
    buf = ReplayBuffer(training_config['buffer_size'])
policy.set_eps(1)
train_collector = Collector(policy, train_envs, buf)
train_collector.collect(n_step=1)

train_fn = lambda e: [
    policy.set_eps(
        max(
            0.05, 1 - e / training_config['epoch'] / training_config[
                'exploration_ratio'])),
    torch.save(policy.state_dict(),
               os.path.join(save_path, 'policy_%d.pth' % (e)))
]

result = offpolicy_trainer(policy,
                           train_collector,
                           training_config['epoch'],
                           training_config['step_per_epoch'],
                           training_config['collect_per_step'],
                           training_config['batch_size'],
                           update_per_step=training_config['update_per_step'],
                           train_fn=train_fn,
                           writer=writer)

env.close()
def test_dqn(args=get_args()):
    env = LimitWrapper(StateBonus(ImgObsWrapper(gym.make(args.task))))
    # print(env.observation_space.spaces['image'])
    # args.state_shape = env.observation_space.spaces['image'].shape
    args.state_shape = env.observation_space.shape

    args.action_shape = env.env.action_space.shape or env.env.action_space.n

    # train_envs = gym.make(args.task)
    train_envs = SubprocVectorEnv([
        lambda: LimitWrapper(StateBonus(ImgObsWrapper(gym.make(args.task))))
        for _ in range(args.training_num)
    ])
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv([
        lambda: LimitWrapper(StateBonus(ImgObsWrapper(gym.make(args.task))))
        for _ in range(args.test_num)
    ])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = DQN(args.state_shape[0], args.state_shape[1], args.action_shape,
              args.device)
    net = net.to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       use_target_network=args.target_update_freq > 0,
                       target_update_freq=args.target_update_freq)
    # collector
    train_collector = Collector(policy, train_envs,
                                ReplayBuffer(args.buffer_size))
    test_collector = Collector(policy, test_envs)
    # policy.set_eps(1)
    train_collector.collect(n_step=args.batch_size * 4)
    print(len(train_collector.buffer))
    # log
    writer = SummaryWriter(args.logdir + '/' + 'dqn')

    def stop_fn(x):
        # if env.env.spec.reward_threshold:
        #     return x >= env.spec.reward_threshold
        # else:
        #     return False

        return False

    def train_fn(x):
        policy.set_eps(args.eps_train)

    def test_fn(x):
        policy.set_eps(args.eps_test)

    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.collect_per_step,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               stop_fn=stop_fn,
                               writer=writer,
                               task=args.task)

    train_collector.close()
    test_collector.close()
    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = LimitWrapper(StateBonus(ImgObsWrapper(gym.make(args.task))))
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')
        collector.close()

        torch.save(policy.state_dict(), 'dqn.pth')