Пример #1
0
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.model_config["lstm_use_prev_action"]:
            if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                prev_a = one_hot(input_dict[SampleBatch.PREV_ACTIONS].float(),
                                 self.action_space)
            else:
                prev_a = input_dict[SampleBatch.PREV_ACTIONS].float()
            prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim]))
        if self.model_config["lstm_use_prev_reward"]:
            prev_a_r.append(
                torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
                              [-1, 1]))

        if prev_a_r:
            wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)

        # Then through our LSTM.
        input_dict["obs_flat"] = wrapped_out
        return super().forward(input_dict, state, seq_lens)
Пример #2
0
    def forward(self, input_dict, state, seq_lens):
        if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
            orig_obs = input_dict[SampleBatch.OBS]
        else:
            orig_obs = restore_original_dimensions(
                input_dict[SampleBatch.OBS],
                self.processed_obs_space,
                tensorlib="torch")
        # Push image observations through our CNNs.
        outs = []
        for i, component in enumerate(tree.flatten(orig_obs)):
            if i in self.cnns:
                cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
                outs.append(cnn_out)
            elif i in self.one_hot:
                if component.dtype in [torch.int32, torch.int64, torch.uint8]:
                    outs.append(
                        one_hot(component, self.flattened_input_space[i]))
                else:
                    outs.append(component)
            else:
                outs.append(torch.reshape(component, [-1, self.flatten[i]]))
        # Concat all outputs and the non-image inputs.
        out = torch.cat(outs, dim=1)
        # Push through (optional) FC-stack (this may be an empty stack).
        out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None)

        # No logits/value branches.
        if self.logits_layer is None:
            return out, []

        # Logits- and value branches.
        logits, values = self.logits_layer(out), self.value_layer(out)
        self._value_out = torch.reshape(values, [-1])
        return logits, []
Пример #3
0
    def forward(self, input_dict, state, seq_lens):
        # Push image observations through our CNNs.
        outs = []
        for i, component in enumerate(input_dict["obs"]):
            if i in self.cnns:
                cnn_out, _ = self.cnns[i]({"obs": component})
                outs.append(cnn_out)
            elif i in self.one_hot:
                if component.dtype in [torch.int32, torch.int64, torch.uint8]:
                    outs.append(
                        one_hot(component, self.original_space.spaces[i]))
                else:
                    outs.append(component)
            else:
                outs.append(torch.reshape(component, [-1, self.flatten[i]]))
        # Concat all outputs and the non-image inputs.
        out = torch.cat(outs, dim=1)
        # Push through (optional) FC-stack (this may be an empty stack).
        out, _ = self.post_fc_stack({"obs": out}, [], None)

        # No logits/value branches.
        if self.logits_layer is None:
            return out, []

        # Logits- and value branches.
        logits, values = self.logits_layer(out), self.value_layer(out)
        self._value_out = torch.reshape(values, [-1])
        return logits, []
Пример #4
0
    def forward(self, input_dict, state, seq_lens):
        # Push image observations through our CNNs.
        orig_obs = restore_original_dimensions(input_dict.get("obs"),
                                               self.new_obs_space, "torch")
        mode = input_dict.get(
            "is_training",
            False) if input_dict.get("obs").shape[0] > 1 else False
        outs = []
        v_outs = []
        for i, component in enumerate(orig_obs[:-1]):
            if i in self.cnns:
                cnn_out, _ = self.cnns[i]({"obs": component})
                outs.append(cnn_out)
                v_outs.append(self.cnns[i].value_function())
            elif i in self.one_hot:
                if component.dtype in [torch.int32, torch.int64, torch.uint8]:
                    outs.append(
                        one_hot(component, self.original_space.spaces[i]))
                    v_outs.append(
                        one_hot(component, self.original_space.spaces[i]))
                else:
                    outs.append(component)
                    v_outs.append(component)
            else:
                outs.append(torch.reshape(component, [-1, self.flatten[i]]))
                v_outs.append(torch.reshape(component, [-1, self.flatten[i]]))
        # Concat all outputs and the non-image inputs.
        out = torch.cat(outs, dim=1)
        v_out = torch.cat(v_outs, dim=1)
        # Push through (optional) FC-stack (this may be an empty stack).

        self.post_fc_stack.train(mode=mode)
        self.post_fc_stack_vf.train(mode=mode)
        out_p = self.post_fc_stack(out)
        out_v = self.post_fc_stack_vf(v_out)

        # No logits/value branches.
        if self.logits_layer is None:
            return out, []

        # Logits- and value branches.
        logits, values = self.logits_layer(out_p), self.value_layer(out_v)
        inf = torch.from_numpy(np.array(float('-inf'))).to(
            torch.device('cuda'))
        inf_mask = torch.maximum(torch.log(orig_obs[-1]), inf)
        self._value_out = torch.reshape(values, [-1])
        return logits + inf_mask, []
