def forward(self, input_dict, state, seq_lens):
        if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
            orig_obs = input_dict[SampleBatch.OBS]
        else:
            orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS],
                                                   self.processed_obs_space,
                                                   tensorlib="torch")
        # Push image observations through our CNNs.
        outs = []
        for i, component in enumerate(tree.flatten(orig_obs)):
            if i in self.cnns:
                cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
                outs.append(cnn_out)
            elif i in self.one_hot:
                if component.dtype in [torch.int32, torch.int64, torch.uint8]:
                    outs.append(
                        one_hot(component, self.flattened_input_space[i]))
                else:
                    outs.append(component)
            else:
                outs.append(torch.reshape(component, [-1, self.flatten[i]]))
        # Concat all outputs and the non-image inputs.
        out = torch.cat(outs, dim=1)
        # Push through (optional) FC-stack (this may be an empty stack).
        out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None)

        # No logits/value branches.
        if self.logits_layer is None:
            return out, []

        # Logits- and value branches.
        logits, values = self.logits_layer(out), self.value_layer(out)
        self._value_out = torch.reshape(values, [-1])
        return logits, []
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.model_config["lstm_use_prev_action"]:
            if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                prev_a = one_hot(input_dict[SampleBatch.PREV_ACTIONS].float(),
                                 self.action_space)
            else:
                prev_a = input_dict[SampleBatch.PREV_ACTIONS].float()
            prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim]))
        if self.model_config["lstm_use_prev_reward"]:
            prev_a_r.append(
                torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
                              [-1, 1]))

        if prev_a_r:
            wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)

        # Then through our LSTM.
        input_dict["obs_flat"] = wrapped_out
        return super().forward(input_dict, state, seq_lens)
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.use_n_prev_actions:
            if isinstance(self.action_space, Discrete):
                for i in range(self.use_n_prev_actions):
                    prev_a_r.append(
                        one_hot(
                            input_dict[SampleBatch.PREV_ACTIONS][:, i].float(),
                            self.action_space))
            elif isinstance(self.action_space, MultiDiscrete):
                for i in range(self.use_n_prev_actions,
                               step=self.action_space.shape[0]):
                    prev_a_r.append(
                        one_hot(
                            input_dict[SampleBatch.PREV_ACTIONS]
                            [:, i:i + self.action_space.shape[0]].float(),
                            self.action_space))
            else:
                prev_a_r.append(
                    torch.reshape(
                        input_dict[SampleBatch.PREV_ACTIONS].float(),
                        [-1, self.use_n_prev_actions * self.action_dim]))
        if self.use_n_prev_rewards:
            prev_a_r.append(
                torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
                              [-1, self.use_n_prev_rewards]))

        if prev_a_r:
            wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)

        # Then through our GTrXL.
        input_dict["obs_flat"] = input_dict["obs"] = wrapped_out

        self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
        model_out = self._logits_branch(self._features)
        return model_out, memory_outs
Exemple #4
0
    def _postprocess_torch(self, policy, sample_batch):
        # Push both observations through feature net to get both phis.
        phis, _ = self.model._curiosity_feature_net({
            SampleBatch.OBS:
            torch.cat([
                torch.from_numpy(sample_batch[SampleBatch.OBS]).to(
                    policy.device),
                torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]).to(
                    policy.device),
            ])
        })
        phi, next_phi = torch.chunk(phis, 2)
        actions_tensor = (torch.from_numpy(
            sample_batch[SampleBatch.ACTIONS]).long().to(policy.device))

        # Predict next phi with forward model.
        predicted_next_phi = self.model._curiosity_forward_fcnet(
            torch.cat(
                [phi, one_hot(actions_tensor, self.action_space).float()],
                dim=-1))

        # Forward loss term (predicted phi', given phi and action vs actually
        # observed phi').
        forward_l2_norm_sqared = 0.5 * torch.sum(
            torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1)
        forward_loss = torch.mean(forward_l2_norm_sqared)

        # Scale intrinsic reward by eta hyper-parameter.
        sample_batch[SampleBatch.REWARDS] = (
            sample_batch[SampleBatch.REWARDS] +
            self.eta * forward_l2_norm_sqared.detach().cpu().numpy())

        # Inverse loss term (prediced action that led from phi to phi' vs
        # actual action taken).
        phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1)
        dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
        action_dist = (TorchCategorical(dist_inputs, self.model) if isinstance(
            self.action_space, Discrete) else TorchMultiCategorical(
                dist_inputs, self.model, self.action_space.nvec))
        # Neg log(p); p=probability of observed action given the inverse-NN
        # predicted action distribution.
        inverse_loss = -action_dist.logp(actions_tensor)
        inverse_loss = torch.mean(inverse_loss)

        # Calculate the ICM loss.
        loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
        # Perform an optimizer step.
        self._optimizer.zero_grad()
        loss.backward()
        self._optimizer.step()

        # Return the postprocessed sample batch (with the corrected rewards).
        return sample_batch
