def extract(self, ws, input_record, extract_record): def fetch(b): data = ws.fetch_blob(str(b())) return torch.tensor(data) state = mt.FeatureVector(float_features=fetch(extract_record.state)) if self.sorted_action_features is None: action = None else: action = mt.FeatureVector( float_features=fetch(extract_record.action)) return mt.StateAction(state=state, action=action)
def input_prototype(self): if self.parametric_action: return rlt.StateAction( state=rlt.FeatureVector(float_features=torch.randn(1, self.state_dim)), action=rlt.FeatureVector( float_features=torch.randn(1, self.action_dim) ), ) else: return rlt.StateInput( state=rlt.FeatureVector(float_features=torch.randn(1, self.state_dim)) )
def internal_reward_estimation(self, state, action): """ Only used by Gym """ self.reward_network.eval() reward_estimates = self.reward_network( rlt.StateAction( state=rlt.FeatureVector(float_features=state), action=rlt.FeatureVector(float_features=action), )) self.reward_network.train() return reward_estimates.q_value.cpu()
def get_detached_q_values( self, state, action) -> Tuple[rlt.SingleQValue, Optional[rlt.SingleQValue]]: """ Gets the q values from the model and target networks """ with torch.no_grad(): input = rlt.StateAction(state=state, action=action) q_values = self.q_network(input) if self.double_q_learning: q_values_target = self.q_network_target(input) else: q_values_target = None return q_values, q_values_target
def internal_reward_estimation(self, state, action): """ Only used by Gym """ self.reward_network.eval() with torch.no_grad(): state = torch.from_numpy(np.array(state)).type(self.dtype) action = torch.from_numpy(np.array(action)).type(self.dtype) reward_estimates = self.reward_network( rlt.StateAction( state=rlt.FeatureVector(float_features=state), action=rlt.FeatureVector(float_features=action), )) self.reward_network.train() return reward_estimates.q_value.cpu().data.numpy()
def get_max_q_values(self, tiled_next_state, possible_next_actions, double_q_learning): """ :param double_q_learning: bool to use double q-learning """ lengths = possible_next_actions.lengths row_nums = np.arange(len(lengths)) row_idxs = np.repeat(row_nums, lengths.cpu().numpy()) col_idxs = arange_expand(lengths).cpu().numpy() dense_idxs = torch.tensor((row_idxs, col_idxs), device=lengths.device, dtype=torch.int64) q_network_input = rlt.StateAction(state=tiled_next_state, action=possible_next_actions.actions) if double_q_learning: q_values = self.q_network( q_network_input).q_value.squeeze().detach() q_values_target = (self.q_network_target( q_network_input).q_value.squeeze().detach()) else: q_values = self.q_network_target( q_network_input).q_value.squeeze().detach() dense_dim = [len(lengths), max(lengths)] # Add specific fingerprint to q-values so that after sparse -> dense we can # subtract the fingerprint to identify the 0's added in sparse -> dense q_values.add_(self.FINGERPRINT) sparse_q = torch.sparse_coo_tensor(dense_idxs, q_values, dense_dim) dense_q = sparse_q.to_dense() dense_q.add_(self.FINGERPRINT * -1) dense_q[dense_q == self.FINGERPRINT * -1] = self.ACTION_NOT_POSSIBLE_VAL max_q_values, max_indexes = torch.max(dense_q, dim=1) if double_q_learning: sparse_q_target = torch.sparse_coo_tensor(dense_idxs, q_values_target, dense_dim) dense_q_values_target = sparse_q_target.to_dense() max_q_values = torch.gather(dense_q_values_target, 1, max_indexes.unsqueeze(1)) return max_q_values.squeeze()
def create_from_tensors( cls, trainer: RLTrainer, mdp_ids: np.ndarray, sequence_numbers: torch.Tensor, states: Union[mt.State, torch.Tensor], actions: Union[mt.Action, torch.Tensor], propensities: torch.Tensor, rewards: torch.Tensor, possible_actions_mask: torch.Tensor, possible_actions: Optional[mt.FeatureVector] = None, max_num_actions: Optional[int] = None, metrics: Optional[torch.Tensor] = None, ): with torch.no_grad(): # Switch to evaluation mode for the network old_q_train_state = trainer.q_network.training old_reward_train_state = trainer.reward_network.training trainer.q_network.train(False) trainer.reward_network.train(False) if max_num_actions: # Parametric model CPE state_action_pairs = mt.StateAction(state=states, action=actions) tiled_state = mt.FeatureVector( states.float_features.repeat(1, max_num_actions).reshape( -1, states.float_features.shape[1] ) ) # Get Q-value of action taken possible_actions_state_concat = mt.StateAction( state=tiled_state, action=possible_actions ) # Parametric actions model_values = trainer.q_network(possible_actions_state_concat).q_value assert ( model_values.shape[0] * model_values.shape[1] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1] ), ( "Invalid shapes: " + str(model_values.shape) + " != " + str(possible_actions_mask.shape) ) model_values = model_values.reshape(possible_actions_mask.shape) model_rewards = trainer.reward_network( possible_actions_state_concat ).q_value assert ( model_rewards.shape[0] * model_rewards.shape[1] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1] ), ( "Invalid shapes: " + str(model_rewards.shape) + " != " + str(possible_actions_mask.shape) ) model_rewards = model_rewards.reshape(possible_actions_mask.shape) model_values_for_logged_action = trainer.q_network( state_action_pairs ).q_value model_rewards_for_logged_action = trainer.reward_network( state_action_pairs ).q_value action_mask = ( torch.abs(model_values - model_values_for_logged_action) < 1e-3 ).float() model_metrics = None model_metrics_for_logged_action = None model_metrics_values = None model_metrics_values_for_logged_action = None else: action_mask = actions.float() # Switch to evaluation mode for the network old_q_cpe_train_state = trainer.q_network_cpe.training trainer.q_network_cpe.train(False) # Discrete actions rewards = trainer.boost_rewards(rewards, actions) model_values = trainer.get_detached_q_values(states)[0] assert model_values.shape == actions.shape, ( "Invalid shape: " + str(model_values.shape) + " != " + str(actions.shape) ) assert model_values.shape == possible_actions_mask.shape, ( "Invalid shape: " + str(model_values.shape) + " != " + str(possible_actions_mask.shape) ) model_values_for_logged_action = torch.sum( model_values * action_mask, dim=1, keepdim=True ) if isinstance(states, mt.State): states = mt.StateInput(state=states) rewards_and_metric_rewards = trainer.reward_network(states) # In case we reuse the modular for Q-network if hasattr(rewards_and_metric_rewards, "q_values"): rewards_and_metric_rewards = rewards_and_metric_rewards.q_values num_actions = trainer.num_actions model_rewards = rewards_and_metric_rewards[:, 0:num_actions] assert model_rewards.shape == actions.shape, ( "Invalid shape: " + str(model_rewards.shape) + " != " + str(actions.shape) ) model_rewards_for_logged_action = torch.sum( model_rewards * action_mask, dim=1, keepdim=True ) model_metrics = rewards_and_metric_rewards[:, num_actions:] assert model_metrics.shape[1] % num_actions == 0, ( "Invalid metrics shape: " + str(model_metrics.shape) + " " + str(num_actions) ) num_metrics = model_metrics.shape[1] // num_actions if num_metrics == 0: model_metrics_values = None model_metrics_for_logged_action = None model_metrics_values_for_logged_action = None else: model_metrics_values = trainer.q_network_cpe(states) # Backward compatility if hasattr(model_metrics_values, "q_values"): model_metrics_values = model_metrics_values.q_values model_metrics_values = model_metrics_values[:, num_actions:] assert model_metrics_values.shape[1] == num_actions * num_metrics, ( "Invalid shape: " + str(model_metrics_values.shape[1]) + " != " + str(actions.shape[1] * num_metrics) ) model_metrics_for_logged_action_list = [] model_metrics_values_for_logged_action_list = [] for metric_index in range(num_metrics): metric_start = metric_index * num_actions metric_end = (metric_index + 1) * num_actions model_metrics_for_logged_action_list.append( torch.sum( model_metrics[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_values_for_logged_action_list.append( torch.sum( model_metrics_values[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_for_logged_action = torch.cat( model_metrics_for_logged_action_list, dim=1 ) model_metrics_values_for_logged_action = torch.cat( model_metrics_values_for_logged_action_list, dim=1 ) # Switch back to the old mode trainer.q_network_cpe.train(old_q_cpe_train_state) # Switch back to the old mode trainer.q_network.train(old_q_train_state) trainer.reward_network.train(old_reward_train_state) return cls( mdp_id=mdp_ids, sequence_number=sequence_numbers, logged_propensities=propensities, logged_rewards=rewards, action_mask=action_mask, model_rewards=model_rewards, model_rewards_for_logged_action=model_rewards_for_logged_action, model_values=model_values, model_values_for_logged_action=model_values_for_logged_action, model_metrics_values=model_metrics_values, model_metrics_values_for_logged_action=model_metrics_values_for_logged_action, model_propensities=masked_softmax( model_values, possible_actions_mask, trainer.rl_temperature ), logged_metrics=metrics, model_metrics=model_metrics, model_metrics_for_logged_action=model_metrics_for_logged_action, # Will compute later logged_values=None, logged_metrics_values=None, possible_actions_mask=possible_actions_mask, )
def input_prototype(self) -> rlt.StateAction: return rlt.StateAction( state=rlt.FeatureVector(float_features=torch.randn(1, self.state_dim)), action=rlt.FeatureVector(float_features=torch.randn(1, self.action_dim)), )
def train(self, training_batch, evaluator=None) -> None: if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch( ) learning_input = training_batch.training_input self.minibatch += 1 reward = learning_input.reward discount_tensor = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal if self.use_seq_num_diff_as_time_diff: # TODO: Implement this in another diff raise NotImplementedError if self.maxq_learning: # Compute max a' Q(s', a') over all possible actions using target network next_q_values = self.get_max_q_values( learning_input.tiled_next_state, learning_input.possible_next_actions, self.double_q_learning, ) else: # SARSA next_q_values = self.get_next_action_q_values( learning_input.next_state, learning_input.next_action) filtered_max_q_vals = next_q_values.reshape(-1, 1) * not_done_mask if self.minibatch < self.reward_burnin: target_q_values = reward else: target_q_values = reward + (discount_tensor * filtered_max_q_vals) # Get Q-value of action taken current_state_action = rlt.StateAction(state=learning_input.state, action=learning_input.action) q_values = self.q_network(current_state_action).q_value self.all_action_scores = q_values.detach() value_loss = self.q_network_loss(q_values, target_q_values) self.loss = value_loss.detach() self.q_network_optimizer.zero_grad() value_loss.backward() if self.gradient_handler: self.gradient_handler(self.q_network.parameters()) self.q_network_optimizer.step() # TODO: Maybe soft_update should belong to the target network if self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.q_network, self.q_network_target, 1.0) else: # Use the soft update rule to update target network self._soft_update(self.q_network, self.q_network_target, self.tau) # get reward estimates reward_estimates = self.reward_network(current_state_action).q_value reward_loss = F.mse_loss(reward_estimates, reward) self.reward_network_optimizer.zero_grad() reward_loss.backward() self.reward_network_optimizer.step() self.loss_reporter.report(td_loss=float(self.loss), reward_loss=float(reward_loss)) if evaluator is not None: cpe_stats = BatchStatsForCPE( model_values_on_logged_actions=self.all_action_scores) evaluator.report(cpe_stats)
def train(self, training_batch) -> None: if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch( ) learning_input = training_batch.training_input self.minibatch += 1 reward = learning_input.reward if self.multi_steps is not None: discount_tensor = torch.pow(self.gamma, learning_input.step.float()) else: discount_tensor = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal if self.use_seq_num_diff_as_time_diff: if self.multi_steps is not None: # TODO: Implement this in another diff pass else: discount_tensor = discount_tensor.pow( learning_input.time_diff.float()) if self.maxq_learning: all_next_q_values, all_next_q_values_target = self.get_detached_q_values( learning_input.tiled_next_state, learning_input.possible_next_actions) # Compute max a' Q(s', a') over all possible actions using target network next_q_values, _ = self.get_max_q_values_with_target( all_next_q_values.q_value, all_next_q_values_target.q_value, learning_input.possible_next_actions_mask.float(), ) else: # SARSA (Use the target network) _, next_q_values = self.get_detached_q_values( learning_input.next_state, learning_input.next_action) next_q_values = next_q_values.q_value filtered_max_q_vals = next_q_values * not_done_mask.float() if self.minibatch < self.reward_burnin: target_q_values = reward else: target_q_values = reward + (discount_tensor * filtered_max_q_vals) # Get Q-value of action taken current_state_action = rlt.StateAction(state=learning_input.state, action=learning_input.action) q_values = self.q_network(current_state_action).q_value self.all_action_scores = q_values.detach() value_loss = self.q_network_loss(q_values, target_q_values) self.loss = value_loss.detach() self.q_network_optimizer.zero_grad() value_loss.backward() if self.gradient_handler: self.gradient_handler(self.q_network.parameters()) self.q_network_optimizer.step() # TODO: Maybe soft_update should belong to the target network if self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.q_network, self.q_network_target, 1.0) else: # Use the soft update rule to update target network self._soft_update(self.q_network, self.q_network_target, self.tau) # get reward estimates reward_estimates = self.reward_network(current_state_action).q_value reward_loss = F.mse_loss(reward_estimates, reward) self.reward_network_optimizer.zero_grad() reward_loss.backward() self.reward_network_optimizer.step() self.loss_reporter.report( td_loss=self.loss, reward_loss=reward_loss, model_values_on_logged_actions=self.all_action_scores, )
def train(self, training_batch: rlt.TrainingBatch) -> None: if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch() learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh action = rlt.FeatureVector( rescale_torch_tensor( learning_input.action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) ) rewards = learning_input.reward next_state = learning_input.next_state time_diffs = learning_input.time_diff discount_tensor = torch.full_like(rewards, self.gamma) not_done_mask = learning_input.not_terminal # Optimize the critic network subject to mean squared error: # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2 q_s1_a1 = self.critic.forward( rlt.StateAction(state=state, action=action) ).q_value next_action = rlt.FeatureVector( float_features=self.actor_target( rlt.StateAction(state=next_state, action=None) ).action ) q_s2_a2 = self.critic_target.forward( rlt.StateAction(state=next_state, action=next_action) ).q_value filtered_q_s2_a2 = not_done_mask.float() * q_s2_a2 if self.use_seq_num_diff_as_time_diff: discount_tensor = discount_tensor.pow(time_diffs) target_q_values = rewards + (discount_tensor * filtered_q_s2_a2) # compute loss and update the critic network critic_predictions = q_s1_a1 loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach()) loss_critic_for_eval = loss_critic.detach() self.critic_optimizer.zero_grad() loss_critic.backward() self.critic_optimizer.step() # Optimize the actor network subject to the following: # max mean(Q(s1, a1)) or min -mean(Q(s1, a1)) actor_output = self.actor(rlt.StateAction(state=state, action=None)) loss_actor = -( self.critic.forward( rlt.StateAction( state=state, action=rlt.FeatureVector(float_features=actor_output.action), ) ).q_value.mean() ) # Zero out both the actor and critic gradients because we need # to backprop through the critic to get to the actor self.actor_optimizer.zero_grad() loss_actor.backward() self.actor_optimizer.step() # Use the soft update rule to update both target networks self._soft_update(self.actor, self.actor_target, self.tau) self._soft_update(self.critic, self.critic_target, self.tau) self.loss_reporter.report( td_loss=float(loss_critic_for_eval), reward_loss=None, model_values_on_logged_actions=critic_predictions, )
def create_embed_rl_dataset( gym_env: OpenAIGymEnvironment, trainer: MDNRNNTrainer, dataset: RLDataset, use_gpu: bool = False, seq_len: int = 5, num_state_embed_episodes: int = 100, max_steps: Optional[int] = None, **kwargs, ): old_mdnrnn_mode = trainer.mdnrnn.mdnrnn.training trainer.mdnrnn.mdnrnn.eval() num_transitions = num_state_embed_episodes * max_steps # type: ignore device = torch.device("cuda") if use_gpu else torch.device( "cpu") # type: ignore ( state_batch, action_batch, reward_batch, next_state_batch, next_action_batch, not_terminal_batch, step_batch, next_step_batch, ) = map( list, zip(*multi_step_sample_generator( gym_env=gym_env, num_transitions=num_transitions, max_steps=max_steps, # +1 because MDNRNN embeds the first seq_len steps and then # the embedded state will be concatenated with the last step multi_steps=seq_len + 1, include_shorter_samples_at_start=True, include_shorter_samples_at_end=False, )), ) def concat_batch(batch): return torch.cat( [ torch.tensor(np.expand_dims(x, axis=1), dtype=torch.float, device=device) for x in batch ], dim=1, ) # shape: seq_len x batch_size x feature_dim mdnrnn_state = concat_batch(state_batch) next_mdnrnn_state = concat_batch(next_state_batch) mdnrnn_action = concat_batch(action_batch) next_mdnrnn_action = concat_batch(next_action_batch) mdnrnn_input = rlt.StateAction( state=rlt.FeatureVector(float_features=mdnrnn_state), action=rlt.FeatureVector(float_features=mdnrnn_action), ) next_mdnrnn_input = rlt.StateAction( state=rlt.FeatureVector(float_features=next_mdnrnn_state), action=rlt.FeatureVector(float_features=next_mdnrnn_action), ) # batch-compute state embedding mdnrnn_output = trainer.mdnrnn(mdnrnn_input) next_mdnrnn_output = trainer.mdnrnn(next_mdnrnn_input) for i in range(len(state_batch)): # Embed the state as the hidden layer's output # until the previous step + current state hidden_idx = 0 if step_batch[ i] == 1 else step_batch[i] - 2 # type: ignore next_hidden_idx = next_step_batch[i] - 2 # type: ignore hidden_embed = ( mdnrnn_output.all_steps_lstm_hidden[hidden_idx, i, :].squeeze().detach().cpu()) state_embed = torch.cat( (hidden_embed, torch.tensor(state_batch[i][hidden_idx + 1]) ) # type: ignore ) next_hidden_embed = (next_mdnrnn_output.all_steps_lstm_hidden[ next_hidden_idx, i, :].squeeze().detach().cpu()) next_state_embed = torch.cat(( next_hidden_embed, torch.tensor(next_state_batch[i][next_hidden_idx + 1]), # type: ignore )) logger.debug( "create_embed_rl_dataset:\nstate batch\n{}\naction batch\n{}\nlast " "action: {},reward: {}\nstate embed {}\nnext state embed {}\n". format( state_batch[i][:hidden_idx + 1], # type: ignore action_batch[i][:hidden_idx + 1], # type: ignore action_batch[i][hidden_idx + 1], # type: ignore reward_batch[i][hidden_idx + 1], # type: ignore state_embed, next_state_embed, )) terminal = 1 - not_terminal_batch[i][hidden_idx + 1] # type: ignore possible_actions, possible_actions_mask = get_possible_actions( gym_env, ModelType.PYTORCH_PARAMETRIC_DQN.value, False) possible_next_actions, possible_next_actions_mask = get_possible_actions( gym_env, ModelType.PYTORCH_PARAMETRIC_DQN.value, terminal) dataset.insert( state=state_embed, action=torch.tensor(action_batch[i][hidden_idx + 1]), # type: ignore reward=reward_batch[i][hidden_idx + 1], # type: ignore next_state=next_state_embed, next_action=torch.tensor(next_action_batch[i][next_hidden_idx + 1] # type: ignore ), terminal=torch.tensor(terminal), possible_next_actions=possible_next_actions, possible_next_actions_mask=possible_next_actions_mask, time_diff=torch.tensor(1), possible_actions=possible_actions, possible_actions_mask=possible_actions_mask, policy_id=0, ) logger.info("Insert {} transitions into a state embed dataset".format( len(state_batch))) trainer.mdnrnn.mdnrnn.train(old_mdnrnn_mode) return dataset
def train(self, training_batch) -> None: """ IMPORTANT: the input action here is assumed to be preprocessed to match the range of the output of the actor. """ if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch( ) learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state action = learning_input.action next_state = learning_input.next_state reward = learning_input.reward not_done_mask = learning_input.not_terminal action = self._maybe_scale_action_in_train(action) # Compute current value estimates current_state_action = rlt.StateAction(state=state, action=action) q1_value = self.q1_network(current_state_action).q_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value actor_action = self.actor_network(rlt.StateInput(state=state)).action # Generate target = r + y * min (Q1(s',pi(s')), Q2(s',pi(s'))) with torch.no_grad(): next_actor = self.actor_network_target( rlt.StateInput(state=next_state)).action next_actor += (torch.randn_like(next_actor) * self.target_policy_smoothing).clamp( -self.noise_clip, self.noise_clip) next_actor = torch.max( torch.min(next_actor, self.max_action_range_tensor_training), self.min_action_range_tensor_training, ) next_state_actor = rlt.StateAction( state=next_state, action=rlt.FeatureVector(float_features=next_actor)) next_state_value = self.q1_network_target(next_state_actor).q_value if self.q2_network is not None: next_state_value = torch.min( next_state_value, self.q2_network_target(next_state_actor).q_value) target_q_value = ( reward + self.gamma * next_state_value * not_done_mask.float()) # Optimize Q1 and Q2 q1_loss = F.mse_loss(q1_value, target_q_value) q1_loss.backward() self._maybe_run_optimizer(self.q1_network_optimizer, self.minibatches_per_step) if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) q2_loss.backward() self._maybe_run_optimizer(self.q2_network_optimizer, self.minibatches_per_step) # Only update actor and target networks after a fixed number of Q updates if self.minibatch % self.delayed_policy_update == 0: actor_loss = -self.q1_network( rlt.StateAction( state=state, action=rlt.FeatureVector( float_features=actor_action))).q_value.mean() actor_loss.backward() self._maybe_run_optimizer(self.actor_network_optimizer, self.minibatches_per_step) # Use the soft update rule to update the target networks self._maybe_soft_update( self.q1_network, self.q1_network_target, self.tau, self.minibatches_per_step, ) self._maybe_soft_update( self.actor_network, self.actor_network_target, self.tau, self.minibatches_per_step, ) if self.q2_network is not None: self._maybe_soft_update( self.q2_network, self.q2_network_target, self.tau, self.minibatches_per_step, ) # Logging at the end to schedule all the cuda operations first if (self.tensorboard_logging_freq is not None and self.minibatch % self.tensorboard_logging_freq == 0): SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value) if self.q2_network: SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value) SummaryWriterContext.add_histogram("q_network/next_state_value", next_state_value) SummaryWriterContext.add_histogram("q_network/target_q_value", target_q_value) SummaryWriterContext.add_histogram("actor/loss", actor_loss) self.loss_reporter.report( td_loss=float(q1_loss), reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, )
def get_loss( self, training_batch: rlt.TrainingBatch, state_dim: Optional[int] = None, batch_first: bool = False, ): """ Compute losses. The loss that is computed is: (GMMLoss(next_state, GMMPredicted) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal)) / (STATE_DIM + 2) The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearily with STATE_DIM, the feature size of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :param training_batch training_batch.learning_input has these fields: state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor reward: (BATCH_SIZE, SEQ_LEN) torch tensor not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor :param state_dim: the dimension of states. If provided, use it to normalize loss :param batch_first: whether data's first dimension represents batch size. If FALSE, state, action, reward, not-terminal, and next_state's first two dimensions are SEQ_LEN and BATCH_SIZE. :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ learning_input = training_batch.training_input # mdnrnn's input should have seq_len as the first dimension if batch_first: state, action, next_state, reward, not_terminal = transpose( learning_input.state.float_features, learning_input.action.float_features, learning_input.next_state, learning_input.reward, learning_input.not_terminal, ) learning_input = rlt.MemoryNetworkInput( state=rlt.FeatureVector(float_features=state), action=rlt.FeatureVector(float_features=action), next_state=next_state, reward=reward, not_terminal=not_terminal, ) mdnrnn_input = rlt.StateAction(state=learning_input.state, action=learning_input.action) mdnrnn_output = self.mdnrnn(mdnrnn_input) mus, sigmas, logpi, rs, ds = ( mdnrnn_output.mus, mdnrnn_output.sigmas, mdnrnn_output.logpi, mdnrnn_output.reward, mdnrnn_output.not_terminal, ) gmm = gmm_loss(learning_input.next_state, mus, sigmas, logpi) bce = F.binary_cross_entropy_with_logits(ds, learning_input.not_terminal) mse = F.mse_loss(rs, learning_input.reward) if state_dim is not None: loss = (gmm + bce + mse) / (state_dim + 2) else: loss = mse + bce + gmm return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
def get_loss( self, training_batch: rlt.TrainingBatch, state_dim: Optional[int] = None, batch_first: bool = False, ): """ Compute losses. The loss that is computed is: (GMMLoss(next_state, GMMPredicted) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal)) / (STATE_DIM + 2) The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearily with STATE_DIM, the feature size of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :param training_batch training_batch.learning_input has these fields: - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor - reward: (BATCH_SIZE, SEQ_LEN) torch tensor - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor the first two dimensions may be swapped depending on batch_first :param state_dim: the dimension of states. If provided, use it to normalize gmm loss :param batch_first: whether data's first dimension represents batch size. If FALSE, state, action, reward, not-terminal, and next_state's first two dimensions are SEQ_LEN and BATCH_SIZE. :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ learning_input = training_batch.training_input # mdnrnn's input should have seq_len as the first dimension if batch_first: state, action, next_state, reward, not_terminal = transpose( learning_input.state.float_features, learning_input.action.float_features, # type: ignore learning_input.next_state, learning_input.reward, learning_input.not_terminal, # type: ignore ) learning_input = rlt.MemoryNetworkInput( # type: ignore state=rlt.FeatureVector(float_features=state), reward=reward, time_diff=torch.ones_like(reward).float(), action=rlt.FeatureVector(float_features=action), not_terminal=not_terminal, next_state=next_state, ) mdnrnn_input = rlt.StateAction( state=learning_input.state, action=learning_input.action # type: ignore ) mdnrnn_output = self.mdnrnn(mdnrnn_input) mus, sigmas, logpi, rs, nts = ( mdnrnn_output.mus, mdnrnn_output.sigmas, mdnrnn_output.logpi, mdnrnn_output.reward, mdnrnn_output.not_terminal, ) next_state = learning_input.next_state not_terminal = learning_input.not_terminal # type: ignore reward = learning_input.reward if self.params.fit_only_one_next_step: next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple( map( lambda x: x[-1:], (next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs), )) gmm = (gmm_loss(next_state, mus, sigmas, logpi) * self.params.next_state_loss_weight) bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) * self.params.not_terminal_loss_weight) mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight if state_dim is not None: loss = gmm / (state_dim + 2) + bce + mse else: loss = gmm + bce + mse return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
def get_next_action_q_values(self, state, action): return self.q_network_target( rlt.StateAction(state=state, action=action)).q_value
def train(self, training_batch) -> None: if isinstance(training_batch, TrainingDataPage): if self.maxq_learning: training_batch = training_batch.as_parametric_maxq_training_batch() else: training_batch = training_batch.as_parametric_sarsa_training_batch() learning_input = training_batch.training_input self.minibatch += 1 reward = learning_input.reward not_done_mask = learning_input.not_terminal discount_tensor = torch.full_like(reward, self.gamma) if self.use_seq_num_diff_as_time_diff: assert self.multi_steps is None discount_tensor = torch.pow(self.gamma, learning_input.time_diff.float()) if self.multi_steps is not None: discount_tensor = torch.pow(self.gamma, learning_input.step.float()) if self.maxq_learning: all_next_q_values, all_next_q_values_target = self.get_detached_q_values( learning_input.tiled_next_state, learning_input.possible_next_actions ) # Compute max a' Q(s', a') over all possible actions using target network next_q_values, _ = self.get_max_q_values_with_target( all_next_q_values.q_value, all_next_q_values_target.q_value, learning_input.possible_next_actions_mask.float(), ) else: # SARSA (Use the target network) _, next_q_values = self.get_detached_q_values( learning_input.next_state, learning_input.next_action ) next_q_values = next_q_values.q_value filtered_max_q_vals = next_q_values * not_done_mask.float() target_q_values = reward + (discount_tensor * filtered_max_q_vals) # Get Q-value of action taken current_state_action = rlt.StateAction( state=learning_input.state, action=learning_input.action ) q_values = self.q_network(current_state_action).q_value self.all_action_scores = q_values.detach() value_loss = self.q_network_loss(q_values, target_q_values) self.loss = value_loss.detach() self.q_network_optimizer.zero_grad() value_loss.backward() self.q_network_optimizer.step() # Use the soft update rule to update target network self._soft_update(self.q_network, self.q_network_target, self.tau) # get reward estimates reward_estimates = self.reward_network(current_state_action).q_value reward_loss = F.mse_loss(reward_estimates, reward) self.reward_network_optimizer.zero_grad() reward_loss.backward() self.reward_network_optimizer.step() self.loss_reporter.report( td_loss=self.loss, reward_loss=reward_loss, model_values_on_logged_actions=self.all_action_scores, )
def train(self, training_batch, evaluator=None) -> None: """ IMPORTANT: the input action here is assumed to be preprocessed to match the range of the output of the actor. """ if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch( ) learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state action = learning_input.action reward = learning_input.reward discount = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal if self._should_scale_action_in_train(): action = rlt.FeatureVector( rescale_torch_tensor( action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, )) current_state_action = rlt.StateAction(state=state, action=action) q1_value = self.q1_network(current_state_action).q_value min_q_value = q1_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value min_q_value = torch.min(q1_value, q2_value) # Use the minimum as target, ensure no gradient going through min_q_value = min_q_value.detach() # # First, optimize value network; minimizing MSE between # V(s) & Q(s, a) - log(pi(a|s)) # state_value = self.value_network(state.float_features) # .q_value if self.logged_action_uniform_prior: log_prob_a = torch.zeros_like(min_q_value) target_value = min_q_value else: with torch.no_grad(): log_prob_a = self.actor_network.get_log_prob( state, action.float_features) log_prob_a = log_prob_a.clamp(-20.0, 20.0) target_value = min_q_value - self.entropy_temperature * log_prob_a value_loss = F.mse_loss(state_value, target_value) self.value_network_optimizer.zero_grad() value_loss.backward() self.value_network_optimizer.step() # # Second, optimize Q networks; minimizing MSE between # Q(s, a) & r + discount * V'(next_s) # with torch.no_grad(): next_state_value = (self.value_network_target( learning_input.next_state.float_features) * not_done_mask) if self.minibatch < self.reward_burnin: target_q_value = reward else: target_q_value = reward + discount * next_state_value q1_loss = F.mse_loss(q1_value, target_q_value) self.q1_network_optimizer.zero_grad() q1_loss.backward() self.q1_network_optimizer.step() if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) self.q2_network_optimizer.zero_grad() q2_loss.backward() self.q2_network_optimizer.step() # # Lastly, optimize the actor; minimizing KL-divergence between action propensity # & softmax of value. Due to reparameterization trick, it ends up being # log_prob(actor_action) - Q(s, actor_action) # actor_output = self.actor_network(rlt.StateInput(state=state)) state_actor_action = rlt.StateAction( state=state, action=rlt.FeatureVector(float_features=actor_output.action)) q1_actor_value = self.q1_network(state_actor_action).q_value min_q_actor_value = q1_actor_value if self.q2_network: q2_actor_value = self.q2_network(state_actor_action).q_value min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) actor_loss = (self.entropy_temperature * actor_output.log_prob - min_q_actor_value) # Do this in 2 steps so we can log histogram of actor loss actor_loss_mean = actor_loss.mean() self.actor_network_optimizer.zero_grad() actor_loss_mean.backward() self.actor_network_optimizer.step() if self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.value_network, self.value_network_target, 1.0) else: # Use the soft update rule to update both target networks self._soft_update(self.value_network, self.value_network_target, self.tau) # Logging at the end to schedule all the cuda operations first if (self.tensorboard_logging_freq is not None and self.minibatch % self.tensorboard_logging_freq == 0): SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value) if self.q2_network: SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value) SummaryWriterContext.add_histogram("log_prob_a", log_prob_a) SummaryWriterContext.add_histogram("value_network/target", target_value) SummaryWriterContext.add_histogram("q_network/next_state_value", next_state_value) SummaryWriterContext.add_histogram("q_network/target_q_value", target_q_value) SummaryWriterContext.add_histogram("actor/min_q_actor_value", min_q_actor_value) SummaryWriterContext.add_histogram("actor/action_log_prob", actor_output.log_prob) SummaryWriterContext.add_histogram("actor/loss", actor_loss) if evaluator is not None: cpe_stats = BatchStatsForCPE( td_loss=q1_loss.detach().cpu().numpy(), logged_rewards=reward.detach().cpu().numpy(), model_values_on_logged_actions=q1_value.detach().cpu().numpy(), model_propensities=actor_output.log_prob.exp().detach().cpu(). numpy(), model_values=min_q_actor_value.detach().cpu().numpy(), ) evaluator.report(cpe_stats)
def forward(self, input): preprocessed_state = self.state_preprocessor(input.state) preprocessed_action = self.action_preprocessor(input.action) return self.q_network( rlt.StateAction(state=preprocessed_state, action=preprocessed_action))
def input_prototype(self): return rlt.StateAction( state=self.state_preprocessor.input_prototype(), action=self.action_preprocessor.input_prototype(), )
def train(self, training_batch, evaluator=None) -> None: if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch() learning_input = training_batch.training_input self.minibatch += 1 s = learning_input.state a = learning_input.action.float_features reward = learning_input.reward discount = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal current_state_action = rlt.StateAction( state=learning_input.state, action=learning_input.action ) q1_value = self.q1_network(current_state_action).q_value min_q_value = q1_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value min_q_value = torch.min(q1_value, q2_value) # Use the minimum as target, ensure no gradient going through min_q_value = min_q_value.detach() # # First, optimize value network; minimizing MSE between # V(s) & Q(s, a) - log(pi(a|s)) # state_value = self.value_network(s.float_features) # .q_value with torch.no_grad(): log_prob_a = self.actor_network.get_log_prob(s, a) target_value = min_q_value - self.entropy_temperature * log_prob_a value_loss = F.mse_loss(state_value, target_value) self.value_network_optimizer.zero_grad() value_loss.backward() self.value_network_optimizer.step() # # Second, optimize Q networks; minimizing MSE between # Q(s, a) & r + discount * V'(next_s) # with torch.no_grad(): next_state_value = ( self.value_network_target(learning_input.next_state.float_features) * not_done_mask ) if self.minibatch < self.reward_burnin: target_q_value = reward else: target_q_value = reward + discount * next_state_value q1_loss = F.mse_loss(q1_value, target_q_value) self.q1_network_optimizer.zero_grad() q1_loss.backward() self.q1_network_optimizer.step() if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) self.q2_network_optimizer.zero_grad() q2_loss.backward() self.q2_network_optimizer.step() # # Lastly, optimize the actor; minimizing KL-divergence between action propensity # & softmax of value. Due to reparameterization trick, it ends up being # log_prob(actor_action) - Q(s, actor_action) # actor_output = self.actor_network(rlt.StateInput(state=learning_input.state)) state_actor_action = rlt.StateAction( state=s, action=rlt.FeatureVector(float_features=actor_output.action) ) q1_actor_value = self.q1_network(state_actor_action).q_value min_q_actor_value = q1_actor_value if self.q2_network: q2_actor_value = self.q2_network(state_actor_action).q_value min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) actor_loss = torch.mean( self.entropy_temperature * actor_output.log_prob - min_q_actor_value ) self.actor_network_optimizer.zero_grad() actor_loss.backward() self.actor_network_optimizer.step() if self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.value_network, self.value_network_target, 1.0) else: # Use the soft update rule to update both target networks self._soft_update(self.value_network, self.value_network_target, self.tau) if evaluator is not None: # FIXME self.evaluate(evaluator)
def train(self, training_batch) -> None: """ IMPORTANT: the input action here is assumed to be preprocessed to match the range of the output of the actor. """ if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch( ) learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state action = learning_input.action reward = learning_input.reward discount = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal if self._should_scale_action_in_train(): action = rlt.FeatureVector( rescale_torch_tensor( action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, )) with torch.enable_grad(): # # First, optimize Q networks; minimizing MSE between # Q(s, a) & r + discount * V'(next_s) # current_state_action = rlt.StateAction(state=state, action=action) q1_value = self.q1_network(current_state_action).q_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value actor_output = self.actor_network(rlt.StateInput(state=state)) # Optimize Alpha if self.alpha_optimizer is not None: alpha_loss = -(self.log_alpha * (actor_output.log_prob + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.entropy_temperature = self.log_alpha.exp() with torch.no_grad(): if self.value_network is not None: next_state_value = self.value_network_target( learning_input.next_state.float_features) else: next_state_actor_output = self.actor_network( rlt.StateInput(state=learning_input.next_state)) next_state_actor_action = rlt.StateAction( state=learning_input.next_state, action=rlt.FeatureVector( float_features=next_state_actor_output.action), ) next_state_value = self.q1_network_target( next_state_actor_action).q_value if self.q2_network is not None: target_q2_value = self.q2_network_target( next_state_actor_action).q_value next_state_value = torch.min(next_state_value, target_q2_value) log_prob_a = self.actor_network.get_log_prob( learning_input.next_state, next_state_actor_output.action) log_prob_a = log_prob_a.clamp(-20.0, 20.0) next_state_value -= self.entropy_temperature * log_prob_a target_q_value = ( reward + discount * next_state_value * not_done_mask.float()) q1_loss = F.mse_loss(q1_value, target_q_value) q1_loss.backward() self._maybe_run_optimizer(self.q1_network_optimizer, self.minibatches_per_step) if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) q2_loss.backward() self._maybe_run_optimizer(self.q2_network_optimizer, self.minibatches_per_step) # # Second, optimize the actor; minimizing KL-divergence between action propensity # & softmax of value. Due to reparameterization trick, it ends up being # log_prob(actor_action) - Q(s, actor_action) # state_actor_action = rlt.StateAction( state=state, action=rlt.FeatureVector(float_features=actor_output.action), ) q1_actor_value = self.q1_network(state_actor_action).q_value min_q_actor_value = q1_actor_value if self.q2_network: q2_actor_value = self.q2_network(state_actor_action).q_value min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) actor_loss = (self.entropy_temperature * actor_output.log_prob - min_q_actor_value) # Do this in 2 steps so we can log histogram of actor loss actor_loss_mean = actor_loss.mean() actor_loss_mean.backward() self._maybe_run_optimizer(self.actor_network_optimizer, self.minibatches_per_step) # # Lastly, if applicable, optimize value network; minimizing MSE between # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ] # if self.value_network is not None: state_value = self.value_network(state.float_features) if self.logged_action_uniform_prior: log_prob_a = torch.zeros_like(min_q_actor_value) target_value = min_q_actor_value else: with torch.no_grad(): log_prob_a = actor_output.log_prob log_prob_a = log_prob_a.clamp(-20.0, 20.0) target_value = (min_q_actor_value - self.entropy_temperature * log_prob_a) value_loss = F.mse_loss(state_value, target_value.detach()) value_loss.backward() self._maybe_run_optimizer(self.value_network_optimizer, self.minibatches_per_step) # Use the soft update rule to update the target networks if self.value_network is not None: self._maybe_soft_update( self.value_network, self.value_network_target, self.tau, self.minibatches_per_step, ) else: self._maybe_soft_update( self.q1_network, self.q1_network_target, self.tau, self.minibatches_per_step, ) if self.q2_network is not None: self._maybe_soft_update( self.q2_network, self.q2_network_target, self.tau, self.minibatches_per_step, ) # Logging at the end to schedule all the cuda operations first if (self.tensorboard_logging_freq is not None and self.minibatch % self.tensorboard_logging_freq == 0): SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value) if self.q2_network: SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value) SummaryWriterContext.add_histogram("log_prob_a", log_prob_a) if self.value_network: SummaryWriterContext.add_histogram("value_network/target", target_value) SummaryWriterContext.add_histogram("q_network/next_state_value", next_state_value) SummaryWriterContext.add_histogram("q_network/target_q_value", target_q_value) SummaryWriterContext.add_histogram("actor/min_q_actor_value", min_q_actor_value) SummaryWriterContext.add_histogram("actor/action_log_prob", actor_output.log_prob) SummaryWriterContext.add_histogram("actor/loss", actor_loss) self.loss_reporter.report( td_loss=float(q1_loss), reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, model_propensities=actor_output.log_prob.exp(), model_values=min_q_actor_value, )