Exemplo n.º 1
0
    def __init__(self, seq2reward_network: Seq2RewardNetwork,
                 params: Seq2RewardTrainerParameters):
        super().__init__()
        self.seq2reward_network = seq2reward_network
        self.params = params

        # Turning off Q value output during training:
        self.view_q_value = params.view_q_value
        # permutations used to do planning
        self.all_permut = gen_permutations(params.multi_steps,
                                           len(self.params.action_names))
        self.mse_loss = nn.MSELoss(reduction="mean")

        # Predict how many steps are remaining from the current step
        self.step_predict_network = FullyConnectedNetwork(
            [
                self.seq2reward_network.state_dim,
                self.params.step_predict_net_size,
                self.params.step_predict_net_size,
                self.params.multi_steps,
            ],
            ["relu", "relu", "linear"],
            use_layer_norm=False,
        )
        self.step_loss = nn.CrossEntropyLoss(reduction="mean")
Exemplo n.º 2
0
 def test_get_Q(self):
     NUM_ACTION = 2
     MULTI_STEPS = 3
     BATCH_SIZE = 2
     STATE_DIM = 4
     all_permut = gen_permutations(MULTI_STEPS, NUM_ACTION)
     seq2reward_network = FakeSeq2RewardNetwork()
     batch = rlt.MemoryNetworkInput(
         state=rlt.FeatureData(
             float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM)
         ),
         next_state=rlt.FeatureData(
             float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM)
         ),
         action=rlt.FeatureData(
             float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, NUM_ACTION)
         ),
         reward=torch.zeros(1),
         time_diff=torch.zeros(1),
         step=torch.zeros(1),
         not_terminal=torch.zeros(1),
     )
     q_values = get_Q(seq2reward_network, batch, all_permut)
     expected_q_values = torch.tensor([[11.0, 111.0], [11.0, 111.0]])
     logger.info(f"q_values: {q_values}")
     assert torch.all(expected_q_values == q_values)
Exemplo n.º 3
0
 def test_get_Q(self):
     NUM_ACTION = 2
     MULTI_STEPS = 3
     BATCH_SIZE = 2
     STATE_DIM = 4
     all_permut = gen_permutations(MULTI_STEPS, NUM_ACTION)
     seq2reward_network = FakeSeq2RewardNetwork()
     state = torch.zeros(BATCH_SIZE, STATE_DIM)
     q_values = get_Q(seq2reward_network, state, all_permut)
     expected_q_values = torch.tensor([[11.0, 111.0], [11.0, 111.0]])
     logger.info(f"q_values: {q_values}")
     assert torch.all(expected_q_values == q_values)
Exemplo n.º 4
0
    def __init__(
        self,
        compress_model_network: FullyConnectedNetwork,
        seq2reward_network: Seq2RewardNetwork,
        params: Seq2RewardTrainerParameters,
    ):
        super().__init__()
        self.compress_model_network = compress_model_network
        self.seq2reward_network = seq2reward_network
        self.params = params

        # permutations used to do planning
        self.all_permut = gen_permutations(params.multi_steps,
                                           len(self.params.action_names))
Exemplo n.º 5
0
 def __init__(
     self,
     model: ModelBase,  # acc_reward prediction model
     state_preprocessor: Preprocessor,
     seq_len: int,
     num_action: int,
 ):
     """
     Since TorchScript unable to trace control-flow, we
     have to generate the action enumerations as constants
     here so that trace can use them directly.
     """
     super().__init__(model, state_preprocessor, rlt.ModelFeatureConfig())
     self.seq_len = seq_len
     self.num_action = num_action
     self.all_permut = gen_permutations(seq_len, num_action)
Exemplo n.º 6
0
    def __init__(self, seq2reward_network: Seq2RewardNetwork,
                 params: Seq2RewardTrainerParameters):
        self.seq2reward_network = seq2reward_network
        self.params = params
        self.optimizer = torch.optim.Adam(self.seq2reward_network.parameters(),
                                          lr=params.learning_rate)
        self.minibatch_size = self.params.batch_size
        self.loss_reporter = NoOpLossReporter()

        # PageHandler must use this to activate evaluator:
        self.calc_cpe_in_training = True
        # Turning off Q value output during training:
        self.view_q_value = params.view_q_value
        # permutations used to do planning
        device = get_device(self.seq2reward_network)
        self.all_permut = gen_permutations(
            params.multi_steps, len(self.params.action_names)).to(device)
