def forward(self, input_dict, state, seq_lens): if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: orig_obs = input_dict[SampleBatch.OBS] else: orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="torch") # Push image observations through our CNNs. outs = [] for i, component in enumerate(tree.flatten(orig_obs)): if i in self.cnns: cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) outs.append(cnn_out) elif i in self.one_hot: if component.dtype in [torch.int32, torch.int64, torch.uint8]: outs.append( one_hot(component, self.flattened_input_space[i])) else: outs.append(component) else: outs.append(torch.reshape(component, [-1, self.flatten[i]])) # Concat all outputs and the non-image inputs. out = torch.cat(outs, dim=1) # Push through (optional) FC-stack (this may be an empty stack). out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None) # No logits/value branches. if self.logits_layer is None: return out, [] # Logits- and value branches. logits, values = self.logits_layer(out), self.value_layer(out) self._value_out = torch.reshape(values, [-1]) return logits, []
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. prev_a_r = [] if self.model_config["lstm_use_prev_action"]: if isinstance(self.action_space, (Discrete, MultiDiscrete)): prev_a = one_hot(input_dict[SampleBatch.PREV_ACTIONS].float(), self.action_space) else: prev_a = input_dict[SampleBatch.PREV_ACTIONS].float() prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim])) if self.model_config["lstm_use_prev_reward"]: prev_a_r.append( torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), [-1, 1])) if prev_a_r: wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) # Then through our LSTM. input_dict["obs_flat"] = wrapped_out return super().forward(input_dict, state, seq_lens)
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. prev_a_r = [] if self.use_n_prev_actions: if isinstance(self.action_space, Discrete): for i in range(self.use_n_prev_actions): prev_a_r.append( one_hot( input_dict[SampleBatch.PREV_ACTIONS][:, i].float(), self.action_space)) elif isinstance(self.action_space, MultiDiscrete): for i in range(self.use_n_prev_actions, step=self.action_space.shape[0]): prev_a_r.append( one_hot( input_dict[SampleBatch.PREV_ACTIONS] [:, i:i + self.action_space.shape[0]].float(), self.action_space)) else: prev_a_r.append( torch.reshape( input_dict[SampleBatch.PREV_ACTIONS].float(), [-1, self.use_n_prev_actions * self.action_dim])) if self.use_n_prev_rewards: prev_a_r.append( torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), [-1, self.use_n_prev_rewards])) if prev_a_r: wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) # Then through our GTrXL. input_dict["obs_flat"] = input_dict["obs"] = wrapped_out self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens) model_out = self._logits_branch(self._features) return model_out, memory_outs
def _postprocess_torch(self, policy, sample_batch): # Push both observations through feature net to get both phis. phis, _ = self.model._curiosity_feature_net({ SampleBatch.OBS: torch.cat([ torch.from_numpy(sample_batch[SampleBatch.OBS]).to( policy.device), torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]).to( policy.device), ]) }) phi, next_phi = torch.chunk(phis, 2) actions_tensor = (torch.from_numpy( sample_batch[SampleBatch.ACTIONS]).long().to(policy.device)) # Predict next phi with forward model. predicted_next_phi = self.model._curiosity_forward_fcnet( torch.cat( [phi, one_hot(actions_tensor, self.action_space).float()], dim=-1)) # Forward loss term (predicted phi', given phi and action vs actually # observed phi'). forward_l2_norm_sqared = 0.5 * torch.sum( torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1) forward_loss = torch.mean(forward_l2_norm_sqared) # Scale intrinsic reward by eta hyper-parameter. sample_batch[SampleBatch.REWARDS] = ( sample_batch[SampleBatch.REWARDS] + self.eta * forward_l2_norm_sqared.detach().cpu().numpy()) # Inverse loss term (prediced action that led from phi to phi' vs # actual action taken). phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1) dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi) action_dist = (TorchCategorical(dist_inputs, self.model) if isinstance( self.action_space, Discrete) else TorchMultiCategorical( dist_inputs, self.model, self.action_space.nvec)) # Neg log(p); p=probability of observed action given the inverse-NN # predicted action distribution. inverse_loss = -action_dist.logp(actions_tensor) inverse_loss = torch.mean(inverse_loss) # Calculate the ICM loss. loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss # Perform an optimizer step. self._optimizer.zero_grad() loss.backward() self._optimizer.step() # Return the postprocessed sample batch (with the corrected rewards). return sample_batch
def forward( self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType, ) -> (TensorType, List[TensorType]): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. prev_a_r = [] # Prev actions. if self.model_config["lstm_use_prev_action"]: prev_a = input_dict[SampleBatch.PREV_ACTIONS] # If actions are not processed yet (in their original form as # have been sent to environment): # Flatten/one-hot into 1D array. if self.model_config["_disable_action_flattening"]: prev_a_r.append( flatten_inputs_to_1d_tensor( prev_a, spaces_struct=self.action_space_struct, time_axis=False)) # If actions are already flattened (but not one-hot'd yet!), # one-hot discrete/multi-discrete actions here. else: if isinstance(self.action_space, (Discrete, MultiDiscrete)): prev_a = one_hot(prev_a.float(), self.action_space) else: prev_a = prev_a.float() prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim])) # Prev rewards. if self.model_config["lstm_use_prev_reward"]: prev_a_r.append( torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), [-1, 1])) # Concat prev. actions + rewards to the "main" input. if prev_a_r: wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) # Push everything through our LSTM. input_dict["obs_flat"] = wrapped_out return super().forward(input_dict, state, seq_lens)
def forward( self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType, ) -> (TensorType, List[TensorType]): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. prev_a_r = [] # Prev actions. if self.use_n_prev_actions: prev_n_actions = input_dict[SampleBatch.PREV_ACTIONS] # If actions are not processed yet (in their original form as # have been sent to environment): # Flatten/one-hot into 1D array. if self.model_config["_disable_action_flattening"]: # Merge prev n actions into flat tensor. flat = flatten_inputs_to_1d_tensor( prev_n_actions, spaces_struct=self.action_space_struct, time_axis=True, ) # Fold time-axis into flattened data. flat = torch.reshape(flat, [flat.shape[0], -1]) prev_a_r.append(flat) # If actions are already flattened (but not one-hot'd yet!), # one-hot discrete/multi-discrete actions here and concatenate the # n most recent actions together. else: if isinstance(self.action_space, Discrete): for i in range(self.use_n_prev_actions): prev_a_r.append( one_hot(prev_n_actions[:, i].float(), space=self.action_space)) elif isinstance(self.action_space, MultiDiscrete): for i in range(0, self.use_n_prev_actions, self.action_space.shape[0]): prev_a_r.append( one_hot( prev_n_actions[:, i:i + self.action_space.shape[0]]. float(), space=self.action_space, )) else: prev_a_r.append( torch.reshape( prev_n_actions.float(), [-1, self.use_n_prev_actions * self.action_dim], )) # Prev rewards. if self.use_n_prev_rewards: prev_a_r.append( torch.reshape( input_dict[SampleBatch.PREV_REWARDS].float(), [-1, self.use_n_prev_rewards], )) # Concat prev. actions + rewards to the "main" input. if prev_a_r: wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) # Then through our GTrXL. input_dict["obs_flat"] = input_dict["obs"] = wrapped_out self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens) model_out = self._logits_branch(self._features) return model_out, memory_outs