示例#1
0
def test_ddpg(args=get_args()):
    env = gym.make(args.task)
    if args.task == 'Pendulum-v0':
        env.spec.reward_threshold = -250
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    # train_envs = gym.make(args.task)
    train_envs = VectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    actor = Actor(
        args.layer_num, args.state_shape, args.action_shape,
        args.max_action, args.device
    ).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    critic = Critic(
        args.layer_num, args.state_shape, args.action_shape, args.device
    ).to(args.device)
    critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
    policy = DDPGPolicy(
        actor, actor_optim, critic, critic_optim,
        args.tau, args.gamma, args.exploration_noise,
        [env.action_space.low[0], env.action_space.high[0]],
        reward_normalization=True, ignore_done=True)
    # collector
    train_collector = Collector(
        policy, train_envs, ReplayBuffer(args.buffer_size))
    test_collector = Collector(policy, test_envs)
    # log
    writer = SummaryWriter(args.logdir + '/' + 'ddpg')

    def stop_fn(x):
        return x >= env.spec.reward_threshold

    # trainer
    result = offpolicy_trainer(
        policy, train_collector, test_collector, args.epoch,
        args.step_per_epoch, args.collect_per_step, args.test_num,
        args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
    assert stop_fn(result['best_reward'])
    train_collector.close()
    test_collector.close()
    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = gym.make(args.task)
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')
        collector.close()
示例#2
0
def test_env_obs_dtype():
    for obs_type in ["array", "object"]:
        envs = SubprocVectorEnv(
            [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]])
        obs = envs.reset()
        assert obs.dtype == object
        obs = envs.step([1, 1, 1, 1])[0]
        assert obs.dtype == object
示例#3
0
    def __init__(self, args, mask=None, name=None):
        env = gym.make(args.task)
        if args.task == 'Pendulum-v0':
            env.spec.reward_threshold = -250
        self.state_shape = env.observation_space.shape or env.observation_space.n
        self.action_shape = env.action_space.shape or env.action_space.n
        self.max_action = env.action_space.high[0]

        self.stop_fn = lambda x: x >= env.spec.reward_threshold

        # env
        self.train_envs = VectorEnv(
            [lambda: gym.make(args.task) for _ in range(args.training_num)])
        self.test_envs = SubprocVectorEnv(
            [lambda: gym.make(args.task) for _ in range(args.test_num)])

        state_dim = int(np.prod(self.state_shape))
        self._view_mask = torch.ones(state_dim)
        if mask == 'even':
            for i in range(0, state_dim, 2):
                self._view_mask[i] = 0
        elif mask == "odd":
            for i in range(1, state_dim, 2):
                self._view_mask[i] = 0

        # policy
        self.actor = ActorWithView(args.layer_num, self.state_shape,
                                   self.action_shape, self.max_action,
                                   self._view_mask,
                                   args.device).to(args.device)
        self.actor_optim = torch.optim.Adam(self.actor.parameters(),
                                            lr=args.actor_lr)
        self.critic = CriticWithView(args.layer_num, self.state_shape,
                                     self._view_mask, self.action_shape,
                                     args.device).to(args.device)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(),
                                             lr=args.critic_lr)
        self.policy = DDPGPolicy(
            self.actor,
            self.actor_optim,
            self.critic,
            self.critic_optim,
            args.tau,
            args.gamma,
            args.exploration_noise,
            [env.action_space.low[0], env.action_space.high[0]],
            reward_normalization=True,
            ignore_done=True)

        # collector
        self.train_collector = Collector(self.policy, self.train_envs,
                                         ReplayBuffer(args.buffer_size))
        self.test_collector = Collector(self.policy, self.test_envs)

        # log
        self.writer = SummaryWriter(
            f"{args.logdir}/{args.task}/ddpg/{args.note}/{name}")
示例#4
0
 def watch():
     print("Testing agent ...")
     policy.eval()
     policy.set_eps(args.eps_test)
     envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
                              for _ in range(args.test_num)])
     envs.seed(args.seed)
     collector = Collector(policy, envs)
     result = collector.collect(n_episode=args.test_num, render=args.render)
     pprint.pprint(result)
示例#5
0
def train(hyper: dict):
    env_id = 'CartPole-v1'
    env = gym.make(env_id)
    hyper['state_dim'] = 4
    hyper['action_dim'] = 2

    train_envs = VectorEnv([lambda: gym.make(env_id) for _ in range(hyper['training_num'])])
    test_envs = SubprocVectorEnv([lambda: gym.make(env_id) for _ in range(hyper['test_num'])])

    if hyper['seed']:
        np.random.seed(hyper['random_seed'])
        torch.manual_seed(hyper['random_seed'])
        train_envs.seed(hyper['random_seed'])
        test_envs.seed(hyper['random_seed'])

    device = Pytorch.device()

    net = Net(hyper['layer_num'], hyper['state_dim'], device=device)
    actor = Actor(net, hyper['action_dim']).to(device)
    critic = Critic(net).to(device)
    optim = torch.optim.Adam(list(
        actor.parameters()) + list(critic.parameters()), lr=hyper['learning_rate'])
    dist = torch.distributions.Categorical
    policy = A2CPolicy(
        actor, critic, optim, dist, hyper['gamma'], vf_coef=hyper['vf_coef'],
        ent_coef=hyper['ent_coef'], max_grad_norm=hyper['max_grad_norm'])
    # collector
    train_collector = Collector(
        policy, train_envs, ReplayBuffer(hyper['capacity']))
    test_collector = Collector(policy, test_envs)

    writer = SummaryWriter('./a2c')

    def stop_fn(x):
        if env.env.spec.reward_threshold:
            return x >= env.spec.reward_threshold
        else:
            return False

    result = onpolicy_trainer(
        policy, train_collector, test_collector, hyper['epoch'],
        hyper['step_per_epoch'], hyper['collect_per_step'], hyper['repeat_per_collect'],
        hyper['test_num'], hyper['batch_size'], stop_fn=stop_fn, writer=writer,
        task=env_id)
    train_collector.close()
    test_collector.close()
    pprint.pprint(result)
    # 测试
    env = gym.make(env_id)
    collector = Collector(policy, env)
    result = collector.collect(n_episode=1, render=hyper['render'])
    print(f'Final reward: {result["rew"]}, length: {result["len"]}')
    collector.close()
