def test_gmm_loss(self): # seq_len x batch_size x gaussian_size x feature_size # 1 x 1 x 2 x 2 mus = torch.Tensor([[[[0.0, 0.0], [6.0, 6.0]]]]) sigmas = torch.Tensor([[[[2.0, 2.0], [2.0, 2.0]]]]) # seq_len x batch_size x gaussian_size pi = torch.Tensor([[[0.5, 0.5]]]) logpi = torch.log(pi) # seq_len x batch_size x feature_size batch = torch.Tensor([[[3.0, 3.0]]]) gl = gmm_loss(batch, mus, sigmas, logpi) # first component, first dimension n11 = Normal(mus[0, 0, 0, 0], sigmas[0, 0, 0, 0]) # first component, second dimension n12 = Normal(mus[0, 0, 0, 1], sigmas[0, 0, 0, 1]) p1 = (pi[0, 0, 0] * torch.exp(n11.log_prob(batch[0, 0, 0])) * torch.exp(n12.log_prob(batch[0, 0, 1]))) # second component, first dimension n21 = Normal(mus[0, 0, 1, 0], sigmas[0, 0, 1, 0]) # second component, second dimension n22 = Normal(mus[0, 0, 1, 1], sigmas[0, 0, 1, 1]) p2 = (pi[0, 0, 1] * torch.exp(n21.log_prob(batch[0, 0, 0])) * torch.exp(n22.log_prob(batch[0, 0, 1]))) logger.info( "gmm loss={}, p1={}, p2={}, p1+p2={}, -log(p1+p2)={}".format( gl, p1, p2, p1 + p2, -torch.log(p1 + p2))) assert -torch.log(p1 + p2) == gl
def get_loss( self, training_batch: rlt.PreprocessedMemoryNetworkInput, state_dim: Optional[int] = None, ): """ Compute losses: GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal) The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearly with STATE_DIM, dim of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :param training_batch: training_batch has these fields: - state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor - action: (SEQ_LEN, BATCH_SIZE, ACTION_DIM) torch tensor - reward: (SEQ_LEN, BATCH_SIZE) torch tensor - not-terminal: (SEQ_LEN, BATCH_SIZE) torch tensor - next_state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor :param state_dim: the dimension of states. If provided, use it to normalize gmm loss :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ assert isinstance(training_batch, rlt.PreprocessedMemoryNetworkInput) # mdnrnn's input should have seq_len as the first dimension mdnrnn_output = self.memory_network( training_batch.state, rlt.FeatureData(training_batch.action)) # mus, sigmas: [seq_len, batch_size, num_gaussian, state_dim] mus, sigmas, logpi, rs, nts = ( mdnrnn_output.mus, mdnrnn_output.sigmas, mdnrnn_output.logpi, mdnrnn_output.reward, mdnrnn_output.not_terminal, ) next_state = training_batch.next_state.float_features not_terminal = training_batch.not_terminal reward = training_batch.reward if self.params.fit_only_one_next_step: next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple( map( lambda x: x[-1:], (next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs), )) gmm = (gmm_loss(next_state, mus, sigmas, logpi) * self.params.next_state_loss_weight) bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) * self.params.not_terminal_loss_weight) mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight if state_dim is not None: loss = gmm / (state_dim + 2) + bce + mse else: loss = gmm + bce + mse return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
def get_loss( self, training_batch: rlt.PreprocessedTrainingBatch, state_dim: Optional[int] = None, batch_first: bool = False, ): """ Compute losses: GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal) The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearly with STATE_DIM, the feature size of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :param training_batch: training_batch.learning_input has these fields: - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor - reward: (BATCH_SIZE, SEQ_LEN) torch tensor - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor the first two dimensions may be swapped depending on batch_first :param state_dim: the dimension of states. If provided, use it to normalize gmm loss :param batch_first: whether data's first dimension represents batch size. If FALSE, state, action, reward, not-terminal, and next_state's first two dimensions are SEQ_LEN and BATCH_SIZE. :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ learning_input = training_batch.training_input assert isinstance(learning_input, rlt.PreprocessedMemoryNetworkInput) # mdnrnn's input should have seq_len as the first dimension if batch_first: state, action, next_state, reward, not_terminal = transpose( learning_input.state.float_features, learning_input.action, learning_input.next_state.float_features, learning_input.reward, learning_input.not_terminal, # type: ignore ) learning_input = rlt.PreprocessedMemoryNetworkInput( # type: ignore state=rlt.PreprocessedFeatureVector(float_features=state), reward=reward, time_diff=torch.ones_like(reward).float(), action=action, not_terminal=not_terminal, next_state=rlt.PreprocessedFeatureVector( float_features=next_state), step=None, ) mdnrnn_input = rlt.PreprocessedStateAction( state=learning_input.state, # type: ignore action=rlt.PreprocessedFeatureVector( float_features=learning_input.action), # type: ignore ) mdnrnn_output = self.mdnrnn(mdnrnn_input) mus, sigmas, logpi, rs, nts = ( mdnrnn_output.mus, mdnrnn_output.sigmas, mdnrnn_output.logpi, mdnrnn_output.reward, mdnrnn_output.not_terminal, ) next_state = learning_input.next_state.float_features not_terminal = learning_input.not_terminal # type: ignore reward = learning_input.reward if self.params.fit_only_one_next_step: next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple( map( lambda x: x[-1:], (next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs), )) gmm = (gmm_loss(next_state, mus, sigmas, logpi) * self.params.next_state_loss_weight) bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) * self.params.not_terminal_loss_weight) mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight if state_dim is not None: loss = gmm / (state_dim + 2) + bce + mse else: loss = gmm + bce + mse return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}