def test_append_sample_controlled(rank, trans_list): world = get_world() default_logger.info("{} started".format(rank)) np.random.seed(0) group = world.create_rpc_group("group", ["0", "1", "2"]) buffer = DistributedPrioritizedBuffer("buffer", group, 5) if rank in (0, 1): for i in range(5): trans, prior = trans_list[i] buffer.append(trans, prior) sleep(5) else: sleep(2) batch_size, sample, indexes, priorities = \ buffer.sample_batch(10, sample_attrs=["index"]) default_logger.info("sampled batch size: {}".format(batch_size)) default_logger.info(sample) default_logger.info(indexes) default_logger.info(priorities) assert batch_size == 10 assert sample[0] == [0, 1, 2, 2, 4, 0, 1, 2, 2, 4] assert list(indexes.keys()) == ["0", "1"] assert np.all( np.abs(priorities - [ 0.75316421, 0.75316421, 0.75316421, 0.75316421, 1.0, 0.75316421, 0.75316421, 0.75316421, 0.75316421, 1.0 ]) < 1e-6) buffer.update_priority(priorities, indexes) return True
def test_append_sample_empty(rank, trans_list): world = get_world() default_logger.info(f"{rank} started") np.random.seed(0) group = world.create_rpc_group("group", ["0", "1", "2"]) buffer = DistributedPrioritizedBuffer("buffer", group, 5) if rank in (0, 1): for i in range(5): trans, prior = trans_list[i] buffer.append(trans) sleep(5) else: sleep(2) batch_size, sample, indexes, priorities = buffer.sample_batch(0) assert batch_size == 0 assert sample is None assert indexes is None assert priorities is None return True
def test_append_sample_random(rank, trans_list): world = get_world() count = 0 default_logger.info("{} started".format(rank)) group = world.create_rpc_group("group", ["0", "1", "2"]) buffer = DistributedPrioritizedBuffer("buffer", group, 5) if rank in (0, 1): begin = time() while time() - begin < 10: trans, prior = random.choice(trans_list) buffer.append(trans) default_logger.info("{} append {} success".format(rank, count)) count += 1 sleep(random.random() * 0.5) else: sleep(5) begin = time() while time() - begin < 5: batch_size, sample, indexes, priorities = \ buffer.sample_batch(10) default_logger.info( "sampled batch size: {}".format(batch_size)) assert batch_size > 0 # state assert (list(sample[0]["state_1"].shape) == [batch_size, 2]) # action assert (list(sample[1]["action_1"].shape) == [batch_size, 3]) # next state assert (list( sample[2]["next_state_1"].shape) == [batch_size, 2]) # reward assert list(sample[3].shape) == [batch_size, 1] # terminal assert list(sample[4].shape) == [batch_size, 1] # index assert len(sample[5]) == batch_size # simulate perform a backward process sleep(1) buffer.update_priority(priorities, indexes) default_logger.info("{} sample {} success".format(rank, count)) count += 1 sleep(1) return True
def test_append_clear(rank, trans_list): world = get_world() default_logger.info(f"{rank} started") np.random.seed(0) group = world.create_rpc_group("group", ["0", "1", "2"]) buffer = DistributedPrioritizedBuffer("buffer", group, 5) if rank in (0, 1): for i in range(5): trans, prior = trans_list[i] buffer.append(trans, prior) if rank == 0: buffer.clear() assert buffer.size() == 0 sleep(5) else: sleep(2) assert buffer.all_size() == 5 buffer.all_clear() assert buffer.all_size() == 0 return True
def test_append_size(rank, trans_list): world = get_world() default_logger.info("{} started".format(rank)) np.random.seed(0) group = world.create_rpc_group("group", ["0", "1", "2"]) buffer = DistributedPrioritizedBuffer("buffer", group, 5) if rank in (0, 1): if rank == 0: for i in range(5): data, prior = trans_list[i] buffer.append(data, prior) assert buffer.size() == 5 else: assert buffer.size() == 0 sleep(5) else: sleep(2) assert buffer.size() == 0 assert buffer.all_size() == 5 return True