def __init__(self, seq2reward_network: Seq2RewardNetwork, params: Seq2RewardTrainerParameters): super().__init__() self.seq2reward_network = seq2reward_network self.params = params # Turning off Q value output during training: self.view_q_value = params.view_q_value # permutations used to do planning self.all_permut = gen_permutations(params.multi_steps, len(self.params.action_names)) self.mse_loss = nn.MSELoss(reduction="mean") # Predict how many steps are remaining from the current step self.step_predict_network = FullyConnectedNetwork( [ self.seq2reward_network.state_dim, self.params.step_predict_net_size, self.params.step_predict_net_size, self.params.multi_steps, ], ["relu", "relu", "linear"], use_layer_norm=False, ) self.step_loss = nn.CrossEntropyLoss(reduction="mean")
def test_get_Q(self): NUM_ACTION = 2 MULTI_STEPS = 3 BATCH_SIZE = 2 STATE_DIM = 4 all_permut = gen_permutations(MULTI_STEPS, NUM_ACTION) seq2reward_network = FakeSeq2RewardNetwork() batch = rlt.MemoryNetworkInput( state=rlt.FeatureData( float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM) ), next_state=rlt.FeatureData( float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM) ), action=rlt.FeatureData( float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, NUM_ACTION) ), reward=torch.zeros(1), time_diff=torch.zeros(1), step=torch.zeros(1), not_terminal=torch.zeros(1), ) q_values = get_Q(seq2reward_network, batch, all_permut) expected_q_values = torch.tensor([[11.0, 111.0], [11.0, 111.0]]) logger.info(f"q_values: {q_values}") assert torch.all(expected_q_values == q_values)
def test_get_Q(self): NUM_ACTION = 2 MULTI_STEPS = 3 BATCH_SIZE = 2 STATE_DIM = 4 all_permut = gen_permutations(MULTI_STEPS, NUM_ACTION) seq2reward_network = FakeSeq2RewardNetwork() state = torch.zeros(BATCH_SIZE, STATE_DIM) q_values = get_Q(seq2reward_network, state, all_permut) expected_q_values = torch.tensor([[11.0, 111.0], [11.0, 111.0]]) logger.info(f"q_values: {q_values}") assert torch.all(expected_q_values == q_values)
def __init__( self, compress_model_network: FullyConnectedNetwork, seq2reward_network: Seq2RewardNetwork, params: Seq2RewardTrainerParameters, ): super().__init__() self.compress_model_network = compress_model_network self.seq2reward_network = seq2reward_network self.params = params # permutations used to do planning self.all_permut = gen_permutations(params.multi_steps, len(self.params.action_names))
def __init__( self, model: ModelBase, # acc_reward prediction model state_preprocessor: Preprocessor, seq_len: int, num_action: int, ): """ Since TorchScript unable to trace control-flow, we have to generate the action enumerations as constants here so that trace can use them directly. """ super().__init__(model, state_preprocessor, rlt.ModelFeatureConfig()) self.seq_len = seq_len self.num_action = num_action self.all_permut = gen_permutations(seq_len, num_action)
def __init__(self, seq2reward_network: Seq2RewardNetwork, params: Seq2RewardTrainerParameters): self.seq2reward_network = seq2reward_network self.params = params self.optimizer = torch.optim.Adam(self.seq2reward_network.parameters(), lr=params.learning_rate) self.minibatch_size = self.params.batch_size self.loss_reporter = NoOpLossReporter() # PageHandler must use this to activate evaluator: self.calc_cpe_in_training = True # Turning off Q value output during training: self.view_q_value = params.view_q_value # permutations used to do planning device = get_device(self.seq2reward_network) self.all_permut = gen_permutations( params.multi_steps, len(self.params.action_names)).to(device)
def __init__( self, model: ModelBase, # acc_reward prediction model step_model: ModelBase, # step prediction model state_preprocessor: Preprocessor, seq_len: int, num_action: int, ): """ The difference with Seq2RewardWithPreprocessor: This wrapper will plan for different look_ahead steps (between 1 and seq_len), and merge results according to look_ahead step prediction probabilities. """ super().__init__(model, state_preprocessor, rlt.ModelFeatureConfig()) self.step_model = step_model self.seq_len = seq_len self.num_action = num_action # key: seq_len, value: all possible action sequences of length seq_len self.all_permut = { s + 1: gen_permutations(s + 1, num_action) for s in range(seq_len) }
def __init__( self, compress_model_network: FullyConnectedNetwork, seq2reward_network: Seq2RewardNetwork, params: Seq2RewardTrainerParameters, ): self.compress_model_network = compress_model_network self.seq2reward_network = seq2reward_network self.params = params self.optimizer = torch.optim.Adam( self.compress_model_network.parameters(), lr=params.compress_model_learning_rate, ) self.minibatch_size = self.params.compress_model_batch_size self.loss_reporter = NoOpLossReporter() # PageHandler must use this to activate evaluator: self.calc_cpe_in_training = True # permutations used to do planning device = get_device(self.compress_model_network) self.all_permut = gen_permutations( params.multi_steps, len(self.params.action_names)).to(device)
def get_Q( self, batch: rlt.MemoryNetworkInput, batch_size: int, seq_len: int, num_action: int, ) -> torch.Tensor: if not self.view_q_value: return torch.zeros(batch_size, num_action) try: # pyre-fixme[16]: `Seq2RewardTrainer` has no attribute `all_permut`. self.all_permut except AttributeError: self.all_permut = gen_permutations(seq_len, num_action) # pyre-fixme[16]: `Seq2RewardTrainer` has no attribute `num_permut`. self.num_permut = self.all_permut.size(1) # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`. preprocessed_state = batch.state.float_features.repeat_interleave( self.num_permut, dim=1) state_feature_vector = rlt.FeatureData(preprocessed_state) # expand action to match the expanded state sequence action = self.all_permut.repeat(1, batch_size, 1) reward = self.seq2reward_network( state_feature_vector, rlt.FeatureData(action)).acc_reward.reshape( batch_size, num_action, self.num_permut // num_action) # The permuations are generated with lexical order # the output has shape [num_perm, num_action,1] # that means we can aggregate on the max reward # then reshape it to (BATCH_SIZE, ACT_DIM) max_reward = ( # pyre-fixme[16]: `Tuple` has no attribute `values`. torch.max(reward, 2).values.cpu().detach().reshape(batch_size, num_action)) return max_reward
def __init__( self, seq2reward_network: Seq2RewardNetwork, params: Seq2RewardTrainerParameters ): self.seq2reward_network = seq2reward_network self.params = params self.mse_optimizer = torch.optim.Adam( self.seq2reward_network.parameters(), lr=params.learning_rate ) self.minibatch_size = self.params.batch_size self.loss_reporter = NoOpLossReporter() # PageHandler must use this to activate evaluator: self.calc_cpe_in_training = True # Turning off Q value output during training: self.view_q_value = params.view_q_value # permutations used to do planning self.all_permut = gen_permutations( params.multi_steps, len(self.params.action_names) ) self.mse_loss = nn.MSELoss(reduction="mean") # Predict how many steps are remaining from the current step self.step_predict_network = FullyConnectedNetwork( [ self.seq2reward_network.state_dim, self.params.step_predict_net_size, self.params.step_predict_net_size, self.params.multi_steps, ], ["relu", "relu", "linear"], use_layer_norm=False, ) self.step_loss = nn.CrossEntropyLoss(reduction="mean") self.step_optimizer = torch.optim.Adam( self.step_predict_network.parameters(), lr=params.learning_rate )
def _test_gen_permutations(self, SEQ_LEN, NUM_ACTION, expected_outcome): # expected shape: SEQ_LEN, PERM_NUM, ACTION_DIM result = gen_permutations(SEQ_LEN, NUM_ACTION) assert result.shape == (SEQ_LEN, NUM_ACTION**SEQ_LEN, NUM_ACTION) outcome = torch.argmax(result.transpose(0, 1), dim=-1) assert torch.all(outcome == expected_outcome)