Пример #5
0
    def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
        """Calculates phi values (obs, obs', and predicted obs') and ri.

        Also calculates forward and inverse losses and updates the curiosity
        module on the provided batch using our optimizer.
        """
        # Push both observations through feature net to get both phis.
        phis, _ = self.model._curiosity_feature_net({
            SampleBatch.OBS:
            torch.cat([
                torch.from_numpy(sample_batch[SampleBatch.OBS]),
                torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS])
            ])
        })
        phi, next_phi = torch.chunk(phis, 2)
        actions_tensor = torch.from_numpy(
            sample_batch[SampleBatch.ACTIONS]).long().to(policy.device)

        # Predict next phi with forward model.
        predicted_next_phi = self.model._curiosity_forward_fcnet(
            torch.cat(
                [phi, one_hot(actions_tensor, self.action_space).float()],
                dim=-1))

        # Forward loss term (predicted phi', given phi and action vs actually
        # observed phi').
        forward_l2_norm_sqared = 0.5 * torch.sum(
            torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1)
        forward_loss = torch.mean(forward_l2_norm_sqared)

        # Scale intrinsic reward by eta hyper-parameter.
        sample_batch[SampleBatch.REWARDS] = \
            sample_batch[SampleBatch.REWARDS] + \
            self.eta * forward_l2_norm_sqared.detach().cpu().numpy()

        # Inverse loss term (prediced action that led from phi to phi' vs
        # actual action taken).
        phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1)
        dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
        action_dist = TorchCategorical(dist_inputs, self.model) if \
            isinstance(self.action_space, Discrete) else \
            TorchMultiCategorical(
                dist_inputs, self.model, self.action_space.nvec)
        # Neg log(p); p=probability of observed action given the inverse-NN
        # predicted action distribution.
        inverse_loss = -action_dist.logp(actions_tensor)
        inverse_loss = torch.mean(inverse_loss)

        # Calculate the ICM loss.
        loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
        # Perform an optimizer step.
        self._optimizer.zero_grad()
        loss.backward()
        self._optimizer.step()

        # Return the postprocessed sample batch (with the corrected rewards).
        return sample_batch
Пример #6
0
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.use_n_prev_actions:
            if isinstance(self.action_space, Discrete):
                for i in range(self.use_n_prev_actions):
                    prev_a_r.append(
                        one_hot(
                            input_dict[SampleBatch.PREV_ACTIONS][:, i].float(),
                            self.action_space))
            elif isinstance(self.action_space, MultiDiscrete):
                for i in range(self.use_n_prev_actions,
                               step=self.action_space.shape[0]):
                    prev_a_r.append(
                        one_hot(
                            input_dict[SampleBatch.PREV_ACTIONS]
                            [:, i:i + self.action_space.shape[0]].float(),
                            self.action_space))
            else:
                prev_a_r.append(
                    torch.reshape(
                        input_dict[SampleBatch.PREV_ACTIONS].float(),
                        [-1, self.use_n_prev_actions * self.action_dim]))
        if self.use_n_prev_rewards:
            prev_a_r.append(
                torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
                              [-1, self.use_n_prev_rewards]))

        if prev_a_r:
            wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)

        # Then through our GTrXL.
        input_dict["obs_flat"] = input_dict["obs"] = wrapped_out

        self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
        model_out = self._logits_branch(self._features)
        return model_out, memory_outs