예제 #1
0
    def test_parametric_wrapper(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        action_normalization_parameters = {
            i: _cont_norm()
            for i in range(5, 9)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_preprocessor = Preprocessor(action_normalization_parameters,
                                           False)
        dqn = models.FullyConnectedCritic(
            state_dim=len(state_normalization_parameters),
            action_dim=len(action_normalization_parameters),
            sizes=[16],
            activations=["relu"],
        )
        dqn_with_preprocessor = ParametricDqnWithPreprocessor(
            dqn,
            state_preprocessor=state_preprocessor,
            action_preprocessor=action_preprocessor,
        )
        wrapper = ParametricDqnPredictorWrapper(dqn_with_preprocessor)

        input_prototype = dqn_with_preprocessor.input_prototype()
        output_action_names, q_value = wrapper(*input_prototype)
        self.assertEqual(output_action_names, ["Q"])
        self.assertEqual(q_value.shape, (1, 1))

        expected_output = dqn(
            rlt.FeatureData(state_preprocessor(*input_prototype[0])),
            rlt.FeatureData(action_preprocessor(*input_prototype[1])),
        )
        self.assertTrue((expected_output == q_value).all())
예제 #2
0
    def _test_seq2slate_model_with_preprocessor(
            self, model: str, output_arch: Seq2SlateOutputArch):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        candidate_normalization_parameters = {
            i: _cont_norm()
            for i in range(101, 106)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        candidate_preprocessor = Preprocessor(
            candidate_normalization_parameters, False)
        candidate_size = 10
        slate_size = 4

        seq2slate = None
        if model == "transformer":
            seq2slate = Seq2SlateTransformerNet(
                state_dim=len(state_normalization_parameters),
                candidate_dim=len(candidate_normalization_parameters),
                num_stacked_layers=2,
                num_heads=2,
                dim_model=10,
                dim_feedforward=10,
                max_src_seq_len=candidate_size,
                max_tgt_seq_len=slate_size,
                output_arch=output_arch,
                temperature=0.5,
            )
        else:
            raise NotImplementedError(f"model type {model} is unknown")

        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate, state_preprocessor, candidate_preprocessor, greedy=True)
        input_prototype = seq2slate_with_preprocessor.input_prototype()

        if seq2slate_with_preprocessor.can_be_traced():
            seq2slate_with_preprocessor_jit = torch.jit.trace(
                seq2slate_with_preprocessor,
                seq2slate_with_preprocessor.input_prototype(),
            )
        else:
            seq2slate_with_preprocessor_jit = torch.jit.script(
                seq2slate_with_preprocessor)
        expected_output = seq2slate_with_preprocessor(*input_prototype)
        jit_output = seq2slate_with_preprocessor_jit(*input_prototype)
        self.verify_results(expected_output, jit_output)

        # Test if scripted model can handle variable lengths of input
        input_prototype = change_cand_size_slate_ranking(input_prototype, 20)
        expected_output = seq2slate_with_preprocessor(*input_prototype)
        jit_output = seq2slate_with_preprocessor_jit(*input_prototype)
        self.verify_results(expected_output, jit_output)
예제 #3
0
    def test_discrete_wrapper(self):
        ids = range(1, 5)
        state_normalization_parameters = {i: _cont_norm() for i in ids}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        dqn = models.FullyConnectedDQN(
            state_dim=len(state_normalization_parameters),
            action_dim=action_dim,
            sizes=[16],
            activations=["relu"],
        )
        state_feature_config = rlt.ModelFeatureConfig(float_feature_infos=[
            rlt.FloatFeatureInfo(feature_id=i, name=f"feat_{i}") for i in ids
        ])
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names,
                                              state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()[0]
        output_action_names, q_values = wrapper(input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        state_with_presence = input_prototype.float_features_with_presence
        expected_output = dqn(
            rlt.FeatureData(state_preprocessor(*state_with_presence)))
        self.assertTrue((expected_output == q_values).all())
예제 #4
0
    def test_actor_wrapper(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        action_normalization_parameters = {
            i: _cont_action_norm()
            for i in range(101, 105)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        postprocessor = Postprocessor(action_normalization_parameters, False)

        # Test with FullyConnectedActor to make behavior deterministic
        actor = models.FullyConnectedActor(
            state_dim=len(state_normalization_parameters),
            action_dim=len(action_normalization_parameters),
            sizes=[16],
            activations=["relu"],
        )
        actor_with_preprocessor = ActorWithPreprocessor(
            actor, state_preprocessor, postprocessor)
        wrapper = ActorPredictorWrapper(actor_with_preprocessor)
        input_prototype = actor_with_preprocessor.input_prototype()
        action, _log_prob = wrapper(*input_prototype)
        self.assertEqual(action.shape,
                         (1, len(action_normalization_parameters)))

        expected_output = postprocessor(
            actor(rlt.FeatureData(
                state_preprocessor(*input_prototype[0]))).action)
        self.assertTrue((expected_output == action).all())
예제 #5
0
    def test_discrete_wrapper_with_id_list(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        action_dim = 2
        state_feature_config = rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=str(i), feature_id=i)
                for i in range(1, 5)
            ],
            id_list_feature_configs=[
                rlt.IdListFeatureConfig(name="A",
                                        feature_id=10,
                                        id_mapping_name="A_mapping")
            ],
            id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])},
        )
        embedding_concat = models.EmbeddingBagConcat(
            state_dim=len(state_normalization_parameters),
            model_feature_config=state_feature_config,
            embedding_dim=8,
        )
        dqn = models.Sequential(
            embedding_concat,
            rlt.TensorFeatureData(),
            models.FullyConnectedDQN(
                embedding_concat.output_dim,
                action_dim=action_dim,
                sizes=[16],
                activations=["relu"],
            ),
        )

        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            dqn, state_preprocessor, state_feature_config)
        action_names = ["L", "R"]
        wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor,
                                              action_names,
                                              state_feature_config)
        input_prototype = dqn_with_preprocessor.input_prototype()[0]
        output_action_names, q_values = wrapper(input_prototype)
        self.assertEqual(action_names, output_action_names)
        self.assertEqual(q_values.shape, (1, 2))

        feature_id_to_name = {
            config.feature_id: config.name
            for config in state_feature_config.id_list_feature_configs
        }
        state_id_list_features = {
            feature_id_to_name[k]: v
            for k, v in input_prototype.id_list_features.items()
        }
        state_with_presence = input_prototype.float_features_with_presence
        expected_output = dqn(
            rlt.FeatureData(
                float_features=state_preprocessor(*state_with_presence),
                id_list_features=state_id_list_features,
            ))
        self.assertTrue((expected_output == q_values).all())
