def postprocess_trajectory(self, policy, sample_batch, tf_sess=None): """Calculates phi values (obs, obs', and predicted obs') and ri. Also calculates forward and inverse losses and updates the curiosity module on the provided batch using our optimizer. """ # Push both observations through feature net to get both phis. phis, _ = self.model._curiosity_feature_net({ SampleBatch.OBS: torch.cat([ torch.from_numpy(sample_batch[SampleBatch.OBS]), torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]) ]) }) phi, next_phi = torch.chunk(phis, 2) actions_tensor = torch.from_numpy( sample_batch[SampleBatch.ACTIONS]).long().to(policy.device) # Predict next phi with forward model. predicted_next_phi = self.model._curiosity_forward_fcnet( torch.cat( [phi, one_hot(actions_tensor, self.action_space).float()], dim=-1)) # Forward loss term (predicted phi', given phi and action vs actually # observed phi'). forward_l2_norm_sqared = 0.5 * torch.sum( torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1) forward_loss = torch.mean(forward_l2_norm_sqared) # Scale intrinsic reward by eta hyper-parameter. sample_batch[SampleBatch.REWARDS] = \ sample_batch[SampleBatch.REWARDS] + \ self.eta * forward_l2_norm_sqared.detach().cpu().numpy() # Inverse loss term (prediced action that led from phi to phi' vs # actual action taken). phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1) dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi) action_dist = TorchCategorical(dist_inputs, self.model) if \ isinstance(self.action_space, Discrete) else \ TorchMultiCategorical( dist_inputs, self.model, self.action_space.nvec) # Neg log(p); p=probability of observed action given the inverse-NN # predicted action distribution. inverse_loss = -action_dist.logp(actions_tensor) inverse_loss = torch.mean(inverse_loss) # Calculate the ICM loss. loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss # Perform an optimizer step. self._optimizer.zero_grad() loss.backward() self._optimizer.step() # Return the postprocessed sample batch (with the corrected rewards). return sample_batch
def build_CAT_vtrace_loss(policy, model, dist_class, train_batch): action_space_parts = model.action_space_parts def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) # Repeat the output_hidden_shape depending on the number of actions that have been generated # output_hidden_shape = np.tile(output_hidden_shape, action_repeats) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] invalid_action_mask = train_batch['invalid_action_mask'] autoregressive_actions = policy.config['autoregressive_actions'] if 'seq_lens' in train_batch: max_seq_len = policy.config['rollout_fragment_length'] mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask_orig, [-1]) else: mask = torch.ones_like(rewards) actions_per_step = policy.config["actions_per_step"] states = [] i = 0 while "state_in_{}".format(i) in train_batch: states.append(train_batch["state_in_{}".format(i)]) i += 1 seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] model.observation_features_module(train_batch, states, seq_lens) action_features, _ = model.action_features_module(train_batch, states, seq_lens) previous_action = None embedded_action = None logp_list = [] entropy_list = [] logits_list = [] multi_actions = torch.chunk(actions, actions_per_step, dim=1) multi_invalid_action_mask = torch.chunk(invalid_action_mask, actions_per_step, dim=1) for a in range(actions_per_step): if autoregressive_actions: if a == 0: batch_size = action_features.shape[0] previous_action = torch.zeros([batch_size, len(action_space_parts)]).to(action_features.device) else: previous_action = multi_actions[a-1] embedded_action = model.embed_action_module(previous_action) logits = model.action_module(action_features, embedded_action) logits += torch.maximum(torch.tensor(torch.finfo().min), torch.log(multi_invalid_action_mask[a])) cat = TorchMultiCategorical(logits, model, action_space_parts) logits_list.append(logits) logp_list.append(cat.logp(multi_actions[a])) entropy_list.append(cat.entropy()) logp = torch.stack(logp_list, dim=1).sum(dim=1) entropy = torch.stack(entropy_list, dim=1).sum(dim=1) target_logits = torch.hstack(logits_list) unpack_shape = np.tile(action_space_parts, actions_per_step) unpacked_behaviour_logits = torch.split(behaviour_logits, list(unpack_shape), dim=1) unpacked_outputs = torch.split(target_logits, list(unpack_shape), dim=1) values = model.value_function() # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. policy.loss = VTraceLoss( actions=_make_time_major(actions, drop_last=True), actions_logp=_make_time_major(logp, drop_last=True), actions_entropy=_make_time_major(entropy, drop_last=True), dones=_make_time_major(dones, drop_last=True), behaviour_action_logp=_make_time_major( behaviour_action_logp, drop_last=True), behaviour_logits=_make_time_major( unpacked_behaviour_logits, drop_last=True), target_logits=_make_time_major(unpacked_outputs, drop_last=True), discount=policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=_make_time_major(values, drop_last=True), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical, model=model, valid_mask=_make_time_major(mask, drop_last=True), config=policy.config, vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.entropy_coeff, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) return policy.loss.total_loss