def test_append_sample_controlled(rank, trans_list):
     world = get_world()
     default_logger.info("{} started".format(rank))
     np.random.seed(0)
     group = world.create_rpc_group("group", ["0", "1", "2"])
     buffer = DistributedPrioritizedBuffer("buffer", group, 5)
     if rank in (0, 1):
         for i in range(5):
             trans, prior = trans_list[i]
             buffer.append(trans, prior)
         sleep(5)
     else:
         sleep(2)
         batch_size, sample, indexes, priorities = \
             buffer.sample_batch(10, sample_attrs=["index"])
         default_logger.info("sampled batch size: {}".format(batch_size))
         default_logger.info(sample)
         default_logger.info(indexes)
         default_logger.info(priorities)
         assert batch_size == 10
         assert sample[0] == [0, 1, 2, 2, 4, 0, 1, 2, 2, 4]
         assert list(indexes.keys()) == ["0", "1"]
         assert np.all(
             np.abs(priorities - [
                 0.75316421, 0.75316421, 0.75316421, 0.75316421, 1.0,
                 0.75316421, 0.75316421, 0.75316421, 0.75316421, 1.0
             ]) < 1e-6)
         buffer.update_priority(priorities, indexes)
     return True
 def test_append_sample_random(rank, trans_list):
     world = get_world()
     count = 0
     default_logger.info("{} started".format(rank))
     group = world.create_rpc_group("group", ["0", "1", "2"])
     buffer = DistributedPrioritizedBuffer("buffer", group, 5)
     if rank in (0, 1):
         begin = time()
         while time() - begin < 10:
             trans, prior = random.choice(trans_list)
             buffer.append(trans)
             default_logger.info("{} append {} success".format(rank, count))
             count += 1
             sleep(random.random() * 0.5)
     else:
         sleep(5)
         begin = time()
         while time() - begin < 5:
             batch_size, sample, indexes, priorities = \
                 buffer.sample_batch(10)
             default_logger.info(
                 "sampled batch size: {}".format(batch_size))
             assert batch_size > 0
             # state
             assert (list(sample[0]["state_1"].shape) == [batch_size, 2])
             # action
             assert (list(sample[1]["action_1"].shape) == [batch_size, 3])
             # next state
             assert (list(
                 sample[2]["next_state_1"].shape) == [batch_size, 2])
             # reward
             assert list(sample[3].shape) == [batch_size, 1]
             # terminal
             assert list(sample[4].shape) == [batch_size, 1]
             # index
             assert len(sample[5]) == batch_size
             # simulate perform a backward process
             sleep(1)
             buffer.update_priority(priorities, indexes)
             default_logger.info("{} sample {} success".format(rank, count))
             count += 1
             sleep(1)
     return True