Exemplo n.º 7
0
 def __init__(
     self,
     model: ModelBase,  # acc_reward prediction model
     step_model: ModelBase,  # step prediction model
     state_preprocessor: Preprocessor,
     seq_len: int,
     num_action: int,
 ):
     """
     The difference with Seq2RewardWithPreprocessor:
     This wrapper will plan for different look_ahead steps (between 1 and seq_len),
     and merge results according to look_ahead step prediction probabilities.
     """
     super().__init__(model, state_preprocessor, rlt.ModelFeatureConfig())
     self.step_model = step_model
     self.seq_len = seq_len
     self.num_action = num_action
     # key: seq_len, value: all possible action sequences of length seq_len
     self.all_permut = {
         s + 1: gen_permutations(s + 1, num_action) for s in range(seq_len)
     }
Exemplo n.º 8
0
    def __init__(
        self,
        compress_model_network: FullyConnectedNetwork,
        seq2reward_network: Seq2RewardNetwork,
        params: Seq2RewardTrainerParameters,
    ):
        self.compress_model_network = compress_model_network
        self.seq2reward_network = seq2reward_network
        self.params = params
        self.optimizer = torch.optim.Adam(
            self.compress_model_network.parameters(),
            lr=params.compress_model_learning_rate,
        )
        self.minibatch_size = self.params.compress_model_batch_size
        self.loss_reporter = NoOpLossReporter()

        # PageHandler must use this to activate evaluator:
        self.calc_cpe_in_training = True
        # permutations used to do planning
        device = get_device(self.compress_model_network)
        self.all_permut = gen_permutations(
            params.multi_steps, len(self.params.action_names)).to(device)
Exemplo n.º 9
0
    def get_Q(
        self,
        batch: rlt.MemoryNetworkInput,
        batch_size: int,
        seq_len: int,
        num_action: int,
    ) -> torch.Tensor:
        if not self.view_q_value:
            return torch.zeros(batch_size, num_action)
        try:
            # pyre-fixme[16]: `Seq2RewardTrainer` has no attribute `all_permut`.
            self.all_permut
        except AttributeError:
            self.all_permut = gen_permutations(seq_len, num_action)
            # pyre-fixme[16]: `Seq2RewardTrainer` has no attribute `num_permut`.
            self.num_permut = self.all_permut.size(1)

        # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`.
        preprocessed_state = batch.state.float_features.repeat_interleave(
            self.num_permut, dim=1)
        state_feature_vector = rlt.FeatureData(preprocessed_state)

        # expand action to match the expanded state sequence
        action = self.all_permut.repeat(1, batch_size, 1)
        reward = self.seq2reward_network(
            state_feature_vector, rlt.FeatureData(action)).acc_reward.reshape(
                batch_size, num_action, self.num_permut // num_action)

        # The permuations are generated with lexical order
        # the output has shape [num_perm, num_action,1]
        # that means we can aggregate on the max reward
        # then reshape it to (BATCH_SIZE, ACT_DIM)
        max_reward = (
            # pyre-fixme[16]: `Tuple` has no attribute `values`.
            torch.max(reward,
                      2).values.cpu().detach().reshape(batch_size, num_action))

        return max_reward
Exemplo n.º 10
0
    def __init__(
        self, seq2reward_network: Seq2RewardNetwork, params: Seq2RewardTrainerParameters
    ):
        self.seq2reward_network = seq2reward_network
        self.params = params
        self.mse_optimizer = torch.optim.Adam(
            self.seq2reward_network.parameters(), lr=params.learning_rate
        )
        self.minibatch_size = self.params.batch_size
        self.loss_reporter = NoOpLossReporter()

        # PageHandler must use this to activate evaluator:
        self.calc_cpe_in_training = True
        # Turning off Q value output during training:
        self.view_q_value = params.view_q_value
        # permutations used to do planning
        self.all_permut = gen_permutations(
            params.multi_steps, len(self.params.action_names)
        )
        self.mse_loss = nn.MSELoss(reduction="mean")

        # Predict how many steps are remaining from the current step
        self.step_predict_network = FullyConnectedNetwork(
            [
                self.seq2reward_network.state_dim,
                self.params.step_predict_net_size,
                self.params.step_predict_net_size,
                self.params.multi_steps,
            ],
            ["relu", "relu", "linear"],
            use_layer_norm=False,
        )
        self.step_loss = nn.CrossEntropyLoss(reduction="mean")
        self.step_optimizer = torch.optim.Adam(
            self.step_predict_network.parameters(), lr=params.learning_rate
        )
Exemplo n.º 11
0
 def _test_gen_permutations(self, SEQ_LEN, NUM_ACTION, expected_outcome):
     # expected shape: SEQ_LEN, PERM_NUM, ACTION_DIM
     result = gen_permutations(SEQ_LEN, NUM_ACTION)
     assert result.shape == (SEQ_LEN, NUM_ACTION**SEQ_LEN, NUM_ACTION)
     outcome = torch.argmax(result.transpose(0, 1), dim=-1)
     assert torch.all(outcome == expected_outcome)