def as_cem_training_batch(self, batch_first=False): """ Generate one-step samples needed by CEM trainer. The samples will be used to train an ensemble of world models used by CEM. If batch_first = True: state/next state shape: batch_size x 1 x state_dim action shape: batch_size x 1 x action_dim reward/terminal shape: batch_size x 1 else (default): state/next state shape: 1 x batch_size x state_dim action shape: 1 x batch_size x action_dim reward/terminal shape: 1 x batch_size """ if batch_first: seq_len_dim = 1 reward, not_terminal = self.rewards, self.not_terminal else: seq_len_dim = 0 reward, not_terminal = transpose(self.rewards, self.not_terminal) training_input = rlt.PreprocessedMemoryNetworkInput( state=rlt.PreprocessedFeatureVector( float_features=self.states.unsqueeze(seq_len_dim)), action=self.actions.unsqueeze(seq_len_dim), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states.unsqueeze(seq_len_dim)), reward=reward, not_terminal=not_terminal, step=self.step, time_diff=self.time_diffs, ) return rlt.PreprocessedTrainingBatch( training_input=training_input, extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def as_cem_training_batch(self): """ Generate one-step samples needed by CEM trainer. The samples will be used to train an ensemble of world models used by CEM. state/next state shape: 1 x batch_size x state_dim action shape: 1 x batch_size x action_dim reward/terminal shape: 1 x batch_size """ seq_len_dim = 0 reward, not_terminal = transpose(self.rewards, self.not_terminal) return rlt.PreprocessedMemoryNetworkInput( state=rlt.FeatureData(self.states.unsqueeze(seq_len_dim)), action=self.actions.unsqueeze(seq_len_dim), next_state=rlt.FeatureData( float_features=self.next_states.unsqueeze(seq_len_dim) ), reward=reward, not_terminal=not_terminal, step=self.step, time_diff=self.time_diffs, )
def evaluate(self, tdp: PreprocessedTrainingBatch): """ Calculate state feature sensitivity due to actions: randomly permutating actions and see how much the prediction of next state feature deviates. """ mdnrnn_training_input = tdp.training_input assert isinstance(mdnrnn_training_input, PreprocessedMemoryNetworkInput) self.trainer.mdnrnn.mdnrnn.eval() batch_size, seq_len, state_dim = ( mdnrnn_training_input.next_state.float_features.size()) state_feature_num = self.state_feature_num feature_sensitivity = torch.zeros(state_feature_num) state, action, next_state, reward, not_terminal = transpose( mdnrnn_training_input.state.float_features, mdnrnn_training_input.action, mdnrnn_training_input.next_state.float_features, mdnrnn_training_input.reward, mdnrnn_training_input.not_terminal, ) mdnrnn_input = PreprocessedStateAction( state=PreprocessedFeatureVector(float_features=state), action=PreprocessedFeatureVector(float_features=action), ) # the input of mdnrnn has seq-len as the first dimension mdnrnn_output = self.trainer.mdnrnn(mdnrnn_input) predicted_next_state_means = mdnrnn_output.mus shuffled_mdnrnn_input = PreprocessedStateAction( state=PreprocessedFeatureVector(float_features=state), # shuffle the actions action=PreprocessedFeatureVector( float_features=action[:, torch.randperm(batch_size), :]), ) shuffled_mdnrnn_output = self.trainer.mdnrnn(shuffled_mdnrnn_input) shuffled_predicted_next_state_means = shuffled_mdnrnn_output.mus assert (predicted_next_state_means.size() == shuffled_predicted_next_state_means.size() == (seq_len, batch_size, self.trainer.params.num_gaussians, state_dim)) state_feature_boundaries = self.sorted_state_feature_start_indices + [ state_dim ] for i in range(state_feature_num): boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) abs_diff = torch.mean( torch.sum( torch.abs( shuffled_predicted_next_state_means[:, :, :, boundary_start: boundary_end] - predicted_next_state_means[:, :, :, boundary_start:boundary_end] ), dim=3, )) feature_sensitivity[i] = abs_diff.cpu().detach().item() self.trainer.mdnrnn.mdnrnn.train() logger.info("**** Debug tool feature sensitivity ****: {}".format( feature_sensitivity)) return {"feature_sensitivity": feature_sensitivity.numpy()}
def get_loss( self, training_batch: rlt.PreprocessedTrainingBatch, state_dim: Optional[int] = None, batch_first: bool = False, ): """ Compute losses: GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal) The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearly with STATE_DIM, the feature size of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :param training_batch: training_batch.learning_input has these fields: - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor - reward: (BATCH_SIZE, SEQ_LEN) torch tensor - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor the first two dimensions may be swapped depending on batch_first :param state_dim: the dimension of states. If provided, use it to normalize gmm loss :param batch_first: whether data's first dimension represents batch size. If FALSE, state, action, reward, not-terminal, and next_state's first two dimensions are SEQ_LEN and BATCH_SIZE. :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ learning_input = training_batch.training_input assert isinstance(learning_input, rlt.PreprocessedMemoryNetworkInput) # mdnrnn's input should have seq_len as the first dimension if batch_first: state, action, next_state, reward, not_terminal = transpose( learning_input.state.float_features, learning_input.action, learning_input.next_state.float_features, learning_input.reward, learning_input.not_terminal, # type: ignore ) learning_input = rlt.PreprocessedMemoryNetworkInput( # type: ignore state=rlt.PreprocessedFeatureVector(float_features=state), reward=reward, time_diff=torch.ones_like(reward).float(), action=action, not_terminal=not_terminal, next_state=rlt.PreprocessedFeatureVector( float_features=next_state), step=None, ) mdnrnn_input = rlt.PreprocessedStateAction( state=learning_input.state, # type: ignore action=rlt.PreprocessedFeatureVector( float_features=learning_input.action), # type: ignore ) mdnrnn_output = self.mdnrnn(mdnrnn_input) mus, sigmas, logpi, rs, nts = ( mdnrnn_output.mus, mdnrnn_output.sigmas, mdnrnn_output.logpi, mdnrnn_output.reward, mdnrnn_output.not_terminal, ) next_state = learning_input.next_state.float_features not_terminal = learning_input.not_terminal # type: ignore reward = learning_input.reward if self.params.fit_only_one_next_step: next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple( map( lambda x: x[-1:], (next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs), )) gmm = (gmm_loss(next_state, mus, sigmas, logpi) * self.params.next_state_loss_weight) bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) * self.params.not_terminal_loss_weight) mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight if state_dim is not None: loss = gmm / (state_dim + 2) + bce + mse else: loss = gmm + bce + mse return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}