class TFUniformReplayBuffer(TFReplayBufferAbstract): def _init_replay_buffer(self, batch_size, data_spec): self._batch_size = batch_size buffer_config = { "batch_size": self._batch_size, "data_spec": data_spec, "max_length": 1 } tf.compat.v2.summary.scalar(name="replay_buffer_size", data=self._batch_size) self._replay_buffer = TFReplayBuffer(**buffer_config) def add_batch(self, traj_dict): """ add a trajectory to the replay buffer Params traj (dict[dim]:numpy): a dict of tensors representing the trajectory to be added it to the replay buffer """ collect_spec_dict = self.collect_data_spec._asdict() traj_tf, traj_spec = build_tf_trajectory(traj_dict, collect_spec_dict) if not self._replay_buffer: batch_size = len(traj_dict["observation"]) self._init_replay_buffer(batch_size, traj_spec) self._replay_buffer.add_batch(traj_tf) def get_batch(self, batch_size): if batch_size is None: batch_size = self._batch_size # TODO: convert the replay buffer to a dataset and iterate over it traj, metadata = self._replay_buffer.get_next( sample_batch_size=batch_size) return traj, metadata
global_step = tf.compat.v1.train.get_global_step() # Create a policy saver policy_saver = PolicySaver(agent.policy) # Main training loop time_step, policy_state = None, None for it in range(N_ITERATIONS): if COLLECT_RANDOM: print('Running random driver...') time_step, policy_state = random_driver.run(time_step, policy_state) print('Running agent driver...') time_step, policy_state = driver.run(time_step, policy_state) print('Training...') for train_it in range(BUFFER_LENGTH//BATCH_SIZE): experience, _ = replay_buffer.get_next(sample_batch_size=BATCH_SIZE, num_steps=2) agent.train(experience) if (train_it + 1) % 100 == 0: print('{0} training iterations'.format(train_it + 1)) print('Saving...') # Save to checkpoint checkpointer.save(global_step) # Save policy policy_saver.save(os.path.relpath('policy')) # Show total reward of actual policy for 1 episode total_reward = 0.0 eval_ts = eval_env.reset() num_steps = 0 while (not eval_ts.is_last()) and num_steps < EVAL_MAX_STEPS: action_step = agent.policy.action(eval_ts) eval_ts = eval_env.step(action_step.action)