def update(self, update_buffer: Buffer, num_sequences: int) -> Dict[str, float]: """ Updates Curiosity model using training buffer. Divides training buffer into mini batches and performs gradient descent. :param update_buffer: Update buffer from which to pull data from. :param num_sequences: Number of sequences in the update buffer. :return: Dict of stats that should be reported to Tensorboard. """ forward_total: List[float] = [] inverse_total: List[float] = [] for _ in range(self.num_epoch): update_buffer.shuffle() buffer = update_buffer for l in range(len(update_buffer["actions"]) // num_sequences): start = l * num_sequences end = (l + 1) * num_sequences run_out_curio = self._update_batch( buffer.make_mini_batch(start, end), num_sequences ) inverse_total.append(run_out_curio["inverse_loss"]) forward_total.append(run_out_curio["forward_loss"]) update_stats = { "Losses/Curiosity Forward Loss": np.mean(forward_total), "Losses/Curiosity Inverse Loss": np.mean(inverse_total), } return update_stats
def update(self, update_buffer: Buffer, n_sequences: int) -> Dict[str, float]: """ Updates model using buffer. :param update_buffer: The policy buffer containing the trajectories for the current policy. :param n_sequences: The number of sequences from demo and policy used in each mini batch. :return: The loss of the update. """ batch_losses = [] # Divide by 2 since we have two buffers, so we have roughly the same batch size n_sequences = max(n_sequences // 2, 1) possible_demo_batches = ( len(self.demonstration_buffer.update_buffer["actions"]) // n_sequences ) possible_policy_batches = len(update_buffer["actions"]) // n_sequences possible_batches = min(possible_policy_batches, possible_demo_batches) max_batches = self.samples_per_update // n_sequences kl_loss = [] policy_estimate = [] expert_estimate = [] z_log_sigma_sq = [] z_mean_expert = [] z_mean_policy = [] n_epoch = self.num_epoch for _epoch in range(n_epoch): self.demonstration_buffer.update_buffer.shuffle() update_buffer.shuffle() if max_batches == 0: num_batches = possible_batches else: num_batches = min(possible_batches, max_batches) for i in range(num_batches): demo_update_buffer = self.demonstration_buffer.update_buffer policy_update_buffer = update_buffer start = i * n_sequences end = (i + 1) * n_sequences mini_batch_demo = demo_update_buffer.make_mini_batch(start, end) mini_batch_policy = policy_update_buffer.make_mini_batch(start, end) run_out = self._update_batch(mini_batch_demo, mini_batch_policy) loss = run_out["gail_loss"] policy_estimate.append(run_out["policy_estimate"]) expert_estimate.append(run_out["expert_estimate"]) if self.model.use_vail: kl_loss.append(run_out["kl_loss"]) z_log_sigma_sq.append(run_out["z_log_sigma_sq"]) z_mean_policy.append(run_out["z_mean_policy"]) z_mean_expert.append(run_out["z_mean_expert"]) batch_losses.append(loss) self.has_updated = True print_list = ["n_epoch", "beta", "policy_estimate", "expert_estimate"] print_vals = [ n_epoch, self.policy.sess.run(self.model.beta), np.mean(policy_estimate), np.mean(expert_estimate), ] if self.model.use_vail: print_list += [ "kl_loss", "z_mean_expert", "z_mean_policy", "z_log_sigma_sq", ] print_vals += [ np.mean(kl_loss), np.mean(z_mean_expert), np.mean(z_mean_policy), np.mean(z_log_sigma_sq), ] LOGGER.debug( "GAIL Debug:\n\t\t" + "\n\t\t".join( "{0}: {1}".format(_name, _val) for _name, _val in zip(print_list, print_vals) ) ) update_stats = {"Losses/GAIL Loss": np.mean(batch_losses)} return update_stats