示例#6
0
def test_ppo(args=get_args()):
    env = create_atari_environment(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space().shape or env.action_space().n
    # train_envs = gym.make(args.task)
    train_envs = SubprocVectorEnv([
        lambda: create_atari_environment(args.task)
        for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv([
        lambda: create_atari_environment(args.task)
        for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = Net(args.layer_num, args.state_shape, device=args.device)
    actor = Actor(net, args.action_shape).to(args.device)
    critic = Critic(net).to(args.device)
    optim = torch.optim.Adam(list(
        actor.parameters()) + list(critic.parameters()), lr=args.lr)
    dist = torch.distributions.Categorical
    policy = PPOPolicy(
        actor, critic, optim, dist, args.gamma,
        max_grad_norm=args.max_grad_norm,
        eps_clip=args.eps_clip,
        vf_coef=args.vf_coef,
        ent_coef=args.ent_coef,
        action_range=None)
    # collector
    train_collector = Collector(
        policy, train_envs, ReplayBuffer(args.buffer_size),
        preprocess_fn=preprocess_fn)
    test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
    # log
    writer = SummaryWriter(args.logdir + '/' + 'ppo')

    def stop_fn(x):
        if env.env.spec.reward_threshold:
            return x >= env.spec.reward_threshold
        else:
            return False

    # trainer
    result = onpolicy_trainer(
        policy, train_collector, test_collector, args.epoch,
        args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
    train_collector.close()
    test_collector.close()
    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = create_atari_environment(args.task)
        collector = Collector(policy, env, preprocess_fn=preprocess_fn)
        result = collector.collect(n_step=2000, render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')
        collector.close()
def data():
    np.random.seed(0)
    env = SimpleEnv()
    env.seed(0)
    env_vec = DummyVectorEnv([lambda: SimpleEnv() for _ in range(100)])
    env_vec.seed(np.random.randint(1000, size=100).tolist())
    env_subproc = SubprocVectorEnv([lambda: SimpleEnv() for _ in range(8)])
    env_subproc.seed(np.random.randint(1000, size=100).tolist())
    env_subproc_init = SubprocVectorEnv(
        [lambda: SimpleEnv() for _ in range(8)])
    env_subproc_init.seed(np.random.randint(1000, size=100).tolist())
    buffer = ReplayBuffer(50000)
    policy = SimplePolicy()
    collector = Collector(policy, env, ReplayBuffer(50000))
    collector_vec = Collector(policy, env_vec, ReplayBuffer(50000))
    collector_subproc = Collector(policy, env_subproc, ReplayBuffer(50000))
    return {
        "env": env,
        "env_vec": env_vec,
        "env_subproc": env_subproc,
        "env_subproc_init": env_subproc_init,
        "policy": policy,
        "buffer": buffer,
        "collector": collector,
        "collector_vec": collector_vec,
        "collector_subproc": collector_subproc,
    }
示例#8
0
    def __init__(self, args, mask=None, name='full'):
        env = gym.make(args.task)
        self.stop_fn = lambda x: x >= env.spec.reward_threshold
        self.state_shape = env.observation_space.shape or env.observation_space.n
        self.action_shape = env.action_space.shape or env.action_space.n
        self.buffer = ReplayBuffer(400000)

        # Env
        # train_envs = gym.make(args.task)
        self.train_envs = SubprocVectorEnv(
            [lambda: gym.make(args.task) for _ in range(args.training_num)])
        # test_envs = gym.make(args.task)
        self.test_envs = SubprocVectorEnv(
            [lambda: gym.make(args.task) for _ in range(args.test_num)])

        # Mask
        state_dim = int(np.prod(self.state_shape))
        self._view_mask = torch.ones(state_dim)
        if mask == 'even':
            for i in range(0, state_dim, 2):
                self._view_mask[i] = 0
        elif mask == "odd":
            for i in range(1, state_dim, 2):
                self._view_mask[i] = 0
        elif type(mask) == int:
            self._view_mask[mask] = 0

        # Model
        net = NetWithView(args.layer_num, self.state_shape, device=args.device,
                          mask=self._view_mask)
        self.actor = Actor(net, self.action_shape).to(args.device)
        self.critic = Critic(net).to(args.device)
        optim = torch.optim.Adam(list(
            self.actor.parameters()) + list(self.critic.parameters()), lr=args.lr)
        dist = torch.distributions.Categorical
        self.policy = PPOPolicy(
            self.actor, self.critic, optim, dist, args.gamma,
            max_grad_norm=args.max_grad_norm,
            eps_clip=args.eps_clip,
            vf_coef=args.vf_coef,
            ent_coef=args.ent_coef,
            action_range=None)

        # Collector
        self.train_collector = Collector(
            self.policy, self.train_envs, ReplayBuffer(args.buffer_size))
        self.test_collector = Collector(self.policy, self.test_envs)

        # Log
        self.writer = SummaryWriter(f'{args.logdir}/{args.task}/ppo/{args.note}/{name}')
示例#9
0
def test_td3(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    # train_envs = gym.make(args.task)
    train_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = Net(args.layer_num, args.state_shape, device=args.device)
    actor = Actor(
        net, args.action_shape,
        args.max_action, args.device
    ).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    net = Net(args.layer_num, args.state_shape,
              args.action_shape, concat=True, device=args.device)
    critic1 = Critic(net, args.device).to(args.device)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2 = Critic(net, args.device).to(args.device)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
    policy = TD3Policy(
        actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
        args.tau, args.gamma,
        GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
        args.update_actor_freq, args.noise_clip,
        [env.action_space.low[0], env.action_space.high[0]],
        reward_normalization=True, ignore_done=True)
    # collector
    train_collector = Collector(
        policy, train_envs, ReplayBuffer(args.buffer_size))
    test_collector = Collector(policy, test_envs)
    # train_collector.collect(n_step=args.buffer_size)
    # log
    writer = SummaryWriter(args.logdir + '/' + 'td3')

    def stop_fn(x):
        return x >= env.spec.reward_threshold

    # trainer
    result = offpolicy_trainer(
        policy, train_collector, test_collector, args.epoch,
        args.step_per_epoch, args.collect_per_step, args.test_num,
        args.batch_size, stop_fn=stop_fn, writer=writer)
    assert stop_fn(result['best_reward'])
    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = gym.make(args.task)
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')
示例#10
0
def test_collector():
    writer = SummaryWriter('log/collector')
    logger = Logger(writer)
    env_fns = [
        lambda: MyTestEnv(size=2, sleep=0),
        lambda: MyTestEnv(size=3, sleep=0),
        lambda: MyTestEnv(size=4, sleep=0),
        lambda: MyTestEnv(size=5, sleep=0),
    ]

    venv = SubprocVectorEnv(env_fns)
    policy = MyPolicy()
    env = env_fns[0]()
    c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
    c0.collect(n_step=3, log_fn=logger.log)
    assert equal(c0.buffer.obs[:3], [0, 1, 0])
    assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
    c0.collect(n_episode=3, log_fn=logger.log)
    assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
    assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
    c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
    c1.collect(n_step=6)
    assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
    assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
    c1.collect(n_episode=2)
    assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
    assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
    c2 = Collector(policy, gym.make('CartPole-v1'),
                   ReplayBuffer(size=20000, ignore_obs_next=False))
    r = c2.collect(n_step=10000, sampling=True)
    assert len(c2.buffer) > 10000
    print(r)
示例#11
0
def test_collector():
    env_fns = [
        lambda: MyTestEnv(size=2, sleep=0),
        lambda: MyTestEnv(size=3, sleep=0),
        lambda: MyTestEnv(size=4, sleep=0),
        lambda: MyTestEnv(size=5, sleep=0),
    ]
    venv = SubprocVectorEnv(env_fns)
    policy = MyPolicy()
    env = env_fns[0]()
    c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
    c0.collect(n_step=3)
    assert equal(c0.buffer.obs[:3], [0, 1, 0])
    assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
    c0.collect(n_episode=3)
    assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
    assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
    c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
    c1.collect(n_step=6)
    assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
    assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
    c1.collect(n_episode=2)
    assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
    assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
    c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
    c2.collect(n_episode=[1, 2, 2, 2])
    assert equal(c2.buffer.obs_next[:26], [
        1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
        1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
    c2.reset_env()
    c2.collect(n_episode=[2, 2, 2, 2])
    assert equal(c2.buffer.obs_next[26:54], [
        1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
        1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
示例#12
0
def test_collector_with_exact_episodes():
    env_lens = [2, 6, 3, 10]
    writer = SummaryWriter('log/exact_collector')
    logger = Logger(writer)
    env_fns = [
        lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
        for i in env_lens
    ]

    venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
    policy = MyPolicy()
    c1 = Collector(policy, venv, ReplayBuffer(size=1000,
                                              ignore_obs_next=False),
                   logger.preprocess_fn)
    n_episode1 = [2, 0, 5, 1]
    n_episode2 = [1, 3, 2, 0]
    c1.collect(n_episode=n_episode1)
    expected_steps = sum([a * b for a, b in zip(env_lens, n_episode1)])
    actual_steps = sum(venv.steps)
    assert expected_steps == actual_steps
    c1.collect(n_episode=n_episode2)
    expected_steps = sum(
        [a * (b + c) for a, b, c in zip(env_lens, n_episode1, n_episode2)])
    actual_steps = sum(venv.steps)
    assert expected_steps == actual_steps
示例#13
0
def reload(args=get_args()):
    slot_set = []
    with open('./dataset/slot_set.txt', 'r', encoding='utf-8') as f:
        for line in f.readlines():
            slot_set.append(line.strip())
    # slot_set =
    goals = {}
    with open('./dataset/test.pk', 'rb') as f:
        goals['test'] = pickle.load(f)

    for dic in goals['test']:
        dic['disease_tag'] = 'Esophagitis'

    total_disease = []
    with open('./dataset/disease.txt', 'r', encoding='utf-8') as f:
        for line in f.readlines():
            total_disease.append(line.strip())
    print(len(slot_set), slot_set)
    disease_num = len(total_disease)

    env = MedicalEnvrionment(slot_set, goals['test'], disease_num=disease_num)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n

    test_envs = SubprocVectorEnv([
        lambda: MedicalEnvrionment(slot_set,
                                   goals['test'],
                                   max_turn=args.max_episode_steps,
                                   flag="test",
                                   disease_num=disease_num)
        for _ in range(args.test_num)
    ])

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    test_envs.seed(args.seed)
    random.seed(args.seed)
    policy = torch.load('./model/ehr/policy.pth')
    test_collector = MyCollector(policy, test_envs)
    result = test_episode(policy,
                          test_collector,
                          test_fn=None,
                          epoch=1,
                          n_episode=len(goals['test']),
                          writer=None)

    return result
示例#14
0
def test_collector():
    writer = SummaryWriter('log/collector')
    logger = Logger(writer)
    env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]

    venv = SubprocVectorEnv(env_fns)
    dum = DummyVectorEnv(env_fns)
    policy = MyPolicy()
    env = env_fns[0]()
    c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn)
    c0.collect(n_step=3)
    assert len(c0.buffer) == 3
    assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
    c0.collect(n_episode=3)
    assert len(c0.buffer) == 8
    assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    c0.collect(n_step=3, random=True)
    c1 = Collector(policy, venv,
                   VectorReplayBuffer(total_size=100, buffer_num=4),
                   logger.preprocess_fn)
    c1.collect(n_step=8)
    obs = np.zeros(100)
    obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]

    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    c1.collect(n_episode=4)
    assert len(c1.buffer) == 16
    obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0],
                       [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
    c1.collect(n_episode=4, random=True)
    c2 = Collector(policy, dum, VectorReplayBuffer(total_size=100,
                                                   buffer_num=4),
                   logger.preprocess_fn)
    c2.collect(n_episode=7)
    obs1 = obs.copy()
    obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
    obs2 = obs.copy()
    obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
    c2obs = c2.buffer.obs[:, 0]
    assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
    c2.reset_env()
    c2.reset_buffer()
    assert c2.collect(n_episode=8)['n/ep'] == 8
    obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
    assert np.all(c2.buffer.obs[:, 0] == obs)
    c2.collect(n_episode=4, random=True)

    # test corner case
    with pytest.raises(TypeError):
        Collector(policy, dum, ReplayBuffer(10))
    with pytest.raises(TypeError):
        Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
    with pytest.raises(TypeError):
        c2.collect()
示例#15
0
def test_psrl(args=get_args()):
    env = gym.make(args.task)
    if args.task == "NChain-v0":
        env.spec.reward_threshold = 3647  # described in PSRL paper
    print("reward threshold:", env.spec.reward_threshold)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # train_envs = gym.make(args.task)
    # train_envs = gym.make(args.task)
    train_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    n_action = args.action_shape
    n_state = args.state_shape
    trans_count_prior = np.ones((n_state, n_action, n_state))
    rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior)
    rew_std_prior = np.full((n_state, n_action), args.rew_std_prior)
    policy = PSRLPolicy(
        trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps,
        args.add_done_loop)
    # collector
    train_collector = Collector(
        policy, train_envs, ReplayBuffer(args.buffer_size))
    test_collector = Collector(policy, test_envs)
    # log
    writer = SummaryWriter(args.logdir + '/' + args.task)

    def stop_fn(x):
        if env.spec.reward_threshold:
            return x >= env.spec.reward_threshold
        else:
            return False

    train_collector.collect(n_step=args.buffer_size, random=True)
    # trainer
    result = onpolicy_trainer(
        policy, train_collector, test_collector, args.epoch,
        args.step_per_epoch, args.collect_per_step, 1,
        args.test_num, 0, stop_fn=stop_fn, writer=writer,
        test_in_train=False)

    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        policy.eval()
        test_envs.seed(args.seed)
        test_collector.reset()
        result = test_collector.collect(n_episode=[1] * args.test_num,
                                        render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')
    elif env.spec.reward_threshold:
        assert result["best_reward"] >= env.spec.reward_threshold
示例#16
0
def test_collector_with_async(gym_reset_kwargs):
    env_lens = [2, 3, 4, 5]
    writer = SummaryWriter('log/async_collector')
    logger = Logger(writer)
    env_fns = [
        lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True)
        for i in env_lens
    ]

    venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
    policy = MyPolicy()
    bufsize = 60
    c1 = AsyncCollector(
        policy,
        venv,
        VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
        logger.preprocess_fn,
    )
    ptr = [0, 0, 0, 0]
    for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
        result = c1.collect(n_episode=n_episode,
                            gym_reset_kwargs=gym_reset_kwargs)
        assert result["n/ep"] >= n_episode
        # check buffer data, obs and obs_next, env_id
        for i, count in enumerate(
                np.bincount(result["lens"], minlength=6)[2:]):
            env_len = i + 2
            total = env_len * count
            indices = np.arange(ptr[i], ptr[i] + total) % bufsize
            ptr[i] = (ptr[i] + total) % bufsize
            seq = np.arange(env_len)
            buf = c1.buffer.buffers[i]
            assert np.all(buf.info.env_id[indices] == i)
            assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
            assert np.all(
                buf.obs_next[indices].reshape(count, env_len) == seq + 1)
    # test async n_step, for now the buffer should be full of data
    for n_step in tqdm.trange(1, 15, desc="test async n_step"):
        result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
        assert result["n/st"] >= n_step
        for i in range(4):
            env_len = i + 2
            seq = np.arange(env_len)
            buf = c1.buffer.buffers[i]
            assert np.all(buf.info.env_id == i)
            assert np.all(buf.obs.reshape(-1, env_len) == seq)
            assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
    with pytest.raises(TypeError):
        c1.collect()
示例#17
0
def test_vecenv(size=10, num=8, sleep=0.001):
    env_fns = [
        lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
        for i in range(size, size + num)
    ]
    venv = [
        DummyVectorEnv(env_fns),
        SubprocVectorEnv(env_fns),
        ShmemVectorEnv(env_fns),
    ]
    if has_ray():
        venv += [RayVectorEnv(env_fns)]
    for v in venv:
        v.seed(0)
    action_list = [1] * 5 + [0] * 10 + [1] * 20
    o = [v.reset() for v in venv]
    for a in action_list:
        o = []
        for v in venv:
            A, B, C, D = v.step([a] * num)
            if sum(C):
                A = v.reset(np.where(C)[0])
            o.append([A, B, C, D])
        for index, infos in enumerate(zip(*o)):
            if index == 3:  # do not check info here
                continue
            for info in infos:
                assert recurse_comp(infos[0], info)

    if __name__ == '__main__':
        t = [0] * len(venv)
        for i, e in enumerate(venv):
            t[i] = time.time()
            e.reset()
            for a in action_list:
                done = e.step([a] * num)[2]
                if sum(done) > 0:
                    e.reset(np.where(done)[0])
            t[i] = time.time() - t[i]
        for i, v in enumerate(venv):
            print(f'{type(v)}: {t[i]:.6f}s')

    for v in venv:
        assert v.size == list(range(size, size + num))
        assert v.env_num == num
        assert v.action_space == [Discrete(2)] * num
    for v in venv:
        v.close()
示例#18
0
def test_vecenv(size=10, num=8, sleep=0.001):
    verbose = __name__ == '__main__'
    env_fns = [
        lambda: MyTestEnv(size=size, sleep=sleep),
        lambda: MyTestEnv(size=size + 1, sleep=sleep),
        lambda: MyTestEnv(size=size + 2, sleep=sleep),
        lambda: MyTestEnv(size=size + 3, sleep=sleep),
        lambda: MyTestEnv(size=size + 4, sleep=sleep),
        lambda: MyTestEnv(size=size + 5, sleep=sleep),
        lambda: MyTestEnv(size=size + 6, sleep=sleep),
        lambda: MyTestEnv(size=size + 7, sleep=sleep),
    ]
    venv = [
        VectorEnv(env_fns),
        SubprocVectorEnv(env_fns),
    ]
    if verbose:
        venv.append(RayVectorEnv(env_fns))
    for v in venv:
        v.seed()
    action_list = [1] * 5 + [0] * 10 + [1] * 20
    if not verbose:
        o = [v.reset() for v in venv]
        for i, a in enumerate(action_list):
            o = []
            for v in venv:
                A, B, C, D = v.step([a] * num)
                if sum(C):
                    A = v.reset(np.where(C)[0])
                o.append([A, B, C, D])
            for i in zip(*o):
                for j in range(1, len(i)):
                    assert (i[0] == i[j]).all()
    else:
        t = [0, 0, 0]
        for i, e in enumerate(venv):
            t[i] = time.time()
            e.reset()
            for a in action_list:
                done = e.step([a] * num)[2]
                if sum(done) > 0:
                    e.reset(np.where(done)[0])
            t[i] = time.time() - t[i]
        print(f'VectorEnv: {t[0]:.6f}s')
        print(f'SubprocVectorEnv: {t[1]:.6f}s')
        print(f'RayVectorEnv: {t[2]:.6f}s')
    for v in venv:
        v.close()
示例#19
0
def test_collector():
    writer = SummaryWriter('log/collector')
    logger = Logger(writer)
    env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]

    venv = SubprocVectorEnv(env_fns)
    dum = DummyVectorEnv(env_fns)
    policy = MyPolicy()
    env = env_fns[0]()
    c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
                   logger.preprocess_fn)
    c0.collect(n_step=3)
    assert np.allclose(c0.buffer.obs[:4],
                       np.expand_dims([0, 1, 0, 1], axis=-1))
    assert np.allclose(c0.buffer[:4].obs_next,
                       np.expand_dims([1, 2, 1, 2], axis=-1))
    c0.collect(n_episode=3)
    assert np.allclose(c0.buffer.obs[:10],
                       np.expand_dims([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], axis=-1))
    assert np.allclose(c0.buffer[:10].obs_next,
                       np.expand_dims([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], axis=-1))
    c0.collect(n_step=3, random=True)
    c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
                   logger.preprocess_fn)
    c1.collect(n_step=6)
    assert np.allclose(c1.buffer.obs[:11], np.expand_dims(
        [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3], axis=-1))
    assert np.allclose(c1.buffer[:11].obs_next, np.expand_dims(
        [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4], axis=-1))
    c1.collect(n_episode=2)
    assert np.allclose(c1.buffer.obs[11:21],
                       np.expand_dims([0, 1, 2, 3, 4, 0, 1, 0, 1, 2], axis=-1))
    assert np.allclose(c1.buffer[11:21].obs_next,
                       np.expand_dims([1, 2, 3, 4, 5, 1, 2, 1, 2, 3], axis=-1))
    c1.collect(n_episode=3, random=True)
    c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False),
                   logger.preprocess_fn)
    c2.collect(n_episode=[1, 2, 2, 2])
    assert np.allclose(c2.buffer.obs_next[:26], np.expand_dims([
        1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
        1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], axis=-1))
    c2.reset_env()
    c2.collect(n_episode=[2, 2, 2, 2])
    assert np.allclose(c2.buffer.obs_next[26:54], np.expand_dims([
        1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
        1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], axis=-1))
    c2.collect(n_episode=[1, 1, 1, 1], random=True)
示例#20
0
def test_collector_with_async():
    env_lens = [2, 3, 4, 5]
    writer = SummaryWriter('log/async_collector')
    logger = Logger(writer)
    env_fns = [
        lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
        for i in env_lens
    ]

    venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
    policy = MyPolicy()
    c1 = Collector(policy, venv, ReplayBuffer(size=1000,
                                              ignore_obs_next=False),
                   logger.preprocess_fn)
    c1.collect(n_episode=10)
    # check if the data in the buffer is chronological
    # i.e. data in the buffer are full episodes, and each episode is
    # returned by the same environment
    env_id = c1.buffer.info['env_id']
    size = len(c1.buffer)
    obs = c1.buffer.obs[:size]
    done = c1.buffer.done[:size]
    obs_ground_truth = []
    i = 0
    while i < size:
        # i is the start of an episode
        if done[i]:
            # this episode has one transition
            assert env_lens[env_id[i]] == 1
            i += 1
            continue
        j = i
        while True:
            j += 1
            # in one episode, the environment id is the same
            assert env_id[j] == env_id[i]
            if done[j]:
                break
        j = j + 1  # j is the start of the next episode
        assert j - i == env_lens[env_id[i]]
        obs_ground_truth += list(range(j - i))
        i = j
    obs_ground_truth = np.expand_dims(np.array(obs_ground_truth), axis=-1)
    assert np.allclose(obs, obs_ground_truth)
def test_asynccollector():
    env_lens = [2, 3, 4, 5]
    env_fns = [
        lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True)
        for i in env_lens
    ]

    venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
    policy = MyPolicy()
    bufsize = 300
    c1 = AsyncCollector(
        policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4))
    ptr = [0, 0, 0, 0]
    for n_episode in tqdm.trange(1, 100, desc="test async n_episode"):
        result = c1.collect(n_episode=n_episode)
        assert result["n/ep"] >= n_episode
        # check buffer data, obs and obs_next, env_id
        for i, count in enumerate(
                np.bincount(result["lens"], minlength=6)[2:]):
            env_len = i + 2
            total = env_len * count
            indices = np.arange(ptr[i], ptr[i] + total) % bufsize
            ptr[i] = (ptr[i] + total) % bufsize
            seq = np.arange(env_len)
            buf = c1.buffer.buffers[i]
            assert np.all(buf.info.env_id[indices] == i)
            assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
            assert np.all(
                buf.obs_next[indices].reshape(count, env_len) == seq + 1)
    # test async n_step, for now the buffer should be full of data
    for n_step in tqdm.trange(1, 150, desc="test async n_step"):
        result = c1.collect(n_step=n_step)
        assert result["n/st"] >= n_step
        for i in range(4):
            env_len = i + 2
            seq = np.arange(env_len)
            buf = c1.buffer.buffers[i]
            assert np.all(buf.info.env_id == i)
            assert np.all(buf.obs.reshape(-1, env_len) == seq)
            assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
示例#22
0
def test_collector():
    writer = SummaryWriter('log/collector')
    logger = Logger(writer)
    env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]

    venv = SubprocVectorEnv(env_fns)
    policy = MyPolicy()
    env = env_fns[0]()
    c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
                   preprocess_fn)
    c0.collect(n_step=3, log_fn=logger.log)
    assert np.allclose(c0.buffer.obs[:3], [0, 1, 0])
    assert np.allclose(c0.buffer[:3].obs_next, [1, 2, 1])
    c0.collect(n_episode=3, log_fn=logger.log)
    assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
    assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
    c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
                   preprocess_fn)
    c1.collect(n_step=6)
    assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
    assert np.allclose(c1.buffer[:11].obs_next,
                       [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
    c1.collect(n_episode=2)
    assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
    assert np.allclose(c1.buffer[11:21].obs_next,
                       [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
    c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
                   preprocess_fn)
    c2.collect(n_episode=[1, 2, 2, 2])
    assert np.allclose(c2.buffer.obs_next[:26], [
        1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3,
        4, 5
    ])
    c2.reset_env()
    c2.collect(n_episode=[2, 2, 2, 2])
    assert np.allclose(c2.buffer.obs_next[26:54], [
        1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 3, 1, 2, 3, 4, 1,
        2, 3, 4, 5
    ])
示例#23
0
def test_reinforce(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low),
          np.max(env.action_space.high))
    # train_envs = gym.make(args.task)
    train_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)],
        norm_obs=True)
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)],
        norm_obs=True,
        obs_rms=train_envs.obs_rms,
        update_obs_rms=False)

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net_a = Net(args.state_shape,
                hidden_sizes=args.hidden_sizes,
                activation=nn.Tanh,
                device=args.device)
    actor = ActorProb(net_a,
                      args.action_shape,
                      max_action=args.max_action,
                      unbounded=True,
                      device=args.device).to(args.device)
    torch.nn.init.constant_(actor.sigma_param, -0.5)
    for m in actor.modules():
        if isinstance(m, torch.nn.Linear):
            # orthogonal initialization
            torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            torch.nn.init.zeros_(m.bias)
    # do last policy layer scaling, this will make initial actions have (close to)
    # 0 mean and std, and will help boost performances,
    # see https://arxiv.org/abs/2006.05990, Fig.24 for details
    for m in actor.mu.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.zeros_(m.bias)
            m.weight.data.copy_(0.01 * m.weight.data)

    optim = torch.optim.Adam(actor.parameters(), lr=args.lr)
    lr_scheduler = None
    if args.lr_decay:
        # decay learning rate to 0 linearly
        max_update_num = np.ceil(
            args.step_per_epoch / args.step_per_collect) * args.epoch

        lr_scheduler = LambdaLR(
            optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

    def dist(*logits):
        return Independent(Normal(*logits), 1)

    policy = PGPolicy(actor,
                      optim,
                      dist,
                      discount_factor=args.gamma,
                      reward_normalization=args.rew_norm,
                      action_scaling=True,
                      action_bound_method=args.action_bound_method,
                      lr_scheduler=lr_scheduler,
                      action_space=env.action_space)

    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)

    # collector
    if args.training_num > 1:
        buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
    else:
        buffer = ReplayBuffer(args.buffer_size)
    train_collector = Collector(policy,
                                train_envs,
                                buffer,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs)
    # log
    t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
    log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_reinforce'
    log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file)
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = BasicLogger(writer, update_interval=10, train_interval=100)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    if not args.watch:
        # trainer
        result = onpolicy_trainer(policy,
                                  train_collector,
                                  test_collector,
                                  args.epoch,
                                  args.step_per_epoch,
                                  args.repeat_per_collect,
                                  args.test_num,
                                  args.batch_size,
                                  step_per_collect=args.step_per_collect,
                                  save_fn=save_fn,
                                  logger=logger,
                                  test_in_train=False)
        pprint.pprint(result)

    # Let's watch its performance!
    policy.eval()
    test_envs.seed(args.seed)
    test_collector.reset()
    result = test_collector.collect(n_episode=args.test_num,
                                    render=args.render)
    print(
        f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}'
    )
