def train_step_gen(self, training_batch: rlt.DiscreteDqnInput, batch_idx: int): # TODO: calls to _maybe_run_optimizer removed, should be replaced with Trainer parameter assert isinstance(training_batch, rlt.DiscreteDqnInput) rewards = self.boost_rewards(training_batch.reward, training_batch.action) not_done_mask = training_batch.not_terminal.float() assert not_done_mask.dim() == 2 discount_tensor = self.compute_discount_tensor(training_batch, rewards) td_loss = self.compute_td_loss(training_batch, rewards, discount_tensor) yield td_loss td_loss = td_loss.detach() # Get Q-values of next states, used in computing cpe all_next_action_scores = self.q_network( training_batch.next_state).detach() logged_action_idxs = torch.argmax(training_batch.action, dim=1, keepdim=True) yield from self._calculate_cpes( training_batch, training_batch.state, training_batch.next_state, # pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`. self.all_action_scores, all_next_action_scores, logged_action_idxs, discount_tensor, not_done_mask, ) if self.maxq_learning: possible_actions_mask = training_batch.possible_actions_mask if self.bcq: action_on_policy = get_valid_actions_from_imitator( self.bcq_imitator, training_batch.state, self.bcq_drop_threshold) possible_actions_mask *= action_on_policy # Do we ever use model_action_idxs computed below? model_action_idxs = self.get_max_q_values( self.all_action_scores, possible_actions_mask if self.maxq_learning else training_batch.action, )[1] self._log_dqn(td_loss, logged_action_idxs, training_batch, rewards, model_action_idxs) # Use the soft update rule to update target network yield self.soft_update_result()
def compute_td_loss( self, batch: rlt.DiscreteDqnInput, boosted_rewards: torch.Tensor, discount_tensor: torch.Tensor, ): not_done_mask = batch.not_terminal.float() all_next_q_values, all_next_q_values_target = self.get_detached_model_outputs( batch.next_state) if self.maxq_learning: # Compute max a' Q(s', a') over all possible actions using target network possible_next_actions_mask = batch.possible_next_actions_mask.float( ) if self.bcq: action_on_policy = get_valid_actions_from_imitator( self.bcq_imitator, batch.next_state, self.bcq_drop_threshold, ) possible_next_actions_mask *= action_on_policy next_q_values, max_q_action_idxs = self.get_max_q_values_with_target( all_next_q_values, all_next_q_values_target, possible_next_actions_mask, ) else: # SARSA next_q_values, max_q_action_idxs = self.get_max_q_values_with_target( all_next_q_values, all_next_q_values_target, batch.next_action, ) filtered_next_q_vals = next_q_values * not_done_mask target_q_values = boosted_rewards + (discount_tensor * filtered_next_q_vals) # Get Q-value of action taken all_q_values = self.q_network(batch.state) # pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`. self.all_action_scores = all_q_values.detach() q_values = torch.sum(all_q_values * batch.action, 1, keepdim=True) td_loss = self.q_network_loss(q_values, target_q_values.detach()) return td_loss
def train(self, training_batch: rlt.DiscreteDqnInput): if isinstance(training_batch, TrainingDataPage): training_batch = training_batch.as_discrete_maxq_training_batch() assert isinstance(training_batch, rlt.DiscreteDqnInput) boosted_rewards = self.boost_rewards(training_batch.reward, training_batch.action) self.minibatch += 1 rewards = boosted_rewards discount_tensor = torch.full_like(rewards, self.gamma) not_done_mask = training_batch.not_terminal.float() assert not_done_mask.dim() == 2 if self.use_seq_num_diff_as_time_diff: assert self.multi_steps is None discount_tensor = torch.pow(self.gamma, training_batch.time_diff.float()) if self.multi_steps is not None: assert training_batch.step is not None # pyre-fixme[16]: `Optional` has no attribute `float`. discount_tensor = torch.pow(self.gamma, training_batch.step.float()) all_next_q_values, all_next_q_values_target = self.get_detached_q_values( training_batch.next_state) if self.maxq_learning: # Compute max a' Q(s', a') over all possible actions using target network possible_next_actions_mask = ( training_batch.possible_next_actions_mask.float()) if self.bcq: action_on_policy = get_valid_actions_from_imitator( self.bcq_imitator, training_batch.next_state, self.bcq_drop_threshold, ) possible_next_actions_mask *= action_on_policy next_q_values, max_q_action_idxs = self.get_max_q_values_with_target( all_next_q_values, all_next_q_values_target, possible_next_actions_mask) else: # SARSA next_q_values, max_q_action_idxs = self.get_max_q_values_with_target( all_next_q_values, all_next_q_values_target, training_batch.next_action) filtered_next_q_vals = next_q_values * not_done_mask target_q_values = rewards + (discount_tensor * filtered_next_q_vals) with torch.enable_grad(): # Get Q-value of action taken all_q_values = self.q_network(training_batch.state) # pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`. self.all_action_scores = all_q_values.detach() q_values = torch.sum(all_q_values * training_batch.action, 1, keepdim=True) loss = self.q_network_loss(q_values, target_q_values) # pyre-fixme[16]: `DQNTrainer` has no attribute `loss`. self.loss = loss.detach() loss.backward() self._maybe_run_optimizer(self.q_network_optimizer, self.minibatches_per_step) # Use the soft update rule to update target network self._maybe_soft_update(self.q_network, self.q_network_target, self.tau, self.minibatches_per_step) # Get Q-values of next states, used in computing cpe all_next_action_scores = self.q_network( training_batch.next_state).detach() logged_action_idxs = torch.argmax(training_batch.action, dim=1, keepdim=True) reward_loss, model_rewards, model_propensities = self._calculate_cpes( training_batch, training_batch.state, training_batch.next_state, self.all_action_scores, all_next_action_scores, logged_action_idxs, discount_tensor, not_done_mask, ) if self.maxq_learning: possible_actions_mask = training_batch.possible_actions_mask if self.bcq: action_on_policy = get_valid_actions_from_imitator( self.bcq_imitator, training_batch.state, self.bcq_drop_threshold) possible_actions_mask *= action_on_policy model_action_idxs = self.get_max_q_values( self.all_action_scores, possible_actions_mask if self.maxq_learning else training_batch.action, )[1] # pyre-fixme[16]: `DQNTrainer` has no attribute `notify_observers`. self.notify_observers( td_loss=self.loss, reward_loss=reward_loss, logged_actions=logged_action_idxs, logged_propensities=training_batch.extras.action_probability, logged_rewards=rewards, model_propensities=model_propensities, model_rewards=model_rewards, model_values=self.all_action_scores, model_action_idxs=model_action_idxs, ) self.loss_reporter.report( td_loss=self.loss, reward_loss=reward_loss, logged_actions=logged_action_idxs, logged_propensities=training_batch.extras.action_probability, logged_rewards=rewards, logged_values=None, # Compute at end of each epoch for CPE model_propensities=model_propensities, model_rewards=model_rewards, model_values=self.all_action_scores, model_values_on_logged_actions= None, # Compute at end of each epoch for CPE model_action_idxs=model_action_idxs, )
def train_step_gen(self, training_batch: rlt.DiscreteDqnInput, batch_idx: int): # TODO: calls to _maybe_run_optimizer removed, should be replaced with Trainer parameter assert isinstance(training_batch, rlt.DiscreteDqnInput) boosted_rewards = self.boost_rewards(training_batch.reward, training_batch.action) rewards = boosted_rewards discount_tensor = torch.full_like(rewards, self.gamma) not_done_mask = training_batch.not_terminal.float() assert not_done_mask.dim() == 2 if self.use_seq_num_diff_as_time_diff: assert self.multi_steps is None discount_tensor = torch.pow(self.gamma, training_batch.time_diff.float()) if self.multi_steps is not None: assert training_batch.step is not None # pyre-fixme[16]: `Optional` has no attribute `float`. discount_tensor = torch.pow(self.gamma, training_batch.step.float()) all_next_q_values, all_next_q_values_target = self.get_detached_q_values( training_batch.next_state) if self.maxq_learning: # Compute max a' Q(s', a') over all possible actions using target network possible_next_actions_mask = ( training_batch.possible_next_actions_mask.float()) if self.bcq: action_on_policy = get_valid_actions_from_imitator( self.bcq_imitator, training_batch.next_state, self.bcq_drop_threshold, ) possible_next_actions_mask *= action_on_policy next_q_values, max_q_action_idxs = self.get_max_q_values_with_target( all_next_q_values, all_next_q_values_target, possible_next_actions_mask, ) else: # SARSA next_q_values, max_q_action_idxs = self.get_max_q_values_with_target( all_next_q_values, all_next_q_values_target, training_batch.next_action, ) filtered_next_q_vals = next_q_values * not_done_mask target_q_values = rewards + (discount_tensor * filtered_next_q_vals) # Get Q-value of action taken all_q_values = self.q_network(training_batch.state) # pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`. self.all_action_scores = all_q_values.detach() q_values = torch.sum(all_q_values * training_batch.action, 1, keepdim=True) loss = self.q_network_loss(q_values, target_q_values) # pyre-fixme[16]: `DQNTrainer` has no attribute `loss`. self.loss = loss.detach() yield loss # Get Q-values of next states, used in computing cpe all_next_action_scores = self.q_network( training_batch.next_state).detach() logged_action_idxs = torch.argmax(training_batch.action, dim=1, keepdim=True) yield from self._calculate_cpes( training_batch, training_batch.state, training_batch.next_state, self.all_action_scores, all_next_action_scores, logged_action_idxs, discount_tensor, not_done_mask, ) if self.maxq_learning: possible_actions_mask = training_batch.possible_actions_mask if self.bcq: action_on_policy = get_valid_actions_from_imitator( self.bcq_imitator, training_batch.state, self.bcq_drop_threshold) possible_actions_mask *= action_on_policy # Do we ever use model_action_idxs computed below? model_action_idxs = self.get_max_q_values( self.all_action_scores, possible_actions_mask if self.maxq_learning else training_batch.action, )[1] self.reporter.log( td_loss=self.loss, logged_actions=logged_action_idxs, logged_propensities=training_batch.extras.action_probability, logged_rewards=rewards, logged_values=None, # Compute at end of each epoch for CPE model_values=self.all_action_scores, model_values_on_logged_actions= None, # Compute at end of each epoch for CPE model_action_idxs=model_action_idxs, ) # Use the soft update rule to update target network yield self.soft_update_result()