コード例 #1
0
    def as_cem_training_batch(self, batch_first=False):
        """
        Generate one-step samples needed by CEM trainer.
        The samples will be used to train an ensemble of world models used by CEM.

        If batch_first = True:
            state/next state shape: batch_size x 1 x state_dim
            action shape: batch_size x 1 x action_dim
            reward/terminal shape: batch_size x 1
        else (default):
             state/next state shape: 1 x batch_size x state_dim
             action shape: 1 x batch_size x action_dim
             reward/terminal shape: 1 x batch_size
        """
        if batch_first:
            seq_len_dim = 1
            reward, not_terminal = self.rewards, self.not_terminal
        else:
            seq_len_dim = 0
            reward, not_terminal = transpose(self.rewards, self.not_terminal)
        training_input = rlt.PreprocessedMemoryNetworkInput(
            state=rlt.PreprocessedFeatureVector(
                float_features=self.states.unsqueeze(seq_len_dim)),
            action=self.actions.unsqueeze(seq_len_dim),
            next_state=rlt.PreprocessedFeatureVector(
                float_features=self.next_states.unsqueeze(seq_len_dim)),
            reward=reward,
            not_terminal=not_terminal,
            step=self.step,
            time_diff=self.time_diffs,
        )
        return rlt.PreprocessedTrainingBatch(
            training_input=training_input,
            extras=rlt.ExtraData(
                mdp_id=self.mdp_ids,
                sequence_number=self.sequence_numbers,
                action_probability=self.propensities,
                max_num_actions=self.max_num_actions,
                metrics=self.metrics,
            ),
        )