示例#24
0
def test_sac(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
    # train_envs = gym.make(args.task)
    if args.training_num > 1:
        train_envs = SubprocVectorEnv(
            [lambda: gym.make(args.task) for _ in range(args.training_num)]
        )
    else:
        train_envs = gym.make(args.task)
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)]
    )
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
    actor = ActorProb(
        net_a,
        args.action_shape,
        max_action=args.max_action,
        device=args.device,
        unbounded=True,
        conditioned_sigma=True
    ).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    net_c1 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device
    )
    net_c2 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device
    )
    critic1 = Critic(net_c1, device=args.device).to(args.device)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2 = Critic(net_c2, device=args.device).to(args.device)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    if args.auto_alpha:
        target_entropy = -np.prod(env.action_space.shape)
        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        args.alpha = (target_entropy, log_alpha, alpha_optim)

    policy = SACPolicy(
        actor,
        actor_optim,
        critic1,
        critic1_optim,
        critic2,
        critic2_optim,
        tau=args.tau,
        gamma=args.gamma,
        alpha=args.alpha,
        estimation_step=args.n_step,
        action_space=env.action_space
    )

    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)

    # collector
    if args.training_num > 1:
        buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
    else:
        buffer = ReplayBuffer(args.buffer_size)
    train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
    test_collector = Collector(policy, test_envs)
    train_collector.collect(n_step=args.start_timesteps, random=True)
    # log
    t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
    log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_sac'
    log_path = os.path.join(args.logdir, args.task, 'sac', log_file)
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = TensorboardLogger(writer)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    if not args.watch:
        # trainer
        result = offpolicy_trainer(
            policy,
            train_collector,
            test_collector,
            args.epoch,
            args.step_per_epoch,
            args.step_per_collect,
            args.test_num,
            args.batch_size,
            save_fn=save_fn,
            logger=logger,
            update_per_step=args.update_per_step,
            test_in_train=False
        )
        pprint.pprint(result)

    # Let's watch its performance!
    policy.eval()
    test_envs.seed(args.seed)
    test_collector.reset()
    result = test_collector.collect(n_episode=args.test_num, render=args.render)
    print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
