def testClearAllVariables(self): batch_size = 1 spec = self._data_spec() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=batch_size, max_length=10) action = tf.constant( 1 * np.ones(spec[0].shape.as_list(), dtype=np.float32)) lidar = tf.constant( 2 * np.ones(spec[1][0].shape.as_list(), dtype=np.float32)) camera = tf.constant( 3 * np.ones(spec[1][1].shape.as_list(), dtype=np.float32)) values = [action, [lidar, camera]] values_batched = tf.nest.map_structure( lambda t: tf.stack([t] * batch_size), values) if tf.executing_eagerly(): add_op = lambda: replay_buffer.add_batch(values_batched) else: add_op = replay_buffer.add_batch(values_batched) def get_table_vars(): return [ var for var in replay_buffer.variables() if 'Table' in var.name ] self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(replay_buffer._clear(clear_all_variables=True)) empty_table_vars = self.evaluate(get_table_vars()) initial_id = self.evaluate(replay_buffer._get_last_id()) empty_items = self.evaluate(replay_buffer.gather_all()) self.evaluate(add_op) self.evaluate(add_op) self.evaluate(add_op) self.evaluate(add_op) values_ = self.evaluate(values) sample, _ = self.evaluate(replay_buffer.get_next(sample_batch_size=3)) tf.nest.map_structure(lambda x, y: self._assertContains([x], list(y)), values_, sample) self.assertNotEqual(initial_id, self.evaluate(replay_buffer._get_last_id())) tf.nest.map_structure(lambda x, y: self.assertFalse(np.all(x == y)), empty_table_vars, self.evaluate(get_table_vars())) self.evaluate(replay_buffer._clear(clear_all_variables=True)) self.assertEqual(initial_id, self.evaluate(replay_buffer._get_last_id())) def check_np_arrays_everything_equal(x, y): np.testing.assert_equal(x, y) self.assertEqual(x.dtype, y.dtype) tf.nest.map_structure(check_np_arrays_everything_equal, empty_items, self.evaluate(replay_buffer.gather_all()))
def create_replay_buffer(self): # a batched replay buffer which can be sampled uniformly during training self._replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( # the type of data you are collecting data_spec=self._agent.collect_data_spec, batch_size=self._train_env.batch_size, max_length=param.BUFFER_LENGTH) # adds a batch of items to replay_buffer - part of the routine update in dynamic_step_driver self._replay_buffer_observer = self._replay_buffer.add_batch
def _init_replay_buffer(self, mac, agent, train_env): """Replay buffer keeps track of data collected from the environment. We will be using TFUniformReplayBuffer. """ self.replay_buffer[mac] = replay.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=train_env.batch_size, max_length=self.replay_buffer_max_length) return
def get_replay_buffer(self): """Replay buffer Returns: ReplayBuffer -- tf-agents replay buffer """ return tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self._agent.collect_data_spec, batch_size=self._params["ML"]["Agent"]["num_parallel_environments"], max_length=self._params["ML"]["Agent"]["replay_buffer_capacity"])
def testGetNextEmpty(self): spec = self._data_spec() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=1, max_length=10) sample, _ = replay_buffer.get_next() self.evaluate(tf.global_variables_initializer()) with self.assertRaisesRegexp( tf.errors.InvalidArgumentError, 'TFUniformReplayBuffer is empty. Make ' 'sure to add items before sampling the buffer.'): self.evaluate(sample)
def create_replay_buffer(rb_type, data_spec, batch_size, max_length): if rb_type == 'uniform': prb_flag = False return (tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=data_spec, batch_size=batch_size, max_length=max_length), prb_flag) elif rb_type == 'prioritized': prb_flag = True return (tf_prioritized_replay_buffer.TFPrioritizedReplayBuffer( data_spec=data_spec, batch_size=batch_size, max_length=max_length), prb_flag)
def make_replay_buffer(tf_env): """Default replay buffer factory.""" time_step_spec = tf_env.time_step_spec() action_step_spec = policy_step.PolicyStep( tf_env.action_spec(), (), tensor_spec.TensorSpec((), tf.int32)) trajectory_spec = trajectory.from_transition(time_step_spec, action_step_spec, time_step_spec) return tf_uniform_replay_buffer.TFUniformReplayBuffer(trajectory_spec, batch_size=1)
def testGatherAllEmpty(self, batch_size): spec = specs.TensorSpec([], tf.int32, 'action') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=batch_size) items = replay_buffer.gather_all() expected = [[]] * batch_size self.evaluate(tf.compat.v1.global_variables_initializer()) items_ = self.evaluate(items) self.assertAllClose(expected, items_)
def testDeterministicAsDatasetNumStepsGreaterThanMaxLengthFails(self): spec = specs.TensorSpec([], tf.int32, 'action') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=2, max_length=3, # If this isn't turned on, then the batching works fine. dataset_drop_remainder=True) with self.assertRaisesRegexp(ValueError, 'ALL data will be dropped'): replay_buffer.as_dataset(single_deterministic_pass=True, num_steps=4)
def __init__(self, data_spec, batch_size=1, n_steps=2): self.data_spec = data_spec self.batch_size = batch_size self.n_steps = n_steps self.replay_buffer_capacity = 10000 self.buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self.data_spec, batch_size=self.batch_size, max_length=self.replay_buffer_capacity, ) self.state = None
def testGatherAllEmpty(self, batch_size): spec = specs.TensorSpec([], tf.int32, 'action') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=batch_size) items = replay_buffer.gather_all() expected = [[]] * batch_size with self.test_session() as sess: sess.run(tf.global_variables_initializer()) items_ = sess.run(items) self.assertAllClose(expected, items_)
def get_tf_buffers(c,max_length=270): obs_spec,ac_spec = get_env_specs(c) time_step_spec = ts.time_step_spec(obs_spec) action_spec = policy_step.PolicyStep(ac_spec) trajectory_spec = trajectory.from_transition( time_step_spec, action_spec , time_step_spec) the_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=trajectory_spec, batch_size=1, max_length=max_length) return the_replay_buffer
def testClearAllVariables(self): batch_size = 1 spec = self._data_spec() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=batch_size, max_length=10) action = tf.constant( 1 * np.ones(spec[0].shape.as_list(), dtype=np.float32)) lidar = tf.constant( 2 * np.ones(spec[1][0].shape.as_list(), dtype=np.float32)) camera = tf.constant( 3 * np.ones(spec[1][1].shape.as_list(), dtype=np.float32)) values = [action, [lidar, camera]] values_batched = nest.map_structure( lambda t: tf.stack([t] * batch_size), values) last_id_op = replay_buffer._get_last_id() add_op = replay_buffer.add_batch(values_batched) sample, _ = replay_buffer.get_next(sample_batch_size=3) clear_op = replay_buffer._clear(clear_all_variables=True) items_op = replay_buffer.gather_all() table_vars = [ var for var in replay_buffer.variables() if 'Table' in var.name ] with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(clear_op) empty_table_vars = sess.run(table_vars) last_id = sess.run(last_id_op) empty_items = sess.run(items_op) sess.run(add_op) sess.run(add_op) sess.run(add_op) sess.run(add_op) values_ = sess.run(values) sample_ = sess.run(sample) nest.map_structure(lambda x, y: self._assertContains([x], list(y)), values_, sample_) self.assertNotEqual(last_id, sess.run(last_id_op)) nest.map_structure(lambda x, y: self.assertFalse(np.all(x == y)), empty_table_vars, sess.run(table_vars)) sess.run(clear_op) self.assertEqual(last_id, sess.run(last_id_op)) def check_np_arrays_everything_equal(x, y): np.testing.assert_equal(x, y) self.assertEqual(x.dtype, y.dtype) nest.map_structure(check_np_arrays_everything_equal, empty_items, sess.run(items_op))
def initOnlineAndOfflineBuffers(self, bufferType="online"): # (s,a, S', r) data_spec = (tf.TensorSpec(self._state_size, tf.float64, 'state'), tf.TensorSpec(self._action_size, tf.float64, 'action'), tf.TensorSpec(self._imgReshapeWithDepth, tf.float32, 'camera'), tf.TensorSpec(self._state_size, tf.float64, 'next_state'), tf.TensorSpec(self._imgReshapeWithDepth, tf.float32, 'next_camera'), tf.TensorSpec(1, tf.float64, 'reward'), tf.TensorSpec(1, tf.bool, 'terminated')) if (bufferType == "online"): return tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec, batch_size=1, max_length=self.onlineBufferMaxLengh) elif (bufferType == "offline"): return tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec, batch_size=1, max_length=self.offlineBuffferMaxLength) else: print("Error init buffer")
def _create_collect_rb_dataset(self, max_length, buffer_batch_size, num_adds, sample_batch_size, num_steps=None): """Create a replay buffer, add items to it, and collect from its dataset.""" spec = specs.TensorSpec([], tf.int32, 'action') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=buffer_batch_size, max_length=max_length) ds = replay_buffer.as_dataset(single_deterministic_pass=True, sample_batch_size=sample_batch_size, num_steps=num_steps) if tf.executing_eagerly(): ix = [0] def add_op(): replay_buffer.add_batch(10 * tf.range(buffer_batch_size) + ix[0]) ix[0] += 1 itr = iter(ds) get_next = lambda: next(itr) else: actions = 10 * tf.range(buffer_batch_size) + tf.Variable( 0).count_up_to(9) add_op = replay_buffer.add_batch(actions) itr = tf.compat.v1.data.make_initializable_iterator(ds) get_next = itr.get_next() self.evaluate(tf.compat.v1.global_variables_initializer()) for _ in range(num_adds): # Add 10*range(buffer_batch_size) then 1 + 10*range(buffer_batch_size), .. # The actual episodes are: # [0, 1, 2, ...], # [10, 11, 12, ...], # [20, 21, 22, ...] # ... (buffer_batch_size of these) self.evaluate(add_op) rb_values = [] if not tf.executing_eagerly(): self.evaluate(itr.initializer) try: while True: rb_values.append(self.evaluate(get_next)[0].tolist()) except (tf.errors.OutOfRangeError, StopIteration): pass return replay_buffer, rb_values
def testRNNTrain(self, compute_value_and_advantage_in_train): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._time_step_spec.observation, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, lstm_size=(20, )) value_net = value_rnn_network.ValueRnnNetwork( self._time_step_spec.observation, input_fc_layer_params=None, output_fc_layer_params=None, lstm_size=(10, )) global_step = tf.compat.v1.train.get_or_create_global_step() agent = ppo_agent.PPOAgent( self._time_step_spec, self._action_spec, optimizer=tf.compat.v1.train.AdamOptimizer(), actor_net=actor_net, value_net=value_net, num_epochs=1, train_step_counter=global_step, compute_value_and_advantage_in_train= compute_value_and_advantage_in_train) # Use a random env, policy, and replay buffer to collect training data. random_env = random_tf_environment.RandomTFEnvironment( self._time_step_spec, self._action_spec, batch_size=1) collection_policy = random_tf_policy.RandomTFPolicy( self._time_step_spec, self._action_spec, info_spec=agent.collect_policy.info_spec) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collection_policy.trajectory_spec, batch_size=1, max_length=7) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( random_env, collection_policy, observers=[replay_buffer.add_batch], num_episodes=1) # In graph mode: finish building the graph so the optimizer # variables are created. if not tf.executing_eagerly(): _, _ = agent.train(experience=replay_buffer.gather_all()) # Initialize. self.evaluate(agent.initialize()) self.evaluate(tf.compat.v1.global_variables_initializer()) # Train one step. self.assertEqual(0, self.evaluate(global_step)) self.evaluate(collect_driver.run()) self.evaluate(agent.train(experience=replay_buffer.gather_all())) self.assertEqual(1, self.evaluate(global_step))
def __init__(self, data_spec, batch_size=1, sample_batch_size=64, replay_buffer_capacity=2000): self.data_spec = data_spec self.batch_size = batch_size self.sample_batch_size = sample_batch_size self.replay_buffer_capacity = replay_buffer_capacity self.buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self.data_spec, batch_size=self.batch_size, max_length=self.replay_buffer_capacity)
def build(self): # build environment self.train_py_env = suite_gym.load(self.env_name) self.eval_py_env = suite_gym.load(self.env_name) # we can chagne cartpole parameters here self.train_env = tf_py_environment.TFPyEnvironment(self.train_py_env) self.eval_env = tf_py_environment.TFPyEnvironment(self.eval_py_env) # build agent q_net = q_network.QNetwork( self.train_env.observation_spec(), self.train_env.action_spec(), fc_layer_params=self.fc_layer_params) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate) train_step_counter = tf.Variable(0) self.agent = dqn_agent.DqnAgent( self.train_env.time_step_spec(), self.train_env.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=train_step_counter) self.agent.initialize() # build policy self.random_policy = random_tf_policy.RandomTFPolicy(self.train_env.time_step_spec(),self.train_env.action_spec()) # build replay buffer self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self.agent.collect_data_spec, batch_size=self.train_env.batch_size, max_length=self.replay_buffer_max_length) # build collect self.collect_data(self.train_env, self.random_policy, self.replay_buffer, self.initial_collect_steps) # build dataset self.dataset = self.replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=self.batch_size, num_steps=2).prefetch(3) self.iterator = iter(self.dataset)
def initTrainBuffer(self): # (s,a, S', r, q_value) data_spec = (tf.TensorSpec(self._state_size, tf.float64, 'state'), tf.TensorSpec(self._action_size, tf.float64, 'action'), tf.TensorSpec(self._imgReshapeWithDepth, tf.float32, 'camera'), tf.TensorSpec(self._state_size, tf.float64, 'next_state'), tf.TensorSpec(self._imgReshapeWithDepth, tf.float32, 'next_camera'), tf.TensorSpec(1, tf.float64, 'reward'), tf.TensorSpec(1, tf.bool, 'terminated'), tf.TensorSpec(1, tf.float32, 'q_value')) return tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec, batch_size=1, max_length=self.trainBufferMaxLength)
def replayBufferInit(train_env): """replay buffer stores transition of all the agents""" replay_buffer_capacity = 100000 policy_step_spec = policy_step.PolicyStep(action=train_env.action_spec(), state=(), info=()) replay_buffer_data_spec = trajectory.from_transition( train_env.time_step_spec(), policy_step_spec, train_env.time_step_spec()) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=replay_buffer_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_capacity) return replay_buffer
def testAddSingleSampleBatch(self): batch_size = 1 spec = self._data_spec() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=batch_size, max_length=10) values, add_op = _get_add_op(spec, replay_buffer, batch_size) sample, _ = replay_buffer.get_next(sample_batch_size=3) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(add_op) values_ = self.evaluate(values) sample_ = self.evaluate(sample) tf.nest.map_structure(lambda x, y: self._assertContains([x], list(y)), values_, sample_)
def testMultiStepStackedSampling(self, batch_size): spec = specs.TensorSpec([], tf.int32, 'action') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=batch_size) actions = tf.stack([tf.Variable(0).count_up_to(10)] * batch_size) add_op = replay_buffer.add_batch(actions) steps, _ = replay_buffer.get_next(num_steps=2) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(10): sess.run(add_op) for _ in range(100): steps_ = sess.run(steps) self.assertEqual((steps_[0] + 1) % 10, steps_[1])
def testAdd(self, batch_size): spec = self._data_spec() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=batch_size, max_length=10, scope='rb{}'.format(batch_size)) values, add_op = _get_add_op(spec, replay_buffer, batch_size) sample, _ = replay_buffer.get_next() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(add_op) sample_ = self.evaluate(sample) values_ = self.evaluate(values) tf.nest.map_structure(self.assertAllClose, values_, sample_)
def create_replay_buffer(train_env, agent): # replay_buffer - for data collection. so that training can use the collected Trajectory replay_buffer = \ tf_uniform_replay_buffer.TFUniformReplayBuffer(data_spec=agent.collect_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_capacity) # Dataset generates trajectories with shape [BxTx...] where # T = n_step_update + 1. dataset = replay_buffer.as_dataset(num_parallel_calls=8, sample_batch_size=batch_size, num_steps=n_step_update + 1) # .prefetch(3) replay_buffer_itr = iter(dataset) # <---- iterator being used in training return replay_buffer, replay_buffer_itr
def train(num_iterations): train_env = tf_py_environment.TFPyEnvironment(Cliff()) test_env = tf_py_environment.TFPyEnvironment(Cliff()) counter = tf.Variable(0) # Build network network = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), fc_layer_params=(100, )) agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=network, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3), td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=counter) agent.initialize() agent.train = common.function(agent.train) agent.train_step_counter.assign(0) buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=train_env.batch_size, max_length=100) dataset = buffer.as_dataset(sample_batch_size=32, num_steps=2) iterator = iter(dataset) first_reward = compute_average_reward(train_env, agent.policy, num_episodes=10) print(f'Before training: {first_reward}') rewards = [first_reward] for _ in range(num_iterations): for _ in range(2): collect_steps(train_env, agent.collect_policy, buffer) experience, info = next(iterator) loss = agent.train(experience).loss step_number = agent.train_step_counter.numpy() if step_number % 10 == 0: print(f'step={step_number}: loss={loss}') if step_number % 20 == 0: average_reward = compute_average_reward(test_env, agent.policy, 1) print(f'step={step_number}: Reward:={average_reward}')
def testMultiStepSampling(self, batch_size): spec = specs.TensorSpec([], tf.int32, 'action') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=batch_size) action = tf.stack([tf.Variable(0).count_up_to(10)] * batch_size) add_op = replay_buffer.add_batch(action) (step, next_step), _ = replay_buffer.get_next(num_steps=2, time_stacked=False) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(10): sess.run(add_op) for _ in range(100): step_, next_step_ = sess.run([step, next_step]) self.assertEqual((step_ + 1) % 10, next_step_)
def _create_replay_buffer(self, num_steps=2): # Replay Bufferの初期化。初期データ収集 self.replay_buffer = buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self.agent.collect_data_spec, batch_size=self.env.batch_size, # actually 1, env isn't batched max_length=replay_buffer_capacity ) print(buffer.capacity.numpy(), buffer._batch_size) print(buffer.data_spec) self._collect_data( self.agent.collect_policy, initial_collect_steps) dataset = buffer.as_dataset( num_parallel_calls=3, num_steps=num_steps, sample_batch_size=batch_size, ).prefetch(batch_size) self.data_iterator = iter(dataset)
def __init__(self, env_name, learning_rate, num_episodes, replay_buffer_max_length, steps, num_iterations, collect_step_per_iteration, log_interval, batch_size, total_run_minutes, perturb_duration, kernel_initializer): self.env_name = env_name train_py_env = suite_gym.load(self.env_name) eval_py_env = suite_gym.load(self.env_name) train_py_env.reset() eval_py_env.reset() self.train_env = tf_py_environment.TFPyEnvironment(train_py_env) self.eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) self.learning_rate = learning_rate self.num_episodes = num_episodes self.replay_buffer_max_length = replay_buffer_max_length self.steps = steps self.num_iterations = num_iterations self.collect_step_per_iteration = collect_step_per_iteration self.log_interval = log_interval self.eval_interval = eval_interval self.batch_size = batch_size self.returns = [] self.total_run_minutes = total_run_minutes self.numLogs = 2 * self.total_run_minutes self.numEvals = self.total_run_minutes self.eval_times = [] self.perturb_duration = perturb_duration self.kernel_initializer = kernel_initializer self.perturb_times = [] self.agent_init() self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self.data_spec, batch_size=self.train_env.batch_size, max_length=self.replay_buffer_max_length) self.start_time = time.time() agent_thread = threading.Thread(target=self.train_agent) actor_thread = threading.Thread(target=self.act) actor_thread.start() agent_thread.start() actor_thread.join() agent_thread.join()
def train_on_policy_tf_agent( model: TFAgent, train_env: TFPyEnvironment, total_timesteps: int, callback: callbacks.BaseKindoRLCallback = None, ): replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( model.collect_data_spec, batch_size=train_env.batch_size, max_length=100000) if callback is not None: callback.init_callback(model, train_env=train_env) collect_policy = model.collect_policy locals_ = { "episode_rewards": [], "episode_losses": [], "episode_lengths": [] } passed_timesteps = 0 if callback is not None: callback.on_training_start(locals_, {}) while passed_timesteps < total_timesteps: episode_reward, episode_length = utils.step_episode( train_env, collect_policy, replay_buffer) passed_timesteps += episode_length locals_["episode_rewards"].append(episode_reward) locals_["episode_lengths"].append(episode_length) experience = replay_buffer.gather_all() train_loss = model.train(experience).loss.numpy() locals_["episode_losses"].append(train_loss) replay_buffer.clear() if callback is not None: callback.update_locals(locals_) continue_training = callback.on_steps(num_steps=episode_length) if not continue_training: break if callback is not None: callback.on_training_end()
def testMultiStepStackedSampling(self, batch_size): if tf.executing_eagerly(): self.skipTest('b/123885577') spec = specs.TensorSpec([], tf.int32, 'action') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=batch_size) actions = tf.stack([tf.Variable(0).count_up_to(10)] * batch_size) add_op = replay_buffer.add_batch(actions) steps, _ = replay_buffer.get_next(num_steps=2) self.evaluate(tf.compat.v1.global_variables_initializer()) for _ in range(10): self.evaluate(add_op) for _ in range(100): steps_ = self.evaluate(steps) self.assertEqual((steps_[0] + 1) % 10, steps_[1])