コード例 #2
0
    def get_loss(
        self,
        training_batch: rlt.PreprocessedTrainingBatch,
        state_dim: Optional[int] = None,
        batch_first: bool = False,
    ):
        """
        Compute losses:
            GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2)
            + MSE(reward, predicted_reward)
            + BCE(not_terminal, logit_not_terminal)

        The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales
            approximately linearly with STATE_DIM, the feature size of states. All losses
            are averaged both on the batch and the sequence dimensions (the two first
            dimensions).

        :param training_batch:
            training_batch.learning_input has these fields:
            - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor
            - reward: (BATCH_SIZE, SEQ_LEN) torch tensor
            - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor
            - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            the first two dimensions may be swapped depending on batch_first

        :param state_dim: the dimension of states. If provided, use it to normalize
            gmm loss

        :param batch_first: whether data's first dimension represents batch size. If
            FALSE, state, action, reward, not-terminal, and next_state's first
            two dimensions are SEQ_LEN and BATCH_SIZE.

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        learning_input = training_batch.training_input
        assert isinstance(learning_input, rlt.PreprocessedMemoryNetworkInput)
        # mdnrnn's input should have seq_len as the first dimension
        if batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                learning_input.state.float_features,
                learning_input.action,
                learning_input.next_state.float_features,
                learning_input.reward,
                learning_input.not_terminal,  # type: ignore
            )
            learning_input = rlt.PreprocessedMemoryNetworkInput(  # type: ignore
                state=rlt.PreprocessedFeatureVector(float_features=state),
                reward=reward,
                time_diff=torch.ones_like(reward).float(),
                action=action,
                not_terminal=not_terminal,
                next_state=rlt.PreprocessedFeatureVector(
                    float_features=next_state),
                step=None,
            )

        mdnrnn_input = rlt.PreprocessedStateAction(
            state=learning_input.state,  # type: ignore
            action=rlt.PreprocessedFeatureVector(
                float_features=learning_input.action),  # type: ignore
        )
        mdnrnn_output = self.mdnrnn(mdnrnn_input)
        mus, sigmas, logpi, rs, nts = (
            mdnrnn_output.mus,
            mdnrnn_output.sigmas,
            mdnrnn_output.logpi,
            mdnrnn_output.reward,
            mdnrnn_output.not_terminal,
        )

        next_state = learning_input.next_state.float_features
        not_terminal = learning_input.not_terminal  # type: ignore
        reward = learning_input.reward
        if self.params.fit_only_one_next_step:
            next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple(
                map(
                    lambda x: x[-1:],
                    (next_state, not_terminal, reward, mus, sigmas, logpi, nts,
                     rs),
                ))

        gmm = (gmm_loss(next_state, mus, sigmas, logpi) *
               self.params.next_state_loss_weight)
        bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) *
               self.params.not_terminal_loss_weight)
        mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight
        if state_dim is not None:
            loss = gmm / (state_dim + 2) + bce + mse
        else:
            loss = gmm + bce + mse
        return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
コード例 #3
0
    def evaluate(self, tdp: TrainingBatch):
        """ Calculate state feature sensitivity due to actions:
        randomly permutating actions and see how much the prediction of next
        state feature deviates. """
        self.trainer.mdnrnn.mdnrnn.eval()

        batch_size, seq_len, state_dim = (
            tdp.training_input.next_state.size()  # type: ignore
        )  # type: ignore
        state_feature_num = self.state_feature_num
        feature_sensitivity = torch.zeros(state_feature_num)

        mdnrnn_input = tdp.training_input
        state, action, next_state, reward, not_terminal = transpose(
            mdnrnn_input.state.float_features,
            mdnrnn_input.action.float_features,  # type: ignore
            mdnrnn_input.next_state,
            mdnrnn_input.reward,
            mdnrnn_input.not_terminal,
        )
        mdnrnn_input = MemoryNetworkInput(
            state=FeatureVector(float_features=state),
            action=FeatureVector(float_features=action),
            next_state=next_state,
            reward=reward,
            not_terminal=not_terminal,
        )
        # the input of mdnrnn has seq-len as the first dimension
        mdnrnn_output = self.trainer.mdnrnn(mdnrnn_input)
        predicted_next_state_means = mdnrnn_output.mus

        shuffled_mdnrnn_input = MemoryNetworkInput(
            state=FeatureVector(float_features=state),
            # shuffle the actions
            action=FeatureVector(
                float_features=action[:, torch.randperm(batch_size), :]
            ),
            next_state=next_state,
            reward=reward,
            not_terminal=not_terminal,
        )
        shuffled_mdnrnn_output = self.trainer.mdnrnn(shuffled_mdnrnn_input)
        shuffled_predicted_next_state_means = shuffled_mdnrnn_output.mus

        assert (
            predicted_next_state_means.size()
            == shuffled_predicted_next_state_means.size()
            == (seq_len, batch_size, self.trainer.params.num_gaussians, state_dim)
        )

        state_feature_boundaries = self.sorted_state_feature_start_indices + [state_dim]
        for i in range(state_feature_num):
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            abs_diff = torch.mean(
                torch.sum(
                    torch.abs(
                        shuffled_predicted_next_state_means[
                            :, :, :, boundary_start:boundary_end
                        ]
                        - predicted_next_state_means[
                            :, :, :, boundary_start:boundary_end
                        ]
                    ),
                    dim=3,
                )
            )
            feature_sensitivity[i] = abs_diff.cpu().detach().item()

        self.trainer.mdnrnn.mdnrnn.train()
        logger.info(
            "**** Debug tool feature sensitivity ****: {}".format(feature_sensitivity)
        )
        return {"feature_sensitivity": feature_sensitivity.numpy()}
コード例 #4
0
ファイル: mdnrnn_trainer.py プロジェクト: getcontrol/Horizon
    def get_loss(
        self,
        training_batch: rlt.TrainingBatch,
        state_dim: Optional[int] = None,
        batch_first: bool = False,
    ):
        """ Compute losses.

        The loss that is computed is:
        (GMMLoss(next_state, GMMPredicted) + MSE(reward, predicted_reward) +
             BCE(not_terminal, logit_not_terminal)) / (STATE_DIM + 2)
        The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales
        approximately linearily with STATE_DIM, the feature size of states. All losses
        are averaged both on the batch and the sequence dimensions (the two first
        dimensions).

        :param training_batch
        training_batch.learning_input has these fields:
            state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor
            reward: (BATCH_SIZE, SEQ_LEN) torch tensor
            not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor
            next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
        :param state_dim: the dimension of states. If provided, use it to normalize
            gmm loss
        :param batch_first: whether data's first dimension represents batch size. If
            FALSE, state, action, reward, not-terminal, and next_state's first
            two dimensions are SEQ_LEN and BATCH_SIZE.

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        learning_input = training_batch.training_input
        # mdnrnn's input should have seq_len as the first dimension
        if batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                learning_input.state.float_features,
                learning_input.action.float_features,
                learning_input.next_state,
                learning_input.reward,
                learning_input.not_terminal,
            )
            learning_input = rlt.MemoryNetworkInput(
                state=rlt.FeatureVector(float_features=state),
                action=rlt.FeatureVector(float_features=action),
                next_state=next_state,
                reward=reward,
                not_terminal=not_terminal,
            )

        mdnrnn_input = rlt.StateAction(state=learning_input.state,
                                       action=learning_input.action)
        mdnrnn_output = self.mdnrnn(mdnrnn_input)
        mus, sigmas, logpi, rs, ds = (
            mdnrnn_output.mus,
            mdnrnn_output.sigmas,
            mdnrnn_output.logpi,
            mdnrnn_output.reward,
            mdnrnn_output.not_terminal,
        )

        gmm = (gmm_loss(learning_input.next_state, mus, sigmas, logpi) *
               self.params.next_state_loss_weight)
        bce = (F.binary_cross_entropy_with_logits(
            ds, learning_input.not_terminal) *
               self.params.not_terminal_loss_weight)
        mse = F.mse_loss(
            rs, learning_input.reward) * self.params.reward_loss_weight
        if state_dim is not None:
            loss = gmm / (state_dim + 2) + bce + mse
        else:
            loss = gmm + bce + mse
        return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}