예제 #1
0
def test_pickle():
    size = 100
    vbuf = ReplayBuffer(size, stack_num=2)
    lbuf = ListReplayBuffer()
    pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rew = torch.tensor([1.]).to(device)
    for i in range(4):
        vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
    for i in range(3):
        lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
    for i in range(5):
        pbuf.add(obs=Batch(index=np.array([i])),
                 act=2, rew=rew, done=0, weight=np.random.rand())
    # save & load
    _vbuf = pickle.loads(pickle.dumps(vbuf))
    _lbuf = pickle.loads(pickle.dumps(lbuf))
    _pbuf = pickle.loads(pickle.dumps(pbuf))
    assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
    assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
    assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
    # make sure the meta var is identical
    assert _vbuf.stack_num == vbuf.stack_num
    assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
                       pbuf.weight[np.arange(len(pbuf))])
예제 #2
0
def test_hdf5():
    size = 100
    buffers = {
        "array": ReplayBuffer(size, stack_num=2),
        "list": ListReplayBuffer(),
        "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4),
    }
    buffer_types = {k: b.__class__ for k, b in buffers.items()}
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    info_t = torch.tensor([1.]).to(device)
    for i in range(4):
        kwargs = {
            'obs': Batch(index=np.array([i])),
            'act': i,
            'rew': np.array([1, 2]),
            'done': i % 3 == 2,
            'info': {"number": {"n": i, "t": info_t}, 'extra': None},
        }
        buffers["array"].add(**kwargs)
        buffers["list"].add(**kwargs)
        buffers["prioritized"].add(weight=np.random.rand(), **kwargs)

    # save
    paths = {}
    for k, buf in buffers.items():
        f, path = tempfile.mkstemp(suffix='.hdf5')
        os.close(f)
        buf.save_hdf5(path)
        paths[k] = path

    # load replay buffer
    _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}

    # compare
    for k in buffers.keys():
        assert len(_buffers[k]) == len(buffers[k])
        assert np.allclose(_buffers[k].act, buffers[k].act)
        assert _buffers[k].stack_num == buffers[k].stack_num
        assert _buffers[k].maxsize == buffers[k].maxsize
        assert np.all(_buffers[k]._indices == buffers[k]._indices)
    for k in ["array", "prioritized"]:
        assert _buffers[k]._index == buffers[k]._index
        assert isinstance(buffers[k].get(0, "info"), Batch)
        assert isinstance(_buffers[k].get(0, "info"), Batch)
    for k in ["array"]:
        assert np.all(
            buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
        assert np.all(
            buffers[k][:].info.extra == _buffers[k][:].info.extra)

    # raise exception when value cannot be pickled
    data = {"not_supported": lambda x: x * x}
    grp = h5py.Group
    with pytest.raises(NotImplementedError):
        to_hdf5(data, grp)
    # ndarray with data type not supported by HDF5 that cannot be pickled
    data = {"not_supported": np.array(lambda x: x * x)}
    grp = h5py.Group
    with pytest.raises(RuntimeError):
        to_hdf5(data, grp)
예제 #3
0
 def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
     assert buffer_num > 0
     size = int(np.ceil(total_size / buffer_num))
     buffer_list = [
         PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)
     ]
     super().__init__(buffer_list)