Exemple #5
0
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []

        # Prev actions.
        if self.model_config["lstm_use_prev_action"]:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            # If actions are not processed yet (in their original form as
            # have been sent to environment):
            # Flatten/one-hot into 1D array.
            if self.model_config["_disable_action_flattening"]:
                prev_a_r.append(
                    flatten_inputs_to_1d_tensor(
                        prev_a,
                        spaces_struct=self.action_space_struct,
                        time_axis=False))
            # If actions are already flattened (but not one-hot'd yet!),
            # one-hot discrete/multi-discrete actions here.
            else:
                if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                    prev_a = one_hot(prev_a.float(), self.action_space)
                else:
                    prev_a = prev_a.float()
                prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim]))
        # Prev rewards.
        if self.model_config["lstm_use_prev_reward"]:
            prev_a_r.append(
                torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
                              [-1, 1]))

        # Concat prev. actions + rewards to the "main" input.
        if prev_a_r:
            wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)

        # Push everything through our LSTM.
        input_dict["obs_flat"] = wrapped_out
        return super().forward(input_dict, state, seq_lens)
Exemple #6
0
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []

        # Prev actions.
        if self.use_n_prev_actions:
            prev_n_actions = input_dict[SampleBatch.PREV_ACTIONS]
            # If actions are not processed yet (in their original form as
            # have been sent to environment):
            # Flatten/one-hot into 1D array.
            if self.model_config["_disable_action_flattening"]:
                # Merge prev n actions into flat tensor.
                flat = flatten_inputs_to_1d_tensor(
                    prev_n_actions,
                    spaces_struct=self.action_space_struct,
                    time_axis=True,
                )
                # Fold time-axis into flattened data.
                flat = torch.reshape(flat, [flat.shape[0], -1])
                prev_a_r.append(flat)
            # If actions are already flattened (but not one-hot'd yet!),
            # one-hot discrete/multi-discrete actions here and concatenate the
            # n most recent actions together.
            else:
                if isinstance(self.action_space, Discrete):
                    for i in range(self.use_n_prev_actions):
                        prev_a_r.append(
                            one_hot(prev_n_actions[:, i].float(),
                                    space=self.action_space))
                elif isinstance(self.action_space, MultiDiscrete):
                    for i in range(0, self.use_n_prev_actions,
                                   self.action_space.shape[0]):
                        prev_a_r.append(
                            one_hot(
                                prev_n_actions[:, i:i +
                                               self.action_space.shape[0]].
                                float(),
                                space=self.action_space,
                            ))
                else:
                    prev_a_r.append(
                        torch.reshape(
                            prev_n_actions.float(),
                            [-1, self.use_n_prev_actions * self.action_dim],
                        ))
        # Prev rewards.
        if self.use_n_prev_rewards:
            prev_a_r.append(
                torch.reshape(
                    input_dict[SampleBatch.PREV_REWARDS].float(),
                    [-1, self.use_n_prev_rewards],
                ))

        # Concat prev. actions + rewards to the "main" input.
        if prev_a_r:
            wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)

        # Then through our GTrXL.
        input_dict["obs_flat"] = input_dict["obs"] = wrapped_out

        self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
        model_out = self._logits_branch(self._features)
        return model_out, memory_outs