Пример #1
0
    def test_gmm_loss(self):
        # seq_len x batch_size x gaussian_size x feature_size
        # 1 x 1 x 2 x 2
        mus = torch.Tensor([[[[0.0, 0.0], [6.0, 6.0]]]])
        sigmas = torch.Tensor([[[[2.0, 2.0], [2.0, 2.0]]]])
        # seq_len x batch_size x gaussian_size
        pi = torch.Tensor([[[0.5, 0.5]]])
        logpi = torch.log(pi)

        # seq_len x batch_size x feature_size
        batch = torch.Tensor([[[3.0, 3.0]]])
        gl = gmm_loss(batch, mus, sigmas, logpi)

        # first component, first dimension
        n11 = Normal(mus[0, 0, 0, 0], sigmas[0, 0, 0, 0])
        # first component, second dimension
        n12 = Normal(mus[0, 0, 0, 1], sigmas[0, 0, 0, 1])
        p1 = (pi[0, 0, 0] * torch.exp(n11.log_prob(batch[0, 0, 0])) *
              torch.exp(n12.log_prob(batch[0, 0, 1])))
        # second component, first dimension
        n21 = Normal(mus[0, 0, 1, 0], sigmas[0, 0, 1, 0])
        # second component, second dimension
        n22 = Normal(mus[0, 0, 1, 1], sigmas[0, 0, 1, 1])
        p2 = (pi[0, 0, 1] * torch.exp(n21.log_prob(batch[0, 0, 0])) *
              torch.exp(n22.log_prob(batch[0, 0, 1])))

        logger.info(
            "gmm loss={}, p1={}, p2={}, p1+p2={}, -log(p1+p2)={}".format(
                gl, p1, p2, p1 + p2, -torch.log(p1 + p2)))
        assert -torch.log(p1 + p2) == gl
Пример #2
0
    def get_loss(
        self,
        training_batch: rlt.PreprocessedMemoryNetworkInput,
        state_dim: Optional[int] = None,
    ):
        """
        Compute losses:
            GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2)
            + MSE(reward, predicted_reward)
            + BCE(not_terminal, logit_not_terminal)

        The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales
            approximately linearly with STATE_DIM, dim of states. All losses
            are averaged both on the batch and the sequence dimensions (the two first
            dimensions).

        :param training_batch:
            training_batch has these fields:
            - state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor
            - action: (SEQ_LEN, BATCH_SIZE, ACTION_DIM) torch tensor
            - reward: (SEQ_LEN, BATCH_SIZE) torch tensor
            - not-terminal: (SEQ_LEN, BATCH_SIZE) torch tensor
            - next_state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor

        :param state_dim: the dimension of states. If provided, use it to normalize
            gmm loss

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        assert isinstance(training_batch, rlt.PreprocessedMemoryNetworkInput)
        # mdnrnn's input should have seq_len as the first dimension

        mdnrnn_output = self.memory_network(
            training_batch.state, rlt.FeatureData(training_batch.action))
        # mus, sigmas: [seq_len, batch_size, num_gaussian, state_dim]
        mus, sigmas, logpi, rs, nts = (
            mdnrnn_output.mus,
            mdnrnn_output.sigmas,
            mdnrnn_output.logpi,
            mdnrnn_output.reward,
            mdnrnn_output.not_terminal,
        )

        next_state = training_batch.next_state.float_features
        not_terminal = training_batch.not_terminal
        reward = training_batch.reward
        if self.params.fit_only_one_next_step:
            next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple(
                map(
                    lambda x: x[-1:],
                    (next_state, not_terminal, reward, mus, sigmas, logpi, nts,
                     rs),
                ))

        gmm = (gmm_loss(next_state, mus, sigmas, logpi) *
               self.params.next_state_loss_weight)
        bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) *
               self.params.not_terminal_loss_weight)
        mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight
        if state_dim is not None:
            loss = gmm / (state_dim + 2) + bce + mse
        else:
            loss = gmm + bce + mse
        return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
Пример #3
0
    def get_loss(
        self,
        training_batch: rlt.PreprocessedTrainingBatch,
        state_dim: Optional[int] = None,
        batch_first: bool = False,
    ):
        """
        Compute losses:
            GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2)
            + MSE(reward, predicted_reward)
            + BCE(not_terminal, logit_not_terminal)

        The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales
            approximately linearly with STATE_DIM, the feature size of states. All losses
            are averaged both on the batch and the sequence dimensions (the two first
            dimensions).

        :param training_batch:
            training_batch.learning_input has these fields:
            - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor
            - reward: (BATCH_SIZE, SEQ_LEN) torch tensor
            - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor
            - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            the first two dimensions may be swapped depending on batch_first

        :param state_dim: the dimension of states. If provided, use it to normalize
            gmm loss

        :param batch_first: whether data's first dimension represents batch size. If
            FALSE, state, action, reward, not-terminal, and next_state's first
            two dimensions are SEQ_LEN and BATCH_SIZE.

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        learning_input = training_batch.training_input
        assert isinstance(learning_input, rlt.PreprocessedMemoryNetworkInput)
        # mdnrnn's input should have seq_len as the first dimension
        if batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                learning_input.state.float_features,
                learning_input.action,
                learning_input.next_state.float_features,
                learning_input.reward,
                learning_input.not_terminal,  # type: ignore
            )
            learning_input = rlt.PreprocessedMemoryNetworkInput(  # type: ignore
                state=rlt.PreprocessedFeatureVector(float_features=state),
                reward=reward,
                time_diff=torch.ones_like(reward).float(),
                action=action,
                not_terminal=not_terminal,
                next_state=rlt.PreprocessedFeatureVector(
                    float_features=next_state),
                step=None,
            )

        mdnrnn_input = rlt.PreprocessedStateAction(
            state=learning_input.state,  # type: ignore
            action=rlt.PreprocessedFeatureVector(
                float_features=learning_input.action),  # type: ignore
        )
        mdnrnn_output = self.mdnrnn(mdnrnn_input)
        mus, sigmas, logpi, rs, nts = (
            mdnrnn_output.mus,
            mdnrnn_output.sigmas,
            mdnrnn_output.logpi,
            mdnrnn_output.reward,
            mdnrnn_output.not_terminal,
        )

        next_state = learning_input.next_state.float_features
        not_terminal = learning_input.not_terminal  # type: ignore
        reward = learning_input.reward
        if self.params.fit_only_one_next_step:
            next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple(
                map(
                    lambda x: x[-1:],
                    (next_state, not_terminal, reward, mus, sigmas, logpi, nts,
                     rs),
                ))

        gmm = (gmm_loss(next_state, mus, sigmas, logpi) *
               self.params.next_state_loss_weight)
        bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) *
               self.params.not_terminal_loss_weight)
        mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight
        if state_dim is not None:
            loss = gmm / (state_dim + 2) + bce + mse
        else:
            loss = gmm + bce + mse
        return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}