예제 #4
0
def test_collector():
    writer = SummaryWriter('log/collector')
    logger = Logger(writer)
    env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]

    venv = SubprocVectorEnv(env_fns)
    dum = DummyVectorEnv(env_fns)
    policy = MyPolicy()
    env = env_fns[0]()
    c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn)
    c0.collect(n_step=3)
    assert len(c0.buffer) == 3
    assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
    c0.collect(n_episode=3)
    assert len(c0.buffer) == 8
    assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    c0.collect(n_step=3, random=True)
    c1 = Collector(policy, venv,
                   VectorReplayBuffer(total_size=100, buffer_num=4),
                   logger.preprocess_fn)
    c1.collect(n_step=8)
    obs = np.zeros(100)
    obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]

    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    c1.collect(n_episode=4)
    assert len(c1.buffer) == 16
    obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0],
                       [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
    c1.collect(n_episode=4, random=True)
    c2 = Collector(policy, dum, VectorReplayBuffer(total_size=100,
                                                   buffer_num=4),
                   logger.preprocess_fn)
    c2.collect(n_episode=7)
    obs1 = obs.copy()
    obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
    obs2 = obs.copy()
    obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
    c2obs = c2.buffer.obs[:, 0]
    assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
    c2.reset_env()
    c2.reset_buffer()
    assert c2.collect(n_episode=8)['n/ep'] == 8
    obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
    assert np.all(c2.buffer.obs[:, 0] == obs)
    c2.collect(n_episode=4, random=True)

    # test corner case
    with pytest.raises(TypeError):
        Collector(policy, dum, ReplayBuffer(10))
    with pytest.raises(TypeError):
        Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
    with pytest.raises(TypeError):
        c2.collect()
예제 #5
0
def test_init():
    for _ in np.arange(1e5):
        _ = ReplayBuffer(1e5)
        _ = PrioritizedReplayBuffer(size=int(1e5),
                                    alpha=0.5,
                                    beta=0.5,
                                    repeat_sample=True)
        _ = ListReplayBuffer()
예제 #6
0
def test_pickle():
    size = 100
    vbuf = ReplayBuffer(size, stack_num=2)
    pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
    rew = np.array([1, 1])
    for i in range(4):
        vbuf.add(Batch(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0))
    for i in range(5):
        pbuf.add(
            Batch(obs=Batch(index=np.array([i])),
                  act=2,
                  rew=rew,
                  done=0,
                  info=np.random.rand()))
    # save & load
    _vbuf = pickle.loads(pickle.dumps(vbuf))
    _pbuf = pickle.loads(pickle.dumps(pbuf))
    assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
    assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
    # make sure the meta var is identical
    assert _vbuf.stack_num == vbuf.stack_num
    assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
                       pbuf.weight[np.arange(len(pbuf))])
