def update(self, update_buffer: Buffer, num_sequences: int) -> Dict[str, float]:
        """
        Updates Curiosity model using training buffer. Divides training buffer into mini batches and performs
        gradient descent.
        :param update_buffer: Update buffer from which to pull data from.
        :param num_sequences: Number of sequences in the update buffer.
        :return: Dict of stats that should be reported to Tensorboard.
        """
        forward_total: List[float] = []
        inverse_total: List[float] = []
        for _ in range(self.num_epoch):
            update_buffer.shuffle()
            buffer = update_buffer
            for l in range(len(update_buffer["actions"]) // num_sequences):
                start = l * num_sequences
                end = (l + 1) * num_sequences
                run_out_curio = self._update_batch(
                    buffer.make_mini_batch(start, end), num_sequences
                )
                inverse_total.append(run_out_curio["inverse_loss"])
                forward_total.append(run_out_curio["forward_loss"])

        update_stats = {
            "Losses/Curiosity Forward Loss": np.mean(forward_total),
            "Losses/Curiosity Inverse Loss": np.mean(inverse_total),
        }
        return update_stats
    def update(self, update_buffer: Buffer, n_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param update_buffer: The policy buffer containing the trajectories for the current policy.
        :param n_sequences: The number of sequences from demo and policy used in each mini batch.
        :return: The loss of the update.
        """
        batch_losses = []
        # Divide by 2 since we have two buffers, so we have roughly the same batch size
        n_sequences = max(n_sequences // 2, 1)
        possible_demo_batches = (
            len(self.demonstration_buffer.update_buffer["actions"]) // n_sequences
        )
        possible_policy_batches = len(update_buffer["actions"]) // n_sequences
        possible_batches = min(possible_policy_batches, possible_demo_batches)

        max_batches = self.samples_per_update // n_sequences

        kl_loss = []
        policy_estimate = []
        expert_estimate = []
        z_log_sigma_sq = []
        z_mean_expert = []
        z_mean_policy = []

        n_epoch = self.num_epoch
        for _epoch in range(n_epoch):
            self.demonstration_buffer.update_buffer.shuffle()
            update_buffer.shuffle()
            if max_batches == 0:
                num_batches = possible_batches
            else:
                num_batches = min(possible_batches, max_batches)
            for i in range(num_batches):
                demo_update_buffer = self.demonstration_buffer.update_buffer
                policy_update_buffer = update_buffer
                start = i * n_sequences
                end = (i + 1) * n_sequences
                mini_batch_demo = demo_update_buffer.make_mini_batch(start, end)
                mini_batch_policy = policy_update_buffer.make_mini_batch(start, end)
                run_out = self._update_batch(mini_batch_demo, mini_batch_policy)
                loss = run_out["gail_loss"]

                policy_estimate.append(run_out["policy_estimate"])
                expert_estimate.append(run_out["expert_estimate"])
                if self.model.use_vail:
                    kl_loss.append(run_out["kl_loss"])
                    z_log_sigma_sq.append(run_out["z_log_sigma_sq"])
                    z_mean_policy.append(run_out["z_mean_policy"])
                    z_mean_expert.append(run_out["z_mean_expert"])

                batch_losses.append(loss)
        self.has_updated = True

        print_list = ["n_epoch", "beta", "policy_estimate", "expert_estimate"]
        print_vals = [
            n_epoch,
            self.policy.sess.run(self.model.beta),
            np.mean(policy_estimate),
            np.mean(expert_estimate),
        ]
        if self.model.use_vail:
            print_list += [
                "kl_loss",
                "z_mean_expert",
                "z_mean_policy",
                "z_log_sigma_sq",
            ]
            print_vals += [
                np.mean(kl_loss),
                np.mean(z_mean_expert),
                np.mean(z_mean_policy),
                np.mean(z_log_sigma_sq),
            ]
        LOGGER.debug(
            "GAIL Debug:\n\t\t"
            + "\n\t\t".join(
                "{0}: {1}".format(_name, _val)
                for _name, _val in zip(print_list, print_vals)
            )
        )
        update_stats = {"Losses/GAIL Loss": np.mean(batch_losses)}
        return update_stats