示例#25
0
def test_dqn(args=get_args()):
    env = make_atari_env(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    # make environments
    train_envs = SubprocVectorEnv(
        [lambda: make_atari_env(args) for _ in range(args.training_num)])
    test_envs = SubprocVectorEnv(
        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # define model
    net = DQN(*args.state_shape, args.action_shape,
              args.device).to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # define policy
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       target_update_freq=args.target_update_freq)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    buffer = ReplayBuffer(args.buffer_size,
                          ignore_obs_next=True,
                          save_only_last_obs=True,
                          stack_num=args.frames_stack)
    # collector
    train_collector = Collector(policy, train_envs, buffer)
    test_collector = Collector(policy, test_envs)
    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn')
    writer = SummaryWriter(log_path)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        if env.env.spec.reward_threshold:
            return mean_rewards >= env.spec.reward_threshold
        elif 'Pong' in args.task:
            return mean_rewards >= 20
        else:
            return False

    def train_fn(epoch, env_step):
        # nature DQN setting, linear decay in the first 1M steps
        if env_step <= 1e6:
            eps = args.eps_train - env_step / 1e6 * \
                (args.eps_train - args.eps_train_final)
        else:
            eps = args.eps_train_final
        policy.set_eps(eps)
        writer.add_scalar('train/eps', eps, global_step=env_step)

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    # watch agent's performance
    def watch():
        print("Setup test envs ...")
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        if args.save_buffer_name:
            print(f"Generate buffer with size {args.buffer_size}")
            buffer = ReplayBuffer(args.buffer_size,
                                  ignore_obs_next=True,
                                  save_only_last_obs=True,
                                  stack_num=args.frames_stack)
            collector = Collector(policy, test_envs, buffer)
            result = collector.collect(n_step=args.buffer_size)
            print(f"Save buffer into {args.save_buffer_name}")
            # Unfortunately, pickle will cause oom with 1M buffer size
            buffer.save_hdf5(args.save_buffer_name)
        else:
            print("Testing agent ...")
            test_collector.reset()
            result = test_collector.collect(n_episode=[1] * args.test_num,
                                            render=args.render)
        pprint.pprint(result)

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * 4)
    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.collect_per_step,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               stop_fn=stop_fn,
                               save_fn=save_fn,
                               writer=writer,
                               test_in_train=False)

    pprint.pprint(result)
    watch()
