class Seq2RewardTrainer(ReAgentLightningModule): """Trainer for Seq2Reward""" def __init__(self, seq2reward_network: Seq2RewardNetwork, params: Seq2RewardTrainerParameters): super().__init__() self.seq2reward_network = seq2reward_network self.params = params # Turning off Q value output during training: self.view_q_value = params.view_q_value # permutations used to do planning self.all_permut = gen_permutations(params.multi_steps, len(self.params.action_names)) self.mse_loss = nn.MSELoss(reduction="mean") # Predict how many steps are remaining from the current step self.step_predict_network = FullyConnectedNetwork( [ self.seq2reward_network.state_dim, self.params.step_predict_net_size, self.params.step_predict_net_size, self.params.multi_steps, ], ["relu", "relu", "linear"], use_layer_norm=False, ) self.step_loss = nn.CrossEntropyLoss(reduction="mean") def configure_optimizers(self): optimizers = [] optimizers.append({ "optimizer": torch.optim.Adam(self.seq2reward_network.parameters(), lr=self.params.learning_rate), }) optimizers.append( { "optimizer": torch.optim.Adam(self.step_predict_network.parameters(), lr=self.params.learning_rate) }, ) return optimizers def train_step_gen(self, training_batch: rlt.MemoryNetworkInput, batch_idx: int): mse_loss = self.get_mse_loss(training_batch) detached_mse_loss = mse_loss.cpu().detach().item() yield mse_loss step_entropy_loss = self.get_step_entropy_loss(training_batch) detached_step_entropy_loss = step_entropy_loss.cpu().detach().item() if self.view_q_value: state_first_step = training_batch.state.float_features[0] q_values = (get_Q( self.seq2reward_network, state_first_step, self.all_permut, ).cpu().mean(0).tolist()) else: q_values = [0] * len(self.params.action_names) step_probability = (get_step_prediction( self.step_predict_network, training_batch).cpu().mean(dim=0).numpy()) logger.info( f"Seq2Reward trainer output: mse_loss={detached_mse_loss}, " f"step_entropy_loss={detached_step_entropy_loss}, q_values={q_values}, " f"step_probability={step_probability}") self.reporter.log( mse_loss=detached_mse_loss, step_entropy_loss=detached_step_entropy_loss, q_values=[q_values], ) yield step_entropy_loss # pyre-ignore inconsistent override because lightning doesn't use types def validation_step(self, batch: rlt.MemoryNetworkInput, batch_idx: int): detached_mse_loss = self.get_mse_loss(batch).cpu().detach().item() detached_step_entropy_loss = ( self.get_step_entropy_loss(batch).cpu().detach().item()) state_first_step = batch.state.float_features[0] # shape: batch_size, action_dim q_values_all_action_all_data = get_Q( self.seq2reward_network, state_first_step, self.all_permut, ).cpu() q_values = q_values_all_action_all_data.mean(0).tolist() action_distribution = torch.bincount( torch.argmax(q_values_all_action_all_data, dim=1), minlength=len(self.params.action_names), ) # normalize action_distribution = (action_distribution.float() / torch.sum(action_distribution)).tolist() self.reporter.log( eval_mse_loss=detached_mse_loss, eval_step_entropy_loss=detached_step_entropy_loss, eval_q_values=[q_values], eval_action_distribution=[action_distribution], ) return ( detached_mse_loss, detached_step_entropy_loss, q_values, action_distribution, ) def get_mse_loss(self, training_batch: rlt.MemoryNetworkInput): """ Compute losses: MSE(predicted_acc_reward, target_acc_reward) :param training_batch: training_batch has these fields: - state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor - action: (SEQ_LEN, BATCH_SIZE, ACTION_DIM) torch tensor - reward: (SEQ_LEN, BATCH_SIZE) torch tensor :returns: mse loss on reward """ # pyre-fixme[16]: Optional type has no attribute `flatten`. valid_step = training_batch.valid_step.flatten() seq2reward_output = self.seq2reward_network( training_batch.state, rlt.FeatureData(training_batch.action), valid_step, ) predicted_acc_reward = seq2reward_output.acc_reward seq_len, batch_size = training_batch.reward.size() gamma = self.params.gamma gamma_mask = (torch.Tensor([[gamma**i for i in range(seq_len)] for _ in range(batch_size)]).transpose( 0, 1).to(training_batch.reward.device)) target_acc_rewards = torch.cumsum(training_batch.reward * gamma_mask, dim=0) target_acc_reward = target_acc_rewards[ valid_step - 1, torch.arange(batch_size)].unsqueeze(1) # make sure the prediction and target tensors have the same size # the size should both be (BATCH_SIZE, 1) in this case. assert (predicted_acc_reward.size() == target_acc_reward.size() ), f"{predicted_acc_reward.size()}!={target_acc_reward.size()}" return self.mse_loss(predicted_acc_reward, target_acc_reward) def get_step_entropy_loss(self, training_batch: rlt.MemoryNetworkInput): """ Compute cross-entropy losses of step predictions :param training_batch: training_batch has these fields: - state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor - action: (SEQ_LEN, BATCH_SIZE, ACTION_DIM) torch tensor - reward: (SEQ_LEN, BATCH_SIZE) torch tensor :returns: step_entropy_loss on step prediction """ # pyre-fixme[16]: Optional type has no attribute `flatten`. valid_step = training_batch.valid_step.flatten() first_step_state = training_batch.state.float_features[0] valid_step_output = self.step_predict_network(first_step_state) # step loss's target is zero-based indexed, so subtract 1 from valid_step return self.step_loss(valid_step_output, valid_step - 1) def warm_start_components(self): components = ["seq2reward_network"] return components
class Seq2RewardTrainer(Trainer): """ Trainer for Seq2Reward """ def __init__( self, seq2reward_network: Seq2RewardNetwork, params: Seq2RewardTrainerParameters ): self.seq2reward_network = seq2reward_network self.params = params self.mse_optimizer = torch.optim.Adam( self.seq2reward_network.parameters(), lr=params.learning_rate ) self.minibatch_size = self.params.batch_size self.loss_reporter = NoOpLossReporter() # PageHandler must use this to activate evaluator: self.calc_cpe_in_training = True # Turning off Q value output during training: self.view_q_value = params.view_q_value # permutations used to do planning self.all_permut = gen_permutations( params.multi_steps, len(self.params.action_names) ) self.mse_loss = nn.MSELoss(reduction="mean") # Predict how many steps are remaining from the current step self.step_predict_network = FullyConnectedNetwork( [ self.seq2reward_network.state_dim, self.params.step_predict_net_size, self.params.step_predict_net_size, self.params.multi_steps, ], ["relu", "relu", "linear"], use_layer_norm=False, ) self.step_loss = nn.CrossEntropyLoss(reduction="mean") self.step_optimizer = torch.optim.Adam( self.step_predict_network.parameters(), lr=params.learning_rate ) def train(self, training_batch: rlt.MemoryNetworkInput): mse_loss, step_entropy_loss = self.get_loss(training_batch) self.mse_optimizer.zero_grad() mse_loss.backward() self.mse_optimizer.step() self.step_optimizer.zero_grad() step_entropy_loss.backward() self.step_optimizer.step() detached_mse_loss = mse_loss.cpu().detach().item() detached_step_entropy_loss = step_entropy_loss.cpu().detach().item() if self.view_q_value: state_first_step = training_batch.state.float_features[0] q_values = ( get_Q( self.seq2reward_network, state_first_step, self.all_permut, ) .cpu() .mean(0) .tolist() ) else: q_values = [0] * len(self.params.action_names) step_probability = ( get_step_prediction(self.step_predict_network, training_batch) .cpu() .mean(dim=0) .numpy() ) logger.info( f"Seq2Reward trainer output: mse_loss={detached_mse_loss}, " f"step_entropy_loss={detached_step_entropy_loss}, q_values={q_values}, " f"step_probability={step_probability}" ) # pyre-fixme[16]: `Seq2RewardTrainer` has no attribute `notify_observers`. self.notify_observers( mse_loss=detached_mse_loss, step_entropy_loss=detached_step_entropy_loss, q_values=[q_values], ) return (detached_mse_loss, detached_step_entropy_loss, q_values) def get_loss(self, training_batch: rlt.MemoryNetworkInput): """ Compute losses: MSE(predicted_acc_reward, target_acc_reward) :param training_batch: training_batch has these fields: - state: (SEQ_LEN, BATCH_SIZE, STATE_DIM) torch tensor - action: (SEQ_LEN, BATCH_SIZE, ACTION_DIM) torch tensor - reward: (SEQ_LEN, BATCH_SIZE) torch tensor :returns: mse loss on reward step_entropy_loss on step prediction """ # pyre-fixme[16]: Optional type has no attribute `flatten`. valid_reward_len = training_batch.valid_next_seq_len.flatten() first_step_state = training_batch.state.float_features[0] valid_reward_len_output = self.step_predict_network(first_step_state) step_entropy_loss = self.step_loss( valid_reward_len_output, valid_reward_len - 1 ) seq2reward_output = self.seq2reward_network( training_batch.state, rlt.FeatureData(training_batch.action), valid_reward_len, ) predicted_acc_reward = seq2reward_output.acc_reward seq_len, batch_size = training_batch.reward.size() gamma = self.params.gamma gamma_mask = ( torch.Tensor( [[gamma ** i for i in range(seq_len)] for _ in range(batch_size)] ) .transpose(0, 1) .to(training_batch.reward.device) ) target_acc_rewards = torch.cumsum(training_batch.reward * gamma_mask, dim=0) target_acc_reward = target_acc_rewards[ valid_reward_len - 1, torch.arange(batch_size) ].unsqueeze(1) # make sure the prediction and target tensors have the same size # the size should both be (BATCH_SIZE, 1) in this case. assert ( predicted_acc_reward.size() == target_acc_reward.size() ), f"{predicted_acc_reward.size()}!={target_acc_reward.size()}" mse = self.mse_loss(predicted_acc_reward, target_acc_reward) return mse, step_entropy_loss def warm_start_components(self): components = ["seq2reward_network"] return components