예제 #7
0
def test_priortized_replaybuffer(size=32, bufsize=15):
    env = MyTestEnv(size)
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
    buf2 = PrioritizedVectorReplayBuffer(bufsize,
                                         buffer_num=3,
                                         alpha=0.5,
                                         beta=0.5)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        batch = Batch(obs=obs,
                      act=a,
                      rew=rew,
                      done=done,
                      obs_next=obs_next,
                      info=info,
                      policy=np.random.randn() - 0.5)
        batch_stack = Batch.stack([batch, batch, batch])
        buf.add(Batch.stack([batch]), buffer_ids=[0])
        buf2.add(batch_stack, buffer_ids=[0, 1, 2])
        obs = obs_next
        data, indices = buf.sample(len(buf) // 2)
        if len(buf) // 2 == 0:
            assert len(data) == len(buf)
        else:
            assert len(data) == len(buf) // 2
        assert len(buf) == min(bufsize, i + 1)
        assert len(buf2) == min(bufsize, 3 * (i + 1))
    # check single buffer's data
    assert buf.info.key.shape == (buf.maxsize, )
    assert buf.rew.dtype == float
    assert buf.done.dtype == bool
    data, indices = buf.sample(len(buf) // 2)
    buf.update_weight(indices, -data.weight / 2)
    assert np.allclose(buf.weight[indices],
                       np.abs(-data.weight / 2)**buf._alpha)
    # check multi buffer's data
    assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1)
    batch, indices = buf2.sample(10)
    buf2.update_weight(indices, batch.weight * 0)
    weight = buf2[np.arange(buf2.maxsize)].weight
    mask = np.isin(np.arange(buf2.maxsize), indices)
    assert np.all(weight[mask] == weight[mask][0])
    assert np.all(weight[~mask] == weight[~mask][0])
    assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
예제 #8
0
def test_priortized_replaybuffer(size=32, bufsize=15):
    env = MyTestEnv(size)
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
        obs = obs_next
        data, indice = buf.sample(len(buf) // 2)
        if len(buf) // 2 == 0:
            assert len(data) == len(buf)
        else:
            assert len(data) == len(buf) // 2
        assert len(buf) == min(bufsize, i + 1)
    data, indice = buf.sample(len(buf) // 2)
    buf.update_weight(indice, -data.weight / 2)
    assert np.allclose(
        buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
예제 #9
0
def test_priortized_replaybuffer(size=32, bufsize=15):
    env = MyTestEnv(size)
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
        obs = obs_next
        assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
                          1,
                          rtol=1e-12)
        data, indice = buf.sample(len(buf) // 2)
        if len(buf) // 2 == 0:
            assert len(data) == len(buf)
        else:
            assert len(data) == len(buf) // 2
        assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
        assert np.isclose(buf._weight_sum, (buf.weight).sum())
    data, indice = buf.sample(len(buf) // 2)
    buf.update_weight(indice, -data.weight / 2)
    assert np.isclose(buf.weight[indice],
                      np.power(np.abs(-data.weight / 2), buf._alpha)).all()
    assert np.isclose(buf._weight_sum, (buf.weight).sum())
예제 #10
0
def test_c51(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # train_envs = gym.make(args.task)
    # you can also use tianshou.env.SubprocVectorEnv
    train_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = Net(args.state_shape, args.action_shape,
              hidden_sizes=args.hidden_sizes, device=args.device,
              softmax=True, num_atoms=args.num_atoms)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    policy = C51Policy(
        net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
        args.n_step, target_update_freq=args.target_update_freq
    ).to(args.device)
    # buffer
    if args.prioritized_replay:
        buf = PrioritizedReplayBuffer(
            args.buffer_size, alpha=args.alpha, beta=args.beta)
    else:
        buf = ReplayBuffer(args.buffer_size)
    # collector
    train_collector = Collector(policy, train_envs, buf)
    test_collector = Collector(policy, test_envs)
    # policy.set_eps(1)
    train_collector.collect(n_step=args.batch_size)
    # log
    log_path = os.path.join(args.logdir, args.task, 'c51')
    writer = SummaryWriter(log_path)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        return mean_rewards >= env.spec.reward_threshold

    def train_fn(epoch, env_step):
        # eps annnealing, just a demo
        if env_step <= 10000:
            policy.set_eps(args.eps_train)
        elif env_step <= 50000:
            eps = args.eps_train - (env_step - 10000) / \
                40000 * (0.9 * args.eps_train)
            policy.set_eps(eps)
        else:
            policy.set_eps(0.1 * args.eps_train)

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    # trainer
    result = offpolicy_trainer(
        policy, train_collector, test_collector, args.epoch,
        args.step_per_epoch, args.collect_per_step, args.test_num,
        args.batch_size, train_fn=train_fn, test_fn=test_fn,
        stop_fn=stop_fn, save_fn=save_fn, writer=writer)

    assert stop_fn(result['best_reward'])
    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = gym.make(args.task)
        policy.eval()
        policy.set_eps(args.eps_test)
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')
예제 #11
0
net = DuelingDQN(state_shape,
                 action_shape,
                 hidden_layer=training_config['hidden_layer'],
                 cnn=training_config['cnn'])
optim = torch.optim.Adam(net.parameters(), lr=training_config['learning_rate'])

policy = DQNPolicy(net,
                   optim,
                   training_config['gamma'],
                   training_config['n_step'],
                   grad_norm_clipping=training_config['grad_norm_clipping'],
                   target_update_freq=training_config['target_update_freq'])

if training_config['prioritized_replay']:
    buf = PrioritizedReplayBuffer(training_config['buffer_size'],
                                  alpha=training_config['alpha'],
                                  beta=training_config['beta'])
else:
    buf = ReplayBuffer(training_config['buffer_size'])
policy.set_eps(1)
train_collector = Collector(policy, train_envs, buf)
train_collector.collect(n_step=training_config['pre_collect'])


def delect_log():
    for dirname, dirnames, filenames in os.walk('/u/zifan/.ros/log'):
        for filename in filenames:
            p = join(dirname, filename)
            if p.endswith('.log') and dirname != '/u/zifan/.ros/log':
                os.remove(p)
예제 #12
0
파일: manager.py 프로젝트: wide725/tianshou
 def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None:
     ReplayBufferManager.__init__(self, buffer_list)  # type: ignore
     kwargs = buffer_list[0].options
     for buf in buffer_list:
         del buf.weight
     PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs)
예제 #13
0
def test_multibuf_stack():
    size = 5
    bufsize = 9
    stack_num = 4
    cached_num = 3
    env = MyTestEnv(size)
    # test if CachedReplayBuffer can handle stack_num + ignore_obs_next
    buf4 = CachedReplayBuffer(
        ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
        cached_num, size)
    # test if CachedReplayBuffer can handle super corner case:
    # prio-buffer + stack_num + ignore_obs_next + sample_avail
    buf5 = CachedReplayBuffer(
        PrioritizedReplayBuffer(bufsize,
                                0.6,
                                0.4,
                                stack_num=stack_num,
                                ignore_obs_next=True,
                                sample_avail=True), cached_num, size)
    obs = env.reset(1)
    for i in range(18):
        obs_next, rew, done, info = env.step(1)
        obs_list = np.array([obs + size * i for i in range(cached_num)])
        act_list = [1] * cached_num
        rew_list = [rew] * cached_num
        done_list = [done] * cached_num
        obs_next_list = -obs_list
        info_list = [info] * cached_num
        buf4.add(obs_list, act_list, rew_list, done_list, obs_next_list,
                 info_list)
        buf5.add(obs_list, act_list, rew_list, done_list, obs_next_list,
                 info_list)
        obs = obs_next
        if done:
            obs = env.reset(1)
    # check the `add` order is correct
    assert np.allclose(
        buf4.obs.reshape(-1),
        [
            12,
            13,
            14,
            4,
            6,
            7,
            8,
            9,
            11,  # main_buffer
            1,
            2,
            3,
            4,
            0,  # cached_buffer[0]
            6,
            7,
            8,
            9,
            0,  # cached_buffer[1]
            11,
            12,
            13,
            14,
            0,  # cached_buffer[2]
        ]), buf4.obs
    assert np.allclose(
        buf4.done,
        [
            0,
            0,
            1,
            1,
            0,
            0,
            0,
            1,
            0,  # main_buffer
            0,
            0,
            0,
            1,
            0,  # cached_buffer[0]
            0,
            0,
            0,
            1,
            0,  # cached_buffer[1]
            0,
            0,
            0,
            1,
            0,  # cached_buffer[2]
        ]), buf4.done
    assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
    indice = sorted(buf4.sample_index(0))
    assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
    assert np.allclose(buf4[indice].obs[..., 0], [
        [11, 11, 11, 12],
        [11, 11, 12, 13],
        [11, 12, 13, 14],
        [4, 4, 4, 4],
        [6, 6, 6, 6],
        [6, 6, 6, 7],
        [6, 6, 7, 8],
        [6, 7, 8, 9],
        [11, 11, 11, 11],
        [1, 1, 1, 1],
        [1, 1, 1, 2],
        [6, 6, 6, 6],
        [6, 6, 6, 7],
        [11, 11, 11, 11],
        [11, 11, 11, 12],
    ])
    assert np.allclose(buf4[indice].obs_next[..., 0], [
        [11, 11, 12, 13],
        [11, 12, 13, 14],
        [11, 12, 13, 14],
        [4, 4, 4, 4],
        [6, 6, 6, 7],
        [6, 6, 7, 8],
        [6, 7, 8, 9],
        [6, 7, 8, 9],
        [11, 11, 11, 12],
        [1, 1, 1, 2],
        [1, 1, 1, 2],
        [6, 6, 6, 7],
        [6, 6, 6, 7],
        [11, 11, 11, 12],
        [11, 11, 11, 12],
    ])
    assert np.all(buf4.done == buf5.done)
    indice = buf5.sample_index(0)
    assert np.allclose(sorted(indice), [2, 7])
    assert np.all(np.isin(buf5.sample_index(100), indice))
    # manually change the stack num
    buf5.stack_num = 2
    for buf in buf5.buffers:
        buf.stack_num = 2
    indice = buf5.sample_index(0)
    assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20])
    batch, _ = buf5.sample(0)
    assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1)
    buf5.update_weight(indice, batch.weight * 0)
    weight = buf5[np.arange(buf5.maxsize)].weight
    modified_weight = weight[[0, 1, 2, 5, 6, 7]]
    assert modified_weight.min() == modified_weight.max()
    assert modified_weight.max() < 1
    unmodified_weight = weight[[3, 4, 8]]
    assert unmodified_weight.min() == unmodified_weight.max()
    assert unmodified_weight.max() < 1
    cached_weight = weight[9:]
    assert cached_weight.min() == cached_weight.max() == 1
    # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
    buf6 = CachedReplayBuffer(
        ReplayBuffer(bufsize,
                     stack_num=stack_num,
                     save_only_last_obs=True,
                     ignore_obs_next=True), cached_num, size)
    obs = np.random.rand(size, 4, 84, 84)
    buf6.add(obs=[obs[2], obs[0]],
             act=[1, 1],
             rew=[0, 0],
             done=[0, 1],
             obs_next=[obs[3], obs[1]],
             cached_buffer_ids=[1, 2])
    assert buf6.obs.shape == (buf6.maxsize, 84, 84)
    assert np.allclose(buf6.obs[0], obs[0, -1])
    assert np.allclose(buf6.obs[14], obs[2, -1])
    assert np.allclose(buf6.obs[19], obs[0, -1])
    assert buf6[0].obs.shape == (4, 84, 84)
예제 #14
0
def test_pdqn(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # train_envs = gym.make(args.task)
    # you can also use tianshou.env.SubprocVectorEnv
    train_envs = VectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = VectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
    net = net.to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       use_target_network=args.target_update_freq > 0,
                       target_update_freq=args.target_update_freq)
    # collector
    if args.prioritized_replay > 0:
        buf = PrioritizedReplayBuffer(args.buffer_size,
                                      alpha=args.alpha,
                                      beta=args.alpha)
    else:
        buf = ReplayBuffer(args.buffer_size)
    train_collector = Collector(policy, train_envs, buf)
    test_collector = Collector(policy, test_envs)
    # policy.set_eps(1)
    train_collector.collect(n_step=args.batch_size)
    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn')
    writer = SummaryWriter(log_path)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(x):
        return x >= env.spec.reward_threshold

    def train_fn(x):
        policy.set_eps(args.eps_train)

    def test_fn(x):
        policy.set_eps(args.eps_test)

    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.collect_per_step,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               stop_fn=stop_fn,
                               save_fn=save_fn,
                               writer=writer)

    assert stop_fn(result['best_reward'])
    train_collector.close()
    test_collector.close()
    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = gym.make(args.task)
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        print(f'Final reward: {result["rew"]}, length: {result["len"]}')
        collector.close()
예제 #15
0
def test_dqn(args=load_args()):
    # load config
    env_args = args[1]
    args = args[0]
    # load environments
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    train_envs = DummyVectorEnv([
        lambda: gym.make(args.task, **env_args)
        for _ in range(args.training_num)
    ])
    test_envs = DummyVectorEnv([
        lambda: gym.make(args.task, test=True, **env_args)
        for _ in range(args.test_num)
    ])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = Net(
        args.layer_num,
        args.state_shape,
        args.action_shape,
        args.device,  # dueling=(1, 1)
    ).to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # learning schedule
    if args.lr_schedule == "linear":
        lr_lambda = lambda epoch: (1 - float(epoch) / args.epoch)
    else:
        lr_lambda = lambda epoch: 1  # constant lr
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda)
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       target_update_freq=args.target_update_freq)
    # buffer
    if args.prioritized_replay > 0:
        buf = PrioritizedReplayBuffer(args.buffer_size,
                                      alpha=args.alpha,
                                      beta=args.beta)
    else:
        buf = ReplayBuffer(args.buffer_size)
    # collector
    train_collector = Collector(policy, train_envs, buf)
    test_collector = Collector(policy, test_envs)
    # policy.set_eps(1)
    train_collector.collect(n_step=args.batch_size)
    # log
    now = datetime.now()
    dt_string = now.strftime("%Y_%m_%d_%H_%M")
    log_path = os.path.join("log", args.task, 'dqn', dt_string)
    writer = SummaryWriter(log_path)
    copyfile(CONFIG_PATH, os.path.join(log_path, "default.json"))

    def save_fn(policy):
        torch.save(policy.model.state_dict(),
                   os.path.join(log_path, 'policy.pth'))

    def train_fn(epoch, env_step):
        policy.set_eps(
            max(args.final_eps,
                args.init_eps * (1 - 2 * (epoch - 1) / (args.epoch - 1))))

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)
        scheduler.step()

    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.collect_per_step,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               save_fn=save_fn,
                               writer=writer,
                               log_interval=10)
