Beispiel #1
0
def sparse_input_prototype(
    model: ModelBase,
    state_preprocessor: Preprocessor,
    state_feature_config: rlt.ModelFeatureConfig,
):
    name2id = state_feature_config.name2id
    model_prototype = model.input_prototype()
    # Terrible hack to make JIT tracing works. Python dict doesn't have type
    # so we need to insert something so JIT tracer can infer the type.
    state_id_list_features = {
        42: (torch.zeros(1, dtype=torch.long), torch.tensor([], dtype=torch.long))
    }
    state_id_score_list_features = {
        42: (
            torch.zeros(1, dtype=torch.long),
            torch.tensor([], dtype=torch.long),
            torch.tensor([], dtype=torch.float),
        )
    }
    if isinstance(model_prototype, rlt.FeatureData):
        if model_prototype.id_list_features:
            state_id_list_features = {
                name2id[k]: v for k, v in model_prototype.id_list_features.items()
            }
        if model_prototype.id_score_list_features:
            state_id_score_list_features = {
                name2id[k]: v for k, v in model_prototype.id_score_list_features.items()
            }

    input = rlt.ServingFeatureData(
        float_features_with_presence=state_preprocessor.input_prototype(),
        id_list_features=state_id_list_features,
        id_score_list_features=state_id_score_list_features,
    )
    return (input,)
Beispiel #2
0
def sparse_input_prototype(
    model: ModelBase,
    state_preprocessor: Preprocessor,
    state_feature_config: rlt.ModelFeatureConfig,
):
    name2id = state_feature_config.name2id
    model_prototype = model.input_prototype()
    # Terrible hack to make JIT tracing works. Python dict doesn't have type
    # so we need to insert something so JIT tracer can infer the type.
    state_id_list_features = FAKE_STATE_ID_LIST_FEATURES
    state_id_score_list_features = FAKE_STATE_ID_SCORE_LIST_FEATURES
    if isinstance(model_prototype, rlt.FeatureData):
        if model_prototype.id_list_features:
            state_id_list_features = {
                name2id[k]: v
                for k, v in model_prototype.id_list_features.items()
            }
        if model_prototype.id_score_list_features:
            state_id_score_list_features = {
                name2id[k]: v
                for k, v in model_prototype.id_score_list_features.items()
            }

    input = rlt.ServingFeatureData(
        float_features_with_presence=state_preprocessor.input_prototype(),
        id_list_features=state_id_list_features,
        id_score_list_features=state_id_score_list_features,
    )
    return (input, )
Beispiel #3
0
    def __init__(
        self,
        *,
        shared_network: ModelBase,
        advantage_network: ModelBase,
        value_network: ModelBase,
    ) -> None:
        """
        Dueling Q-Network Architecture: https://arxiv.org/abs/1511.06581
        """
        super().__init__()
        self.shared_network = shared_network
        input_prototype = shared_network.input_prototype()
        assert isinstance(
            input_prototype, rlt.FeatureData
        ), f"shared_network should expect FeatureData as input"
        self.advantage_network = advantage_network
        self.value_network = value_network

        _check_connection(self)
        self._name = "unnamed"
Beispiel #4
0
    def __init__(
        self,
        *,
        shared_network: ModelBase,
        advantage_network: ModelBase,
        value_network: ModelBase,
    ) -> None:
        """
        Dueling Q-Network Architecture: https://arxiv.org/abs/1511.06581
        """
        super().__init__()
        advantage_network_input = advantage_network.input_prototype()
        assert (isinstance(advantage_network_input, tuple)
                and len(advantage_network_input) == 2)
        assert advantage_network_input[0].has_float_features_only

        self.shared_network = shared_network
        self.advantage_network = advantage_network
        self.value_network = value_network

        _check_connection(self)
        self._name = "unnamed"