class TFUniformReplayBuffer(TFReplayBufferAbstract):
    def _init_replay_buffer(self, batch_size, data_spec):
        self._batch_size = batch_size
        buffer_config = {
            "batch_size": self._batch_size,
            "data_spec": data_spec,
            "max_length": 1
        }
        tf.compat.v2.summary.scalar(name="replay_buffer_size",
                                    data=self._batch_size)
        self._replay_buffer = TFReplayBuffer(**buffer_config)

    def add_batch(self, traj_dict):
        """
        add a trajectory to the replay buffer

        Params
            traj (dict[dim]:numpy): a dict of tensors representing the trajectory to be added it to the replay buffer
        """

        collect_spec_dict = self.collect_data_spec._asdict()
        traj_tf, traj_spec = build_tf_trajectory(traj_dict, collect_spec_dict)

        if not self._replay_buffer:
            batch_size = len(traj_dict["observation"])
            self._init_replay_buffer(batch_size, traj_spec)

        self._replay_buffer.add_batch(traj_tf)

    def get_batch(self, batch_size):

        if batch_size is None:
            batch_size = self._batch_size

        # TODO: convert the replay buffer to a dataset and iterate over it
        traj, metadata = self._replay_buffer.get_next(
            sample_batch_size=batch_size)
        return traj, metadata
示例#2
0
    global_step = tf.compat.v1.train.get_global_step()

    # Create a policy saver
    policy_saver = PolicySaver(agent.policy)

    # Main training loop
    time_step, policy_state = None, None
    for it in range(N_ITERATIONS):
        if COLLECT_RANDOM:
            print('Running random driver...')
            time_step, policy_state = random_driver.run(time_step, policy_state)
        print('Running agent driver...')
        time_step, policy_state = driver.run(time_step, policy_state)
        print('Training...')
        for train_it in range(BUFFER_LENGTH//BATCH_SIZE):
            experience, _ = replay_buffer.get_next(sample_batch_size=BATCH_SIZE, num_steps=2)
            agent.train(experience)
            if (train_it + 1) % 100 == 0:
                print('{0} training iterations'.format(train_it + 1))
        print('Saving...')
        # Save to checkpoint
        checkpointer.save(global_step)
        # Save policy
        policy_saver.save(os.path.relpath('policy'))
        # Show total reward of actual policy for 1 episode
        total_reward = 0.0
        eval_ts = eval_env.reset()
        num_steps = 0
        while (not eval_ts.is_last()) and num_steps < EVAL_MAX_STEPS:
            action_step = agent.policy.action(eval_ts)
            eval_ts = eval_env.step(action_step.action)