示例#1
0
    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        # Run state feature identification
        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec,
            InputColumn.STATE_FEATURES,
            self.get_state_preprocessing_options(),
        )

        # Run action feature identification
        action_normalization_parameters = identify_normalization_parameters(
            input_table_spec,
            InputColumn.ACTION,
            self.get_action_preprocessing_options(),
        )

        return {
            NormalizationKey.STATE:
            NormalizationData(
                dense_normalization_parameters=state_normalization_parameters),
            NormalizationKey.ACTION:
            NormalizationData(
                dense_normalization_parameters=action_normalization_parameters
            ),
        }
 def _test_parametric_dqn_net_builder(
     self, chooser: ParametricDQNNetBuilder__Union
 ) -> None:
     builder = chooser.value
     state_dim = 3
     state_normalization_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(
                 feature_type=CONTINUOUS, mean=0.0, stddev=1.0
             )
             for i in range(state_dim)
         }
     )
     action_dim = 2
     action_normalization_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(
                 feature_type=CONTINUOUS, mean=0.0, stddev=1.0
             )
             for i in range(action_dim)
         }
     )
     q_network = builder.build_q_network(
         state_normalization_data, action_normalization_data
     )
     x = q_network.input_prototype()
     y = q_network(*x)
     self.assertEqual(y.shape, (1, 1))
     serving_module = builder.build_serving_module(
         q_network, state_normalization_data, action_normalization_data
     )
     self.assertIsInstance(serving_module, ParametricDqnPredictorWrapper)
示例#3
0
def build_normalizer(env: EnvWrapper) -> Dict[str, NormalizationData]:
    try:
        return env.normalization_data
    except AttributeError:
        # TODO: make this a property of EnvWrapper?
        # pyre-fixme[16]: Module `envs` has no attribute `RecSim`.
        if HAS_RECSIM and isinstance(env, RecSim):
            return {
                NormalizationKey.STATE:
                NormalizationData(
                    dense_normalization_parameters=only_continuous_normalizer(
                        list(range(env.observation_space["user"].shape[0])))),
                NormalizationKey.ITEM:
                NormalizationData(
                    dense_normalization_parameters=only_continuous_normalizer(
                        list(range(env.observation_space["doc"]
                                   ["0"].shape[0])))),
            }
        return {
            NormalizationKey.STATE:
            NormalizationData(
                dense_normalization_parameters=build_state_normalizer(env)),
            NormalizationKey.ACTION:
            NormalizationData(
                dense_normalization_parameters=build_action_normalizer(env)),
        }
 def _test_actor_net_builder(
         self, chooser: ContinuousActorNetBuilder__Union) -> None:
     builder = chooser.value
     state_dim = 3
     state_normalization_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(
                 feature_type=CONTINUOUS, mean=0.0, stddev=1.0)
             for i in range(state_dim)
         })
     action_dim = 2
     action_normalization_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(
                 feature_type=builder.default_action_preprocessing,
                 min_value=0.0,
                 max_value=1.0,
             )
             for i in range(action_dim)
         })
     actor_network = builder.build_actor(state_normalization_data,
                                         action_normalization_data)
     x = actor_network.input_prototype()
     y = actor_network(x)
     action = y.action
     log_prob = y.log_prob
     self.assertEqual(action.shape, (1, action_dim))
     self.assertEqual(log_prob.shape, (1, 1))
     serving_module = builder.build_serving_module(
         actor_network, state_normalization_data, action_normalization_data)
     self.assertIsInstance(serving_module, ActorPredictorWrapper)
示例#5
0
    def _test_discrete_dqn_net_builder(
        self,
        chooser: DiscreteDQNNetBuilder__Union,
        state_feature_config: Optional[rlt.ModelFeatureConfig] = None,
        serving_module_class=DiscreteDqnPredictorWrapper,
    ) -> None:
        builder = chooser.value
        state_dim = 3
        state_feature_config = state_feature_config or rlt.ModelFeatureConfig(
            float_feature_infos=[
                rlt.FloatFeatureInfo(name=f"f{i}", feature_id=i)
                for i in range(state_dim)
            ])
        state_dim = len(state_feature_config.float_feature_infos)

        state_normalization_data = NormalizationData(
            dense_normalization_parameters={
                fi.feature_id: NormalizationParameters(
                    feature_type=CONTINUOUS, mean=0.0, stddev=1.0)
                for fi in state_feature_config.float_feature_infos
            })

        action_names = ["L", "R"]
        q_network = builder.build_q_network(state_feature_config,
                                            state_normalization_data,
                                            len(action_names))
        x = q_network.input_prototype()
        y = q_network(x)
        self.assertEqual(y.shape, (1, 2))
        serving_module = builder.build_serving_module(
            q_network, state_normalization_data, action_names,
            state_feature_config)
        self.assertIsInstance(serving_module, serving_module_class)
