def sample_memories(self, batch_size, batch_first=False): """ :param batch_size: number of samples to return :param batch_first: If True, the first dimension of data is batch_size. If False (default), the first dimension is SEQ_LEN. Therefore, state's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default, MDN-RNN consumes data with SEQ_LEN as the first dimension. """ sample_indices = np.random.randint(self.memory_size, size=batch_size) # state/next state shape: batch_size x seq_len x state_dim # action shape: # state shape: batch_size x seq_len x action_dim # reward/not_terminal shape: batch_size x seq_len state, action, next_state, reward, not_terminal = map( lambda x: torch.tensor(x, dtype=torch.float), zip(*self.deque_sample(sample_indices)), ) if not batch_first: state, action, next_state, reward, not_terminal = transpose( state, action, next_state, reward, not_terminal ) training_input = rlt.MemoryNetworkInput( state=rlt.FeatureVector(float_features=state), action=rlt.FeatureVector(float_features=action), next_state=next_state, reward=reward, not_terminal=not_terminal, ) return rlt.TrainingBatch(training_input=training_input, extras=None)
def get_loss( self, training_batch: rlt.TrainingBatch, state_dim: Optional[int] = None, batch_first: bool = False, ): """ Compute losses. The loss that is computed is: (GMMLoss(next_state, GMMPredicted) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal)) / (STATE_DIM + 2) The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearily with STATE_DIM, the feature size of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :param training_batch training_batch.learning_input has these fields: - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor - reward: (BATCH_SIZE, SEQ_LEN) torch tensor - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor the first two dimensions may be swapped depending on batch_first :param state_dim: the dimension of states. If provided, use it to normalize gmm loss :param batch_first: whether data's first dimension represents batch size. If FALSE, state, action, reward, not-terminal, and next_state's first two dimensions are SEQ_LEN and BATCH_SIZE. :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ learning_input = training_batch.training_input # mdnrnn's input should have seq_len as the first dimension if batch_first: state, action, next_state, reward, not_terminal = transpose( learning_input.state.float_features, learning_input.action.float_features, # type: ignore learning_input.next_state, learning_input.reward, learning_input.not_terminal, # type: ignore ) learning_input = rlt.MemoryNetworkInput( # type: ignore state=rlt.FeatureVector(float_features=state), reward=reward, time_diff=torch.ones_like(reward).float(), action=rlt.FeatureVector(float_features=action), not_terminal=not_terminal, next_state=next_state, ) mdnrnn_input = rlt.StateAction( state=learning_input.state, action=learning_input.action # type: ignore ) mdnrnn_output = self.mdnrnn(mdnrnn_input) mus, sigmas, logpi, rs, nts = ( mdnrnn_output.mus, mdnrnn_output.sigmas, mdnrnn_output.logpi, mdnrnn_output.reward, mdnrnn_output.not_terminal, ) next_state = learning_input.next_state not_terminal = learning_input.not_terminal # type: ignore reward = learning_input.reward if self.params.fit_only_one_next_step: next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple( map( lambda x: x[-1:], (next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs), )) gmm = (gmm_loss(next_state, mus, sigmas, logpi) * self.params.next_state_loss_weight) bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) * self.params.not_terminal_loss_weight) mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight if state_dim is not None: loss = gmm / (state_dim + 2) + bce + mse else: loss = gmm + bce + mse return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
def get_loss( self, training_batch: rlt.TrainingBatch, state_dim: Optional[int] = None, batch_first: bool = False, ): """ Compute losses. The loss that is computed is: (GMMLoss(next_state, GMMPredicted) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal)) / (STATE_DIM + 2) The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearily with STATE_DIM, the feature size of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :param training_batch training_batch.learning_input has these fields: state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor reward: (BATCH_SIZE, SEQ_LEN) torch tensor not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor :param state_dim: the dimension of states. If provided, use it to normalize loss :param batch_first: whether data's first dimension represents batch size. If FALSE, state, action, reward, not-terminal, and next_state's first two dimensions are SEQ_LEN and BATCH_SIZE. :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ learning_input = training_batch.training_input # mdnrnn's input should have seq_len as the first dimension if batch_first: state, action, next_state, reward, not_terminal = transpose( learning_input.state.float_features, learning_input.action.float_features, learning_input.next_state, learning_input.reward, learning_input.not_terminal, ) learning_input = rlt.MemoryNetworkInput( state=rlt.FeatureVector(float_features=state), action=rlt.FeatureVector(float_features=action), next_state=next_state, reward=reward, not_terminal=not_terminal, ) mdnrnn_input = rlt.StateAction(state=learning_input.state, action=learning_input.action) mdnrnn_output = self.mdnrnn(mdnrnn_input) mus, sigmas, logpi, rs, ds = ( mdnrnn_output.mus, mdnrnn_output.sigmas, mdnrnn_output.logpi, mdnrnn_output.reward, mdnrnn_output.not_terminal, ) gmm = gmm_loss(learning_input.next_state, mus, sigmas, logpi) bce = F.binary_cross_entropy_with_logits(ds, learning_input.not_terminal) mse = F.mse_loss(rs, learning_input.reward) if state_dim is not None: loss = (gmm + bce + mse) / (state_dim + 2) else: loss = mse + bce + gmm return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}