示例#26
0
def test_fqf(args=get_args()):
    env = make_atari_env(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    # make environments
    train_envs = SubprocVectorEnv(
        [lambda: make_atari_env(args) for _ in range(args.training_num)])
    test_envs = SubprocVectorEnv(
        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # define model
    feature_net = DQN(*args.state_shape,
                      args.action_shape,
                      args.device,
                      features_only=True)
    net = FullQuantileFunction(feature_net,
                               args.action_shape,
                               args.hidden_sizes,
                               args.num_cosines,
                               device=args.device).to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
    fraction_optim = torch.optim.RMSprop(fraction_net.parameters(),
                                         lr=args.fraction_lr)
    # define policy
    policy = FQFPolicy(net,
                       optim,
                       fraction_net,
                       fraction_optim,
                       args.gamma,
                       args.num_fractions,
                       args.ent_coef,
                       args.n_step,
                       target_update_freq=args.target_update_freq).to(
                           args.device)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    buffer = VectorReplayBuffer(args.buffer_size,
                                buffer_num=len(train_envs),
                                ignore_obs_next=True,
                                save_only_last_obs=True,
                                stack_num=args.frames_stack)
    # collector
    train_collector = Collector(policy,
                                train_envs,
                                buffer,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # log
    log_path = os.path.join(args.logdir, args.task, 'fqf')
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = BasicLogger(writer)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        if env.spec.reward_threshold:
            return mean_rewards >= env.spec.reward_threshold
        elif 'Pong' in args.task:
            return mean_rewards >= 20
        else:
            return False

    def train_fn(epoch, env_step):
        # nature DQN setting, linear decay in the first 1M steps
        if env_step <= 1e6:
            eps = args.eps_train - env_step / 1e6 * \
                (args.eps_train - args.eps_train_final)
        else:
            eps = args.eps_train_final
        policy.set_eps(eps)
        logger.write('train/eps', env_step, eps)

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    # watch agent's performance
    def watch():
        print("Setup test envs ...")
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        if args.save_buffer_name:
            print(f"Generate buffer with size {args.buffer_size}")
            buffer = VectorReplayBuffer(args.buffer_size,
                                        buffer_num=len(test_envs),
                                        ignore_obs_next=True,
                                        save_only_last_obs=True,
                                        stack_num=args.frames_stack)
            collector = Collector(policy,
                                  test_envs,
                                  buffer,
                                  exploration_noise=True)
            result = collector.collect(n_step=args.buffer_size)
            print(f"Save buffer into {args.save_buffer_name}")
            # Unfortunately, pickle will cause oom with 1M buffer size
            buffer.save_hdf5(args.save_buffer_name)
        else:
            print("Testing agent ...")
            test_collector.reset()
            result = test_collector.collect(n_episode=args.test_num,
                                            render=args.render)
        rew = result["rews"].mean()
        print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * args.training_num)
    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.step_per_collect,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               stop_fn=stop_fn,
                               save_fn=save_fn,
                               logger=logger,
                               update_per_step=args.update_per_step,
                               test_in_train=False)

    pprint.pprint(result)
    watch()
示例#27
0
def test_td3(args=get_args()):
    # initialize environment
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    train_envs = VectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    actor = Actor(args.layer_num,
                  args.state_shape,
                  args.action_shape,
                  args.max_action,
                  args.device,
                  hidden_layer_size=args.hidden_size).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    critic1 = Critic(args.layer_num,
                     args.state_shape,
                     args.action_shape,
                     args.device,
                     hidden_layer_size=args.hidden_size).to(args.device)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2 = Critic(args.layer_num,
                     args.state_shape,
                     args.action_shape,
                     args.device,
                     hidden_layer_size=args.hidden_size).to(args.device)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    policy = TD3Policy(
        actor,
        actor_optim,
        critic1,
        critic1_optim,
        critic2,
        critic2_optim,
        args.tau,
        args.gamma,
        GaussianNoise(sigma=args.exploration_noise),
        args.policy_noise,
        args.update_actor_freq,
        args.noise_clip,
        action_range=[env.action_space.low[0], env.action_space.high[0]],
        reward_normalization=args.rew_norm,
        ignore_done=False)
    # collector
    if args.training_num == 0:
        max_episode_steps = train_envs._max_episode_steps
    else:
        max_episode_steps = train_envs.envs[0]._max_episode_steps
    train_collector = Collector(
        policy, train_envs,
        ReplayBuffer(args.buffer_size, max_ep_len=max_episode_steps))
    test_collector = Collector(policy, test_envs, mode='test')
    # log
    log_path = os.path.join(args.logdir, args.task, 'td3', str(args.seed))
    writer = SummaryWriter(log_path)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    env.spec.reward_threshold = 100000

    def stop_fn(x):
        return x >= env.spec.reward_threshold

    # trainer
    result = offpolicy_exact_trainer(policy,
                                     train_collector,
                                     test_collector,
                                     args.epoch,
                                     args.step_per_epoch,
                                     args.collect_per_step,
                                     args.test_num,
                                     args.batch_size,
                                     stop_fn=stop_fn,
                                     save_fn=save_fn,
                                     writer=writer)
    assert stop_fn(result['best_reward'])
    train_collector.close()
    test_collector.close()
    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = gym.make(args.task)
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')
        collector.close()
示例#28
0
def test_dqn(args=get_args()):
    env = make_atari_env(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n

    # should be N_FRAMES x H x W

    print("Observations shape: ", args.state_shape)
    print("Actions shape: ", args.action_shape)
    # make environments
    train_envs = SubprocVectorEnv(
        [lambda: make_atari_env(args) for _ in range(args.training_num)])
    test_envs = SubprocVectorEnv(
        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # log
    log_path = os.path.join(args.logdir, args.task, 'embedding')
    embedding_writer = SummaryWriter(log_path + '/with_init')

    embedding_net = embedding_prediction.Prediction(
        *args.state_shape, args.action_shape,
        args.device).to(device=args.device)
    embedding_net.apply(embedding_prediction.weights_init)

    if args.embedding_path:
        embedding_net.load_state_dict(torch.load(log_path + '/embedding.pth'))
        print("Loaded agent from: ", log_path + '/embedding.pth')
    # numel_list = [p.numel() for p in embedding_net.parameters()]
    # print(sum(numel_list), numel_list)

    pre_buffer = ReplayBuffer(args.buffer_size,
                              save_only_last_obs=True,
                              stack_num=args.frames_stack)
    pre_test_buffer = ReplayBuffer(args.buffer_size // 100,
                                   save_only_last_obs=True,
                                   stack_num=args.frames_stack)

    train_collector = Collector(None, train_envs, pre_buffer)
    test_collector = Collector(None, test_envs, pre_test_buffer)
    if args.embedding_data_path:
        pre_buffer = pickle.load(open(log_path + '/train_data.pkl', 'rb'))
        pre_test_buffer = pickle.load(open(log_path + '/test_data.pkl', 'rb'))
        train_collector.buffer = pre_buffer
        test_collector.buffer = pre_test_buffer
        print('load success')
    else:
        print('collect start')
        train_collector = Collector(None, train_envs, pre_buffer)
        test_collector = Collector(None, test_envs, pre_test_buffer)
        train_collector.collect(n_step=args.buffer_size, random=True)
        test_collector.collect(n_step=args.buffer_size // 100, random=True)
        print(len(train_collector.buffer))
        print(len(test_collector.buffer))
        if not os.path.exists(log_path):
            os.makedirs(log_path)
        pickle.dump(pre_buffer, open(log_path + '/train_data.pkl', 'wb'))
        pickle.dump(pre_test_buffer, open(log_path + '/test_data.pkl', 'wb'))
        print('collect finish')

    #使用得到的数据训练编码网络
    def part_loss(x, device='cpu'):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, device=device, dtype=torch.float32)
        x = x.view(128, -1)
        temp = torch.cat(
            ((1 - x).pow(2.0).unsqueeze_(0), x.pow(2.0).unsqueeze_(0)), dim=0)
        temp_2 = torch.min(temp, dim=0)[0]
        return torch.sum(temp_2)

    pre_optim = torch.optim.Adam(embedding_net.parameters(), lr=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(pre_optim,
                                                step_size=50000,
                                                gamma=0.1,
                                                last_epoch=-1)
    test_batch_data = test_collector.sample(batch_size=0)

    loss_fn = torch.nn.NLLLoss()
    # train_loss = []
    for epoch in range(1, 100001):
        embedding_net.train()
        batch_data = train_collector.sample(batch_size=128)
        # print(batch_data)
        # print(batch_data['obs'][0] == batch_data['obs'][1])
        pred = embedding_net(batch_data['obs'], batch_data['obs_next'])
        x1 = pred[1]
        x2 = pred[2]
        # print(torch.argmax(pred[0], dim=1))
        if not isinstance(batch_data['act'], torch.Tensor):
            act = torch.tensor(batch_data['act'],
                               device=args.device,
                               dtype=torch.int64)
        # print(pred[0].dtype)
        # print(act.dtype)
        # l2_norm = sum(p.pow(2.0).sum() for p in embedding_net.net.parameters())
        # loss = loss_fn(pred[0], act) + 0.001 * (part_loss(x1) + part_loss(x2)) / 64
        loss_1 = loss_fn(pred[0], act)
        loss_2 = 0.01 * (part_loss(x1, args.device) +
                         part_loss(x2, args.device)) / 128
        loss = loss_1 + loss_2
        # print(loss_1)
        # print(loss_2)
        embedding_writer.add_scalar('training loss1', loss_1.item(), epoch)
        embedding_writer.add_scalar('training loss2', loss_2, epoch)
        embedding_writer.add_scalar('training loss', loss.item(), epoch)
        # train_loss.append(loss.detach().item())
        pre_optim.zero_grad()
        loss.backward()
        pre_optim.step()
        scheduler.step()

        if epoch % 10000 == 0 or epoch == 1:
            print(pre_optim.state_dict()['param_groups'][0]['lr'])
            # print("Epoch: %d,Train: Loss: %f" % (epoch, float(loss.item())))
            correct = 0
            numel_list = [p for p in embedding_net.parameters()][-2]
            print(numel_list)
            embedding_net.eval()
            with torch.no_grad():
                test_pred, x1, x2, _ = embedding_net(
                    test_batch_data['obs'], test_batch_data['obs_next'])
                if not isinstance(test_batch_data['act'], torch.Tensor):
                    act = torch.tensor(test_batch_data['act'],
                                       device=args.device,
                                       dtype=torch.int64)
                loss_1 = loss_fn(test_pred, act)
                loss_2 = 0.01 * (part_loss(x1, args.device) +
                                 part_loss(x2, args.device)) / 128
                loss = loss_1 + loss_2
                embedding_writer.add_scalar('test loss', loss.item(), epoch)
                # print("Test Loss: %f" % (float(loss)))
                print(torch.argmax(test_pred, dim=1))
                print(act)
                correct += int((torch.argmax(test_pred, dim=1) == act).sum())
                print('Acc:', correct / len(test_batch_data))

    torch.save(embedding_net.state_dict(),
               os.path.join(log_path, 'embedding.pth'))
    embedding_writer.close()
    # plt.figure()
    # plt.plot(np.arange(100000),train_loss)
    # plt.show()
    exit()
    #构建hash表

    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn')
    writer = SummaryWriter(log_path)

    # define model
    net = DQN(*args.state_shape, args.action_shape,
              args.device).to(device=args.device)

    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # define policy
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       target_update_freq=args.target_update_freq)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(torch.load(args.resume_path))
        print("Loaded agent from: ", args.resume_path)

    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    pre_buffer.reset()
    buffer = ReplayBuffer(args.buffer_size,
                          ignore_obs_next=True,
                          save_only_last_obs=True,
                          stack_num=args.frames_stack)
    # collector
    # train_collector中传入preprocess_fn对奖励进行重构
    train_collector = Collector(policy, train_envs, buffer)
    test_collector = Collector(policy, test_envs)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(x):
        if env.env.spec.reward_threshold:
            return x >= env.spec.reward_threshold
        elif 'Pong' in args.task:
            return x >= 20

    def train_fn(x):
        # nature DQN setting, linear decay in the first 1M steps
        now = x * args.collect_per_step * args.step_per_epoch
        if now <= 1e6:
            eps = args.eps_train - now / 1e6 * \
                (args.eps_train - args.eps_train_final)
            policy.set_eps(eps)
        else:
            policy.set_eps(args.eps_train_final)
        print("set eps =", policy.eps)

    def test_fn(x):
        policy.set_eps(args.eps_test)

    # watch agent's performance
    def watch():
        print("Testing agent ...")
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        test_collector.reset()
        result = test_collector.collect(n_episode=[1] * args.test_num,
                                        render=1 / 30)
        pprint.pprint(result)

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * 4)
    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.collect_per_step,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               stop_fn=stop_fn,
                               save_fn=save_fn,
                               writer=writer,
                               test_in_train=False)

    pprint.pprint(result)
    watch()
示例#29
0
def test_dqn(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # train_envs = gym.make(args.task)
    # you can also use tianshou.env.SubprocVectorEnv
    train_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes}
    V_param = {"hidden_sizes": args.dueling_v_hidden_sizes}
    net = Net(args.state_shape,
              args.action_shape,
              hidden_sizes=args.hidden_sizes,
              device=args.device,
              dueling_param=(Q_param, V_param)).to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       target_update_freq=args.target_update_freq)
    # collector
    train_collector = Collector(policy,
                                train_envs,
                                VectorReplayBuffer(args.buffer_size,
                                                   len(train_envs)),
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # policy.set_eps(1)
    train_collector.collect(n_step=args.batch_size * args.training_num)
    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn')
    writer = SummaryWriter(log_path)
    logger = TensorboardLogger(writer)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        return mean_rewards >= env.spec.reward_threshold

    def train_fn(epoch, env_step):  # exp decay
        eps = max(args.eps_train * (1 - 5e-6)**env_step, args.eps_test)
        policy.set_eps(eps)

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.step_per_collect,
                               args.test_num,
                               args.batch_size,
                               update_per_step=args.update_per_step,
                               stop_fn=stop_fn,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               save_fn=save_fn,
                               logger=logger)

    assert stop_fn(result['best_reward'])
    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        test_collector.reset()
        result = test_collector.collect(n_episode=args.test_num,
                                        render=args.render)
        rews, lens = result["rews"], result["lens"]
        print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