예제 #16
0
def test_collector(gym_reset_kwargs):
    writer = SummaryWriter('log/collector')
    logger = Logger(writer)
    env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]

    venv = SubprocVectorEnv(env_fns)
    dum = DummyVectorEnv(env_fns)
    policy = MyPolicy()
    env = env_fns[0]()
    c0 = Collector(
        policy,
        env,
        ReplayBuffer(size=100),
        logger.preprocess_fn,
    )
    c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs)
    assert len(c0.buffer) == 3
    assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
    keys = np.zeros(100)
    keys[:3] = 1
    assert np.allclose(c0.buffer.info["key"], keys)
    for e in c0.buffer.info["env"][:3]:
        assert isinstance(e, MyTestEnv)
    assert np.allclose(c0.buffer.info["env_id"], 0)
    rews = np.zeros(100)
    rews[:3] = [0, 1, 0]
    assert np.allclose(c0.buffer.info["rew"], rews)
    c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs)
    assert len(c0.buffer) == 8
    assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    assert np.allclose(c0.buffer.info["key"][:8], 1)
    for e in c0.buffer.info["env"][:8]:
        assert isinstance(e, MyTestEnv)
    assert np.allclose(c0.buffer.info["env_id"][:8], 0)
    assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1])
    c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs)

    c1 = Collector(policy, venv,
                   VectorReplayBuffer(total_size=100, buffer_num=4),
                   logger.preprocess_fn)
    c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs)
    obs = np.zeros(100)
    valid_indices = [0, 1, 25, 26, 50, 51, 75, 76]
    obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1]
    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    keys = np.zeros(100)
    keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
    assert np.allclose(c1.buffer.info["key"], keys)
    for e in c1.buffer.info["env"][valid_indices]:
        assert isinstance(e, MyTestEnv)
    env_ids = np.zeros(100)
    env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3]
    assert np.allclose(c1.buffer.info["env_id"], env_ids)
    rews = np.zeros(100)
    rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0]
    assert np.allclose(c1.buffer.info["rew"], rews)
    c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs)
    assert len(c1.buffer) == 16
    valid_indices = [2, 3, 27, 52, 53, 77, 78, 79]
    obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0],
                       [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
    keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
    assert np.allclose(c1.buffer.info["key"], keys)
    for e in c1.buffer.info["env"][valid_indices]:
        assert isinstance(e, MyTestEnv)
    env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3]
    assert np.allclose(c1.buffer.info["env_id"], env_ids)
    rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1]
    assert np.allclose(c1.buffer.info["rew"], rews)
    c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)

    c2 = Collector(policy, dum, VectorReplayBuffer(total_size=100,
                                                   buffer_num=4),
                   logger.preprocess_fn)
    c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs)
    obs1 = obs.copy()
    obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
    obs2 = obs.copy()
    obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
    c2obs = c2.buffer.obs[:, 0]
    assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
    c2.reset_env(gym_reset_kwargs=gym_reset_kwargs)
    c2.reset_buffer()
    assert c2.collect(n_episode=8,
                      gym_reset_kwargs=gym_reset_kwargs)['n/ep'] == 8
    valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57]
    obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
    assert np.all(c2.buffer.obs[:, 0] == obs)
    keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1]
    assert np.allclose(c2.buffer.info["key"], keys)
    for e in c2.buffer.info["env"][valid_indices]:
        assert isinstance(e, MyTestEnv)
    env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2]
    assert np.allclose(c2.buffer.info["env_id"], env_ids)
    rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1]
    assert np.allclose(c2.buffer.info["rew"], rews)
    c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)

    # test corner case
    with pytest.raises(TypeError):
        Collector(policy, dum, ReplayBuffer(10))
    with pytest.raises(TypeError):
        Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
    with pytest.raises(TypeError):
        c2.collect()

    # test NXEnv
    for obs_type in ["array", "object"]:
        envs = SubprocVectorEnv(
            [lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]])
        c3 = Collector(policy, envs,
                       VectorReplayBuffer(total_size=100, buffer_num=4))
        c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs)
        assert c3.buffer.obs.dtype == object