示例#6
0
 def normalization_data(self):
     return {
         NormalizationKey.STATE:
         NormalizationData(
             dense_normalization_parameters=only_continuous_normalizer(
                 list(range(self.num_arms)), MU_LOW, MU_HIGH))
     }
示例#7
0
    def run_feature_identification(
        self, input_table_spec: TableSpec
    ) -> Dict[str, NormalizationData]:
        state_preprocessing_options = (
            self.model_manager.state_preprocessing_options or PreprocessingOptions()
        )
        state_features = [
            ffi.feature_id
            for ffi in self.model_manager.state_feature_config.float_feature_infos
        ]
        logger.info(f"state allowedlist_features: {state_features}")
        state_preprocessing_options = replace(
            state_preprocessing_options, allowedlist_features=state_features
        )

        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options
        )
        if self.model_manager.discrete_action_names:
            return {
                NormalizationKey.STATE: NormalizationData(
                    dense_normalization_parameters=state_normalization_parameters
                )
            }
        # Run action feature identification
        action_preprocessing_options = (
            self.model_manager.action_preprocessing_options or PreprocessingOptions()
        )
        action_features = [
            ffi.feature_id
            for ffi in self.model_manager.action_feature_config.float_feature_infos
        ]
        logger.info(f"action allowedlist_features: {action_features}")
        action_preprocessing_options = replace(
            action_preprocessing_options, allowedlist_features=action_features
        )
        action_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.ACTION, action_preprocessing_options
        )
        return {
            NormalizationKey.STATE: NormalizationData(
                dense_normalization_parameters=state_normalization_parameters
            ),
            NormalizationKey.ACTION: NormalizationData(
                dense_normalization_parameters=action_normalization_parameters
            ),
        }
def _create_norm(dim, offset=0):
    normalization_data = NormalizationData(
        dense_normalization_parameters={
            i: NormalizationParameters(
                feature_type=CONTINUOUS, mean=0.0, stddev=1.0)
            for i in range(offset, dim + offset)
        })
    return normalization_data
示例#9
0
 def run_feature_identification(
     self, input_table_spec: TableSpec
 ) -> Dict[str, NormalizationData]:
     state_preprocessing_options = (
         self._state_preprocessing_options or PreprocessingOptions()
     )
     state_features = [
         ffi.feature_id for ffi in self.state_feature_config.float_feature_infos
     ]
     logger.info(f"state allowedlist_features: {state_features}")
     state_preprocessing_options = state_preprocessing_options._replace(
         allowedlist_features=state_features
     )
     state_normalization_parameters = identify_normalization_parameters(
         input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options
     )
     item_preprocessing_options = (
         self._item_preprocessing_options or PreprocessingOptions()
     )
     item_features = [
         ffi.feature_id for ffi in self.item_feature_config.float_feature_infos
     ]
     logger.info(f"item allowedlist_features: {item_features}")
     item_preprocessing_options = item_preprocessing_options._replace(
         allowedlist_features=item_features,
         sequence_feature_id=self.slate_feature_id,
     )
     item_normalization_parameters = identify_normalization_parameters(
         input_table_spec,
         InputColumn.STATE_SEQUENCE_FEATURES,
         item_preprocessing_options,
     )
     return {
         NormalizationKey.STATE: NormalizationData(
             dense_normalization_parameters=state_normalization_parameters
         ),
         NormalizationKey.ITEM: NormalizationData(
             dense_normalization_parameters=item_normalization_parameters
         ),
     }
 def test_fully_connected(self):
     chooser = ValueNetBuilder__Union(
         FullyConnected=value.fully_connected.FullyConnected())
     builder = chooser.value
     state_dim = 3
     normalization_data = NormalizationData(
         dense_normalization_parameters={
             i: NormalizationParameters(feature_type=CONTINUOUS)
             for i in range(state_dim)
         })
     value_network = builder.build_value_network(normalization_data)
     batch_size = 5
     x = torch.randn(batch_size, state_dim)
     y = value_network(x)
     self.assertEqual(y.shape, (batch_size, 1))