예제 #6
0
    def _test_seq2slate_wrapper(self, model: str,
                                output_arch: Seq2SlateOutputArch):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        candidate_normalization_parameters = {
            i: _cont_norm()
            for i in range(101, 106)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        candidate_preprocessor = Preprocessor(
            candidate_normalization_parameters, False)
        candidate_size = 10
        slate_size = 4

        seq2slate = None
        if model == "transformer":
            seq2slate = Seq2SlateTransformerNet(
                state_dim=len(state_normalization_parameters),
                candidate_dim=len(candidate_normalization_parameters),
                num_stacked_layers=2,
                num_heads=2,
                dim_model=10,
                dim_feedforward=10,
                max_src_seq_len=candidate_size,
                max_tgt_seq_len=slate_size,
                output_arch=output_arch,
                temperature=0.5,
            )
        else:
            raise NotImplementedError(f"model type {model} is unknown")

        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate, state_preprocessor, candidate_preprocessor, greedy=True)
        wrapper = Seq2SlatePredictorWrapper(seq2slate_with_preprocessor)

        (
            state_input_prototype,
            candidate_input_prototype,
        ) = seq2slate_with_preprocessor.input_prototype()
        wrapper_output = wrapper(state_input_prototype,
                                 candidate_input_prototype)

        ranking_input = seq2slate_input_prototype_to_ranking_input(
            state_input_prototype,
            candidate_input_prototype,
            state_preprocessor,
            candidate_preprocessor,
        )
        expected_output = seq2slate(
            ranking_input,
            mode=Seq2SlateMode.RANK_MODE,
            tgt_seq_len=candidate_size,
            greedy=True,
        )
        self.validate_seq2slate_output(expected_output, wrapper_output)

        # Test Seq2SlatePredictorWrapper can handle variable lengths of inputs
        random_length = random.randint(candidate_size + 1, candidate_size * 2)
        (
            state_input_prototype,
            candidate_input_prototype,
        ) = change_cand_size_slate_ranking(
            seq2slate_with_preprocessor.input_prototype(), random_length)
        wrapper_output = wrapper(state_input_prototype,
                                 candidate_input_prototype)

        ranking_input = seq2slate_input_prototype_to_ranking_input(
            state_input_prototype,
            candidate_input_prototype,
            state_preprocessor,
            candidate_preprocessor,
        )
        expected_output = seq2slate(
            ranking_input,
            mode=Seq2SlateMode.RANK_MODE,
            tgt_seq_len=random_length,
            greedy=True,
        )
        self.validate_seq2slate_output(expected_output, wrapper_output)