Example #1
0
    def __init__(
        self,
        state_preprocessor: Preprocessor,
        value_network: Optional[nn.Module],
        action_names: List[str],
    ) -> None:
        super().__init__()

        self.state_sorted_features_t = state_preprocessor.sorted_features

        self.state_preprocessor = torch.jit.trace(
            state_preprocessor, (state_preprocessor.input_prototype()))

        value_network_sample_input = self.state_preprocessor(
            *state_preprocessor.input_prototype())
        self.value_network = torch.jit.trace(value_network,
                                             value_network_sample_input)
        self.action_names = torch.jit.Attribute(action_names, List[str])
Example #2
0
    def __init__(
        self,
        state_preprocessor: Preprocessor,
        action_preprocessor: Preprocessor,
        value_network: Optional[nn.Module],
    ) -> None:
        super().__init__()

        self.state_sorted_features_t = state_preprocessor.sorted_features
        self.state_preprocessor = torch.jit.trace(
            state_preprocessor, (state_preprocessor.input_prototype()))

        self.action_sorted_features_t = action_preprocessor.sorted_features
        self.action_preprocessor = torch.jit.trace(
            action_preprocessor, (action_preprocessor.input_prototype()))

        value_network_sample_input = (
            self.state_preprocessor(*state_preprocessor.input_prototype()),
            self.action_preprocessor(*action_preprocessor.input_prototype()),
        )
        self.value_network = torch.jit.trace(value_network,
                                             value_network_sample_input)