示例#30
0
def test_cql():
    args = get_args()
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]  # float
    print("device:", args.device)
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low),
          np.max(env.action_space.high))

    args.state_dim = args.state_shape[0]
    args.action_dim = args.action_shape[0]
    print("Max_action", args.max_action)

    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    test_envs.seed(args.seed)

    # model
    # actor network
    net_a = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        device=args.device,
    )
    actor = ActorProb(net_a,
                      action_shape=args.action_shape,
                      max_action=args.max_action,
                      device=args.device,
                      unbounded=True,
                      conditioned_sigma=True).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

    # critic network
    net_c1 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    net_c2 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    critic1 = Critic(net_c1, device=args.device).to(args.device)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2 = Critic(net_c2, device=args.device).to(args.device)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    if args.auto_alpha:
        target_entropy = -np.prod(env.action_space.shape)
        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        args.alpha = (target_entropy, log_alpha, alpha_optim)

    policy = CQLPolicy(
        actor,
        actor_optim,
        critic1,
        critic1_optim,
        critic2,
        critic2_optim,
        cql_alpha_lr=args.cql_alpha_lr,
        cql_weight=args.cql_weight,
        tau=args.tau,
        gamma=args.gamma,
        alpha=args.alpha,
        temperature=args.temperature,
        with_lagrange=args.with_lagrange,
        lagrange_threshold=args.lagrange_threshold,
        min_action=np.min(env.action_space.low),
        max_action=np.max(env.action_space.high),
        device=args.device,
    )

    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)

    # collector
    test_collector = Collector(policy, test_envs)

    # log
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
    args.algo_name = "cql"
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
    log_path = os.path.join(args.logdir, log_name)

    # logger
    if args.logger == "wandb":
        logger = WandbLogger(
            save_interval=1,
            name=log_name.replace(os.path.sep, "__"),
            run_id=args.resume_id,
            config=args,
            project=args.wandb_project,
        )
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    if args.logger == "tensorboard":
        logger = TensorboardLogger(writer)
    else:  # wandb
        logger.load(writer)

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

    def watch():
        if args.resume_path is None:
            args.resume_path = os.path.join(log_path, "policy.pth")

        policy.load_state_dict(
            torch.load(args.resume_path, map_location=torch.device("cpu")))
        policy.eval()
        collector = Collector(policy, env)
        collector.collect(n_episode=1, render=1 / 35)

    if not args.watch:
        dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
        dataset_size = dataset["rewards"].size

        print("dataset_size", dataset_size)
        replay_buffer = ReplayBuffer(dataset_size)

        for i in range(dataset_size):
            replay_buffer.add(
                Batch(
                    obs=dataset["observations"][i],
                    act=dataset["actions"][i],
                    rew=dataset["rewards"][i],
                    done=dataset["terminals"][i],
                    obs_next=dataset["next_observations"][i],
                ))
        print("dataset loaded")
        # trainer
        result = offline_trainer(
            policy,
            replay_buffer,
            test_collector,
            args.epoch,
            args.step_per_epoch,
            args.test_num,
            args.batch_size,
            save_best_fn=save_best_fn,
            logger=logger,
        )
        pprint.pprint(result)
    else:
        watch()

    # Let's watch its performance!
    policy.eval()
    test_envs.seed(args.seed)
    test_collector.reset()
    result = test_collector.collect(n_episode=args.test_num,
                                    render=args.render)
    print(
        f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}"
    )