示例#11
0
 def run_feature_identification(
         self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
     preprocessing_options = (self.model_manager.preprocessing_options
                              or PreprocessingOptions())
     logger.info("Overriding allowedlist_features")
     state_features = [
         ffi.feature_id for ffi in
         self.model_manager.state_feature_config.float_feature_infos
     ]
     preprocessing_options = preprocessing_options._replace(
         allowedlist_features=state_features)
     return {
         NormalizationKey.STATE:
         NormalizationData(dense_normalization_parameters=
                           identify_normalization_parameters(
                               input_table_spec, InputColumn.STATE_FEATURES,
                               preprocessing_options))
     }
示例#12
0
    def run_feature_identification(
            self, input_table_spec: TableSpec) -> Dict[str, NormalizationData]:
        # Run state feature identification
        state_preprocessing_options = PreprocessingOptions()
        state_features = [
            ffi.feature_id for ffi in
            self.model_manager.state_feature_config.float_feature_infos
        ]
        logger.info(f"Overriding state allowedlist_features: {state_features}")
        state_preprocessing_options = replace(
            state_preprocessing_options, allowedlist_features=state_features)

        state_normalization_parameters = identify_normalization_parameters(
            input_table_spec, InputColumn.STATE_FEATURES,
            state_preprocessing_options)

        return {
            NormalizationKey.STATE:
            NormalizationData(
                dense_normalization_parameters=state_normalization_parameters)
        }
示例#13
0
def train_seq2reward_compress_model(
    training_data, seq2reward_network, learning_rate=0.1, num_epochs=5
):
    SEQ_LEN, batch_size, NUM_ACTION = next(iter(training_data)).action.shape
    assert SEQ_LEN == 6 and NUM_ACTION == 2

    compress_net_builder = FullyConnected(sizes=[8, 8])
    state_normalization_data = NormalizationData(
        dense_normalization_parameters={
            0: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
            1: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
        }
    )
    compress_model_network = compress_net_builder.build_value_network(
        state_normalization_data,
        output_dim=NUM_ACTION,
    )

    trainer_param = Seq2RewardTrainerParameters(
        learning_rate=0.0,
        multi_steps=SEQ_LEN,
        action_names=["0", "1"],
        compress_model_learning_rate=learning_rate,
        gamma=1.0,
        view_q_value=True,
    )

    trainer = CompressModelTrainer(
        compress_model_network=compress_model_network,
        seq2reward_network=seq2reward_network,
        params=trainer_param,
    )

    pl.seed_everything(SEED)
    pl_trainer = pl.Trainer(max_epochs=num_epochs, deterministic=True)
    pl_trainer.fit(trainer, training_data)

    return trainer
    def test_seq2slate_scriptable(self):
        state_dim = 2
        candidate_dim = 3
        num_stacked_layers = 2
        num_heads = 2
        dim_model = 128
        dim_feedforward = 128
        candidate_size = 8
        slate_size = 8
        output_arch = Seq2SlateOutputArch.AUTOREGRESSIVE
        temperature = 1.0
        greedy_serving = True

        # test the raw Seq2Slate model is script-able
        seq2slate = Seq2SlateTransformerModel(
            state_dim=state_dim,
            candidate_dim=candidate_dim,
            num_stacked_layers=num_stacked_layers,
            num_heads=num_heads,
            dim_model=dim_model,
            dim_feedforward=dim_feedforward,
            max_src_seq_len=candidate_size,
            max_tgt_seq_len=slate_size,
            output_arch=output_arch,
            temperature=temperature,
        )
        seq2slate_scripted = torch.jit.script(seq2slate)

        seq2slate_net = Seq2SlateTransformerNet(
            state_dim=state_dim,
            candidate_dim=candidate_dim,
            num_stacked_layers=num_stacked_layers,
            num_heads=num_heads,
            dim_model=dim_model,
            dim_feedforward=dim_feedforward,
            max_src_seq_len=candidate_size,
            max_tgt_seq_len=slate_size,
            output_arch=output_arch,
            temperature=temperature,
        )

        state_normalization_data = NormalizationData(
            dense_normalization_parameters={
                0: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
                1: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
            })

        candidate_normalization_data = NormalizationData(
            dense_normalization_parameters={
                5: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
                6: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
                7: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
            })
        state_preprocessor = Preprocessor(
            state_normalization_data.dense_normalization_parameters, False)
        candidate_preprocessor = Preprocessor(
            candidate_normalization_data.dense_normalization_parameters, False)

        # test seq2slate with preprocessor is scriptable
        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate_net.eval(),
            state_preprocessor,
            candidate_preprocessor,
            greedy_serving,
        )
        torch.jit.script(seq2slate_with_preprocessor)