def handle(self, tdp: TrainingBatch) -> None: batch_size, _, _ = tdp.training_input.next_state.size() tdp = TrainingBatch( training_input=MemoryNetworkInput( state=tdp.training_input.state, action=tdp.training_input.action, # shuffle the data next_state=tdp.training_input.next_state[torch.randperm( batch_size)], reward=tdp.training_input.reward[torch.randperm(batch_size)], not_terminal=tdp.training_input.not_terminal[torch.randperm( batch_size)], ), extras=ExtraData(), ) losses = self.trainer_or_evaluator.train(tdp, batch_first=True) self.results.append(losses)
def __call__(self, batch: TrainingBatch) -> TrainingBatch: training_input = cast(Union[DiscreteDqnInput, ParametricDqnInput], batch.training_input) preprocessed_state = self.state_preprocessor( training_input.state.float_features.value, training_input.state.float_features.presence, ) preprocessed_next_state = self.state_preprocessor( training_input.next_state.float_features.value, training_input.next_state.float_features.presence, ) new_training_input = training_input._replace( state=training_input.state._replace( float_features=preprocessed_state), next_state=training_input.next_state._replace( float_features=preprocessed_next_state), ) return batch._replace(training_input=new_training_input)
def preprocess(self, batch) -> TrainingBatch: state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor( batch["state_features"]) next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor( batch["next_state_features"]) mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1) sequence_numbers = torch.tensor(batch["sequence_number"], dtype=torch.int32).reshape(-1, 1) rewards = torch.tensor(batch["reward"], dtype=torch.float32).reshape(-1, 1) time_diffs = torch.tensor(batch["time_diff"], dtype=torch.int32).reshape(-1, 1) if "action_probability" in batch: propensities = torch.tensor(batch["action_probability"], dtype=torch.float32).reshape(-1, 1) else: propensities = torch.ones(rewards.shape, dtype=torch.float32) return TrainingBatch( training_input=BaseInput( state=FeatureVector(float_features=ValuePresence( value=state_features_dense, presence=state_features_dense_presence, )), next_state=FeatureVector(float_features=ValuePresence( value=next_state_features_dense, presence=next_state_features_dense_presence, )), reward=rewards, time_diff=time_diffs, ), extras=ExtraData( mdp_id=mdp_ids, sequence_number=sequence_numbers, action_probability=propensities, ), )
def __call__(self, batch: TrainingBatch) -> TrainingBatch: batch = super().__call__(batch) training_input = cast(PolicyNetworkInput, batch.training_input) action_before_preprocessing = cast(FeatureVector, training_input.action) preprocessed_action = self.action_preprocessor( action_before_preprocessing.float_features.value, action_before_preprocessing.float_features.presence, ) next_action_before_preprocessing = cast(FeatureVector, training_input.next_action) preprocessed_next_action = self.action_preprocessor( next_action_before_preprocessing.float_features.value, next_action_before_preprocessing.float_features.presence, ) return batch._replace(training_input=training_input._replace( action=action_before_preprocessing._replace( float_features=preprocessed_action), next_action=next_action_before_preprocessing._replace( float_features=preprocessed_next_action), ))
def train(self, training_batch: rlt.TrainingBatch) -> None: if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch() learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh action = rlt.FeatureVector( rescale_torch_tensor( learning_input.action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) ) rewards = learning_input.reward next_state = learning_input.next_state time_diffs = learning_input.time_diff discount_tensor = torch.full_like(rewards, self.gamma) not_done_mask = learning_input.not_terminal # Optimize the critic network subject to mean squared error: # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2 q_s1_a1 = self.critic.forward( rlt.StateAction(state=state, action=action) ).q_value next_action = rlt.FeatureVector( float_features=self.actor_target( rlt.StateAction(state=next_state, action=None) ).action ) q_s2_a2 = self.critic_target.forward( rlt.StateAction(state=next_state, action=next_action) ).q_value filtered_q_s2_a2 = not_done_mask.float() * q_s2_a2 if self.use_seq_num_diff_as_time_diff: discount_tensor = discount_tensor.pow(time_diffs) target_q_values = rewards + (discount_tensor * filtered_q_s2_a2) # compute loss and update the critic network critic_predictions = q_s1_a1 loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach()) loss_critic_for_eval = loss_critic.detach() self.critic_optimizer.zero_grad() loss_critic.backward() self.critic_optimizer.step() # Optimize the actor network subject to the following: # max mean(Q(s1, a1)) or min -mean(Q(s1, a1)) actor_output = self.actor(rlt.StateAction(state=state, action=None)) loss_actor = -( self.critic.forward( rlt.StateAction( state=state, action=rlt.FeatureVector(float_features=actor_output.action), ) ).q_value.mean() ) # Zero out both the actor and critic gradients because we need # to backprop through the critic to get to the actor self.actor_optimizer.zero_grad() loss_actor.backward() self.actor_optimizer.step() # Use the soft update rule to update both target networks self._soft_update(self.actor, self.actor_target, self.tau) self._soft_update(self.critic, self.critic_target, self.tau) self.loss_reporter.report( td_loss=float(loss_critic_for_eval), reward_loss=None, model_values_on_logged_actions=critic_predictions, )
def evaluate(self, tdp: TrainingBatch): """ Calculate feature importance: setting each state/action feature to the mean value and observe loss increase. """ self.trainer.mdnrnn.mdnrnn.eval() state_features = tdp.training_input.state.float_features action_features = tdp.training_input.action.float_features # type: ignore batch_size, seq_len, state_dim = state_features.size() # type: ignore action_dim = action_features.size()[2] # type: ignore action_feature_num = self.action_feature_num state_feature_num = self.state_feature_num feature_importance = torch.zeros(action_feature_num + state_feature_num) orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True) orig_loss = orig_losses["loss"].cpu().detach().item() del orig_losses action_feature_boundaries = self.sorted_action_feature_start_indices + [ action_dim ] state_feature_boundaries = self.sorted_state_feature_start_indices + [state_dim] for i in range(action_feature_num): action_features = tdp.training_input.action.float_features.reshape( # type: ignore (batch_size * seq_len, action_dim) ).data.clone() # if actions are discrete, an action's feature importance is the loss # increase due to setting all actions to this action if self.discrete_action: assert action_dim == action_feature_num action_vec = torch.zeros(action_dim) action_vec[i] = 1 action_features[:] = action_vec # type: ignore # if actions are continuous, an action's feature importance is the loss # increase due to masking this action feature to its mean value else: boundary_start, boundary_end = ( action_feature_boundaries[i], action_feature_boundaries[i + 1], ) action_features[ # type: ignore :, boundary_start:boundary_end ] = self.compute_median_feature_value( # type: ignore action_features[:, boundary_start:boundary_end] # type: ignore ) action_features = action_features.reshape( # type: ignore (batch_size, seq_len, action_dim) ) # type: ignore new_tdp = TrainingBatch( training_input=MemoryNetworkInput( # type: ignore state=tdp.training_input.state, action=FeatureVector( # type: ignore float_features=action_features ), # type: ignore next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss( new_tdp, state_dim=state_dim, batch_first=True ) feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss del losses for i in range(state_feature_num): state_features = tdp.training_input.state.float_features.reshape( # type: ignore (batch_size * seq_len, state_dim) ).data.clone() boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) state_features[ # type: ignore :, boundary_start:boundary_end ] = self.compute_median_feature_value( state_features[:, boundary_start:boundary_end] # type: ignore ) state_features = state_features.reshape( # type: ignore (batch_size, seq_len, state_dim) ) # type: ignore new_tdp = TrainingBatch( training_input=MemoryNetworkInput( # type: ignore state=FeatureVector(float_features=state_features), # type: ignore action=tdp.training_input.action, next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss( new_tdp, state_dim=state_dim, batch_first=True ) feature_importance[i + action_feature_num] = ( losses["loss"].cpu().detach().item() - orig_loss ) del losses self.trainer.mdnrnn.mdnrnn.train() logger.info( "**** Debug tool feature importance ****: {}".format(feature_importance) ) return {"feature_loss_increase": feature_importance.numpy()}
def __call__(self, batch: TrainingBatch) -> TrainingBatch: batch = super().__call__(batch) if isinstance(batch.training_input, ParametricDqnInput): training_input = cast(ParametricDqnInput, batch.training_input) preprocessed_tiled_next_state = self.state_preprocessor( training_input.tiled_next_state.float_features.value, training_input.tiled_next_state.float_features.presence, ) preprocessed_action = self.action_preprocessor( training_input.action.float_features.value, training_input.action.float_features.presence, ) preprocessed_next_action = self.action_preprocessor( training_input.next_action.float_features.value, training_input.next_action.float_features.presence, ) preprocessed_possible_actions = self.action_preprocessor( training_input.possible_actions.float_features.value, training_input.possible_actions.float_features.presence, ) preprocessed_possible_next_actions = self.action_preprocessor( training_input.possible_next_actions.float_features.value, training_input.possible_next_actions.float_features.presence, ) return batch._replace(training_input=training_input._replace( action=training_input.action._replace( float_features=preprocessed_action), next_action=training_input.next_action._replace( float_features=preprocessed_next_action), possible_actions=training_input.possible_actions._replace( float_features=preprocessed_possible_actions), possible_next_actions=training_input.possible_next_actions. _replace(float_features=preprocessed_possible_next_actions), tiled_next_state=training_input.tiled_next_state._replace( float_features=preprocessed_tiled_next_state), )) elif isinstance(batch.training_input, SARSAInput): training_input_sarsa = cast(SARSAInput, batch.training_input) preprocessed_tiled_next_state = self.state_preprocessor( training_input_sarsa.tiled_next_state.float_features. value, # type: ignore training_input_sarsa.tiled_next_state.float_features. presence, # type: ignore ) preprocessed_action = self.action_preprocessor( training_input_sarsa.action.float_features. value, # type: ignore training_input_sarsa.action.float_features. presence, # type: ignore ) preprocessed_next_action = self.action_preprocessor( training_input_sarsa.next_action.float_features. value, # type: ignore training_input_sarsa.next_action.float_features. presence, # type: ignore ) return batch._replace(training_input=training_input_sarsa._replace( action=training_input_sarsa.action._replace( # type: ignore float_features=preprocessed_action), next_action=training_input_sarsa.next_action. _replace( # type: ignore float_features=preprocessed_next_action), tiled_next_state=training_input_sarsa.tiled_next_state. _replace( # type: ignore float_features=preprocessed_tiled_next_state), )) else: assert False, "Invalid training_input type: " + str( type(batch.training_input))
def evaluate(self, tdp: TrainingBatch): """ Calculate feature importance: setting each state/action feature to the mean value and observe loss increase. """ self.trainer.mdnrnn.mdnrnn.eval() state_features = tdp.training_input.state.float_features action_features = tdp.training_input.action.float_features batch_size, seq_len, state_dim = state_features.size() action_dim = action_features.size()[2] action_feature_num = self.feature_extractor.action_feature_num state_feature_num = self.feature_extractor.state_feature_num feature_importance = torch.zeros(action_feature_num + state_feature_num) orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True) orig_loss = orig_losses["loss"].cpu().detach().item() del orig_losses action_feature_boundaries = ( self.feature_extractor.sorted_action_feature_start_indices + [action_dim]) state_feature_boundaries = ( self.feature_extractor.sorted_state_feature_start_indices + [state_dim]) for i in range(action_feature_num): action_features = tdp.training_input.action.float_features.reshape( (batch_size * seq_len, action_dim)).data.clone() boundary_start, boundary_end = ( action_feature_boundaries[i], action_feature_boundaries[i + 1], ) action_features[:, boundary_start: boundary_end] = action_features[:, boundary_start: boundary_end].mean( dim=0) action_features = action_features.reshape( (batch_size, seq_len, action_dim)) new_tdp = TrainingBatch( training_input=MemoryNetworkInput( state=tdp.training_input.state, action=FeatureVector(float_features=action_features), next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss(new_tdp, state_dim=state_dim, batch_first=True) feature_importance[i] = losses["loss"].cpu().detach().item( ) - orig_loss del losses for i in range(state_feature_num): state_features = tdp.training_input.state.float_features.reshape( (batch_size * seq_len, state_dim)).data.clone() boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) state_features[:, boundary_start: boundary_end] = state_features[:, boundary_start: boundary_end].mean( dim=0) state_features = state_features.reshape( (batch_size, seq_len, state_dim)) new_tdp = TrainingBatch( training_input=MemoryNetworkInput( state=FeatureVector(float_features=state_features), action=tdp.training_input.action, next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss(new_tdp, state_dim=state_dim, batch_first=True) feature_importance[i + action_feature_num] = ( losses["loss"].cpu().detach().item() - orig_loss) del losses self.trainer.mdnrnn.mdnrnn.train() logger.info("**** Debug tool feature importance ****: {}".format( feature_importance)) return {"feature_loss_increase": feature_importance.numpy()}