def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. prev_a_r = [] if self.model_config["lstm_use_prev_action"]: if isinstance(self.action_space, (Discrete, MultiDiscrete)): prev_a = one_hot(input_dict[SampleBatch.PREV_ACTIONS].float(), self.action_space) else: prev_a = input_dict[SampleBatch.PREV_ACTIONS].float() prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim])) if self.model_config["lstm_use_prev_reward"]: prev_a_r.append( torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), [-1, 1])) if prev_a_r: wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) # Then through our LSTM. input_dict["obs_flat"] = wrapped_out return super().forward(input_dict, state, seq_lens)
def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: orig_obs = restore_original_dimensions( input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="torch") # Push image observations through our CNNs. outs = [] for i, component in enumerate(tree.flatten(orig_obs)): if i in self.cnns: cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) outs.append(cnn_out) elif i in self.one_hot: if component.dtype in [torch.int32, torch.int64, torch.uint8]: outs.append( one_hot(component, self.flattened_input_space[i])) else: outs.append(component) else: outs.append(torch.reshape(component, [-1, self.flatten[i]])) # Concat all outputs and the non-image inputs. out = torch.cat(outs, dim=1) # Push through (optional) FC-stack (this may be an empty stack). out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None) # No logits/value branches. if self.logits_layer is None: return out, [] # Logits- and value branches. logits, values = self.logits_layer(out), self.value_layer(out) self._value_out = torch.reshape(values, [-1]) return logits, []
def forward(self, input_dict, state, seq_lens): # Push image observations through our CNNs. outs = [] for i, component in enumerate(input_dict["obs"]): if i in self.cnns: cnn_out, _ = self.cnns[i]({"obs": component}) outs.append(cnn_out) elif i in self.one_hot: if component.dtype in [torch.int32, torch.int64, torch.uint8]: outs.append( one_hot(component, self.original_space.spaces[i])) else: outs.append(component) else: outs.append(torch.reshape(component, [-1, self.flatten[i]])) # Concat all outputs and the non-image inputs. out = torch.cat(outs, dim=1) # Push through (optional) FC-stack (this may be an empty stack). out, _ = self.post_fc_stack({"obs": out}, [], None) # No logits/value branches. if self.logits_layer is None: return out, [] # Logits- and value branches. logits, values = self.logits_layer(out), self.value_layer(out) self._value_out = torch.reshape(values, [-1]) return logits, []
def forward(self, input_dict, state, seq_lens): # Push image observations through our CNNs. orig_obs = restore_original_dimensions(input_dict.get("obs"), self.new_obs_space, "torch") mode = input_dict.get( "is_training", False) if input_dict.get("obs").shape[0] > 1 else False outs = [] v_outs = [] for i, component in enumerate(orig_obs[:-1]): if i in self.cnns: cnn_out, _ = self.cnns[i]({"obs": component}) outs.append(cnn_out) v_outs.append(self.cnns[i].value_function()) elif i in self.one_hot: if component.dtype in [torch.int32, torch.int64, torch.uint8]: outs.append( one_hot(component, self.original_space.spaces[i])) v_outs.append( one_hot(component, self.original_space.spaces[i])) else: outs.append(component) v_outs.append(component) else: outs.append(torch.reshape(component, [-1, self.flatten[i]])) v_outs.append(torch.reshape(component, [-1, self.flatten[i]])) # Concat all outputs and the non-image inputs. out = torch.cat(outs, dim=1) v_out = torch.cat(v_outs, dim=1) # Push through (optional) FC-stack (this may be an empty stack). self.post_fc_stack.train(mode=mode) self.post_fc_stack_vf.train(mode=mode) out_p = self.post_fc_stack(out) out_v = self.post_fc_stack_vf(v_out) # No logits/value branches. if self.logits_layer is None: return out, [] # Logits- and value branches. logits, values = self.logits_layer(out_p), self.value_layer(out_v) inf = torch.from_numpy(np.array(float('-inf'))).to( torch.device('cuda')) inf_mask = torch.maximum(torch.log(orig_obs[-1]), inf) self._value_out = torch.reshape(values, [-1]) return logits + inf_mask, []
def postprocess_trajectory(self, policy, sample_batch, tf_sess=None): """Calculates phi values (obs, obs', and predicted obs') and ri. Also calculates forward and inverse losses and updates the curiosity module on the provided batch using our optimizer. """ # Push both observations through feature net to get both phis. phis, _ = self.model._curiosity_feature_net({ SampleBatch.OBS: torch.cat([ torch.from_numpy(sample_batch[SampleBatch.OBS]), torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]) ]) }) phi, next_phi = torch.chunk(phis, 2) actions_tensor = torch.from_numpy( sample_batch[SampleBatch.ACTIONS]).long().to(policy.device) # Predict next phi with forward model. predicted_next_phi = self.model._curiosity_forward_fcnet( torch.cat( [phi, one_hot(actions_tensor, self.action_space).float()], dim=-1)) # Forward loss term (predicted phi', given phi and action vs actually # observed phi'). forward_l2_norm_sqared = 0.5 * torch.sum( torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1) forward_loss = torch.mean(forward_l2_norm_sqared) # Scale intrinsic reward by eta hyper-parameter. sample_batch[SampleBatch.REWARDS] = \ sample_batch[SampleBatch.REWARDS] + \ self.eta * forward_l2_norm_sqared.detach().cpu().numpy() # Inverse loss term (prediced action that led from phi to phi' vs # actual action taken). phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1) dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi) action_dist = TorchCategorical(dist_inputs, self.model) if \ isinstance(self.action_space, Discrete) else \ TorchMultiCategorical( dist_inputs, self.model, self.action_space.nvec) # Neg log(p); p=probability of observed action given the inverse-NN # predicted action distribution. inverse_loss = -action_dist.logp(actions_tensor) inverse_loss = torch.mean(inverse_loss) # Calculate the ICM loss. loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss # Perform an optimizer step. self._optimizer.zero_grad() loss.backward() self._optimizer.step() # Return the postprocessed sample batch (with the corrected rewards). return sample_batch
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. prev_a_r = [] if self.use_n_prev_actions: if isinstance(self.action_space, Discrete): for i in range(self.use_n_prev_actions): prev_a_r.append( one_hot( input_dict[SampleBatch.PREV_ACTIONS][:, i].float(), self.action_space)) elif isinstance(self.action_space, MultiDiscrete): for i in range(self.use_n_prev_actions, step=self.action_space.shape[0]): prev_a_r.append( one_hot( input_dict[SampleBatch.PREV_ACTIONS] [:, i:i + self.action_space.shape[0]].float(), self.action_space)) else: prev_a_r.append( torch.reshape( input_dict[SampleBatch.PREV_ACTIONS].float(), [-1, self.use_n_prev_actions * self.action_dim])) if self.use_n_prev_rewards: prev_a_r.append( torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), [-1, self.use_n_prev_rewards])) if prev_a_r: wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) # Then through our GTrXL. input_dict["obs_flat"] = input_dict["obs"] = wrapped_out self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens) model_out = self._logits_branch(self._features) return model_out, memory_outs