def handle(self, tdp: PreprocessedTrainingBatch) -> None: assert isinstance(tdp.training_input, PreprocessedMemoryNetworkInput) batch_size, _, _ = tdp.training_input.next_state.float_features.size() tdp = PreprocessedTrainingBatch( training_input=PreprocessedMemoryNetworkInput( state=tdp.training_input.state, action=tdp.training_input.action, # type: ignore time_diff=torch.ones_like( tdp.training_input.reward[torch.randperm(batch_size)] ).float(), # shuffle the data next_state=tdp.training_input.next_state._replace( float_features=tdp.training_input.next_state.float_features[ torch.randperm(batch_size) ] ), reward=tdp.training_input.reward[torch.randperm(batch_size)], not_terminal=tdp.training_input.not_terminal[ # type: ignore torch.randperm(batch_size) ], step=None, ), extras=ExtraData(), ) losses = self.trainer_or_evaluator.train(tdp, batch_first=True) self.results.append(losses)
def handle(self, tdp: TrainingBatch) -> None: batch_size, _, _ = tdp.training_input.next_state.size() tdp = TrainingBatch( training_input=MemoryNetworkInput( state=tdp.training_input.state, action=tdp.training_input.action, # shuffle the data next_state=tdp.training_input.next_state[torch.randperm( batch_size)], reward=tdp.training_input.reward[torch.randperm(batch_size)], not_terminal=tdp.training_input.not_terminal[torch.randperm( batch_size)], ), extras=ExtraData(), ) losses = self.trainer_or_evaluator.train(tdp, batch_first=True) self.results.append(losses)
def preprocess(self, batch) -> TrainingBatch: state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor( batch["state_features"]) next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor( batch["next_state_features"]) mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1) sequence_numbers = torch.tensor(batch["sequence_number"], dtype=torch.int32).reshape(-1, 1) rewards = torch.tensor(batch["reward"], dtype=torch.float32).reshape(-1, 1) time_diffs = torch.tensor(batch["time_diff"], dtype=torch.int32).reshape(-1, 1) if "action_probability" in batch: propensities = torch.tensor(batch["action_probability"], dtype=torch.float32).reshape(-1, 1) else: propensities = torch.ones(rewards.shape, dtype=torch.float32) return TrainingBatch( training_input=BaseInput( state=FeatureVector(float_features=ValuePresence( value=state_features_dense, presence=state_features_dense_presence, )), next_state=FeatureVector(float_features=ValuePresence( value=next_state_features_dense, presence=next_state_features_dense_presence, )), reward=rewards, time_diff=time_diffs, ), extras=ExtraData( mdp_id=mdp_ids, sequence_number=sequence_numbers, action_probability=propensities, ), )
def evaluate(self, tdp: TrainingBatch): """ Calculate feature importance: setting each state/action feature to the mean value and observe loss increase. """ self.trainer.mdnrnn.mdnrnn.eval() state_features = tdp.training_input.state.float_features action_features = tdp.training_input.action.float_features # type: ignore batch_size, seq_len, state_dim = state_features.size() # type: ignore action_dim = action_features.size()[2] # type: ignore action_feature_num = self.action_feature_num state_feature_num = self.state_feature_num feature_importance = torch.zeros(action_feature_num + state_feature_num) orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True) orig_loss = orig_losses["loss"].cpu().detach().item() del orig_losses action_feature_boundaries = self.sorted_action_feature_start_indices + [ action_dim ] state_feature_boundaries = self.sorted_state_feature_start_indices + [state_dim] for i in range(action_feature_num): action_features = tdp.training_input.action.float_features.reshape( # type: ignore (batch_size * seq_len, action_dim) ).data.clone() # if actions are discrete, an action's feature importance is the loss # increase due to setting all actions to this action if self.discrete_action: assert action_dim == action_feature_num action_vec = torch.zeros(action_dim) action_vec[i] = 1 action_features[:] = action_vec # type: ignore # if actions are continuous, an action's feature importance is the loss # increase due to masking this action feature to its mean value else: boundary_start, boundary_end = ( action_feature_boundaries[i], action_feature_boundaries[i + 1], ) action_features[ # type: ignore :, boundary_start:boundary_end ] = self.compute_median_feature_value( # type: ignore action_features[:, boundary_start:boundary_end] # type: ignore ) action_features = action_features.reshape( # type: ignore (batch_size, seq_len, action_dim) ) # type: ignore new_tdp = TrainingBatch( training_input=MemoryNetworkInput( # type: ignore state=tdp.training_input.state, action=FeatureVector( # type: ignore float_features=action_features ), # type: ignore next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss( new_tdp, state_dim=state_dim, batch_first=True ) feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss del losses for i in range(state_feature_num): state_features = tdp.training_input.state.float_features.reshape( # type: ignore (batch_size * seq_len, state_dim) ).data.clone() boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) state_features[ # type: ignore :, boundary_start:boundary_end ] = self.compute_median_feature_value( state_features[:, boundary_start:boundary_end] # type: ignore ) state_features = state_features.reshape( # type: ignore (batch_size, seq_len, state_dim) ) # type: ignore new_tdp = TrainingBatch( training_input=MemoryNetworkInput( # type: ignore state=FeatureVector(float_features=state_features), # type: ignore action=tdp.training_input.action, next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss( new_tdp, state_dim=state_dim, batch_first=True ) feature_importance[i + action_feature_num] = ( losses["loss"].cpu().detach().item() - orig_loss ) del losses self.trainer.mdnrnn.mdnrnn.train() logger.info( "**** Debug tool feature importance ****: {}".format(feature_importance) ) return {"feature_loss_increase": feature_importance.numpy()}
def evaluate(self, tdp: TrainingBatch): """ Calculate feature importance: setting each state/action feature to the mean value and observe loss increase. """ self.trainer.mdnrnn.mdnrnn.eval() state_features = tdp.training_input.state.float_features action_features = tdp.training_input.action.float_features batch_size, seq_len, state_dim = state_features.size() action_dim = action_features.size()[2] action_feature_num = self.feature_extractor.action_feature_num state_feature_num = self.feature_extractor.state_feature_num feature_importance = torch.zeros(action_feature_num + state_feature_num) orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True) orig_loss = orig_losses["loss"].cpu().detach().item() del orig_losses action_feature_boundaries = ( self.feature_extractor.sorted_action_feature_start_indices + [action_dim]) state_feature_boundaries = ( self.feature_extractor.sorted_state_feature_start_indices + [state_dim]) for i in range(action_feature_num): action_features = tdp.training_input.action.float_features.reshape( (batch_size * seq_len, action_dim)).data.clone() boundary_start, boundary_end = ( action_feature_boundaries[i], action_feature_boundaries[i + 1], ) action_features[:, boundary_start: boundary_end] = action_features[:, boundary_start: boundary_end].mean( dim=0) action_features = action_features.reshape( (batch_size, seq_len, action_dim)) new_tdp = TrainingBatch( training_input=MemoryNetworkInput( state=tdp.training_input.state, action=FeatureVector(float_features=action_features), next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss(new_tdp, state_dim=state_dim, batch_first=True) feature_importance[i] = losses["loss"].cpu().detach().item( ) - orig_loss del losses for i in range(state_feature_num): state_features = tdp.training_input.state.float_features.reshape( (batch_size * seq_len, state_dim)).data.clone() boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) state_features[:, boundary_start: boundary_end] = state_features[:, boundary_start: boundary_end].mean( dim=0) state_features = state_features.reshape( (batch_size, seq_len, state_dim)) new_tdp = TrainingBatch( training_input=MemoryNetworkInput( state=FeatureVector(float_features=state_features), action=tdp.training_input.action, next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss(new_tdp, state_dim=state_dim, batch_first=True) feature_importance[i + action_feature_num] = ( losses["loss"].cpu().detach().item() - orig_loss) del losses self.trainer.mdnrnn.mdnrnn.train() logger.info("**** Debug tool feature importance ****: {}".format( feature_importance)) return {"feature_loss_increase": feature_importance.numpy()}