예제 #1
0
    def __init__(
        self,
        q_network,
        q_network_target,
        use_gpu: bool = False,
        # Start SlateQTrainerParameters
        rl: rlp.RLParameters = field(  # noqa: B008
            default_factory=lambda: rlp.RLParameters(maxq_learning=False)),
        optimizer: Optimizer__Union = field(  # noqa: B008
            default_factory=Optimizer__Union.default),
        single_selection: bool = True,
        minibatch_size: int = 1024,
        evaluation: rlp.EvaluationParameters = field(  # noqa: B008
            default_factory=lambda: rlp.EvaluationParameters(
                calc_cpe_in_training=False)),
    ) -> None:
        super().__init__(rl, use_gpu=use_gpu)
        self.minibatches_per_step = 1
        self.minibatch_size = minibatch_size
        self.single_selection = single_selection

        self.q_network = q_network
        self.q_network_target = q_network_target
        self.q_network_optimizer = optimizer.make_optimizer(
            self.q_network.parameters())
예제 #2
0
class SlateQTrainerParameters:
    rl: rlp.RLParameters = field(
        default_factory=lambda: rlp.RLParameters(maxq_learning=False))
    optimizer: str = "ADAM"
    learning_rate: float = 0.001
    minibatch_size: int = 1024
    evaluation: rlp.EvaluationParameters = field(
        default_factory=lambda: rlp.EvaluationParameters(calc_cpe_in_training=
                                                         False))
예제 #3
0
 def test_json_serialize_basic(self):
     damp = rlp.DiscreteActionModelParameters(
         actions=["foo", "bar"],
         rl=rlp.RLParameters(),
         training=rlp.TrainingParameters(),
         rainbow=rlp.RainbowDQNParameters(double_q_learning=False, categorical=True),
         state_feature_params=None,
         target_action_distribution=[1.0, 2.0],
         evaluation=rlp.EvaluationParameters(),
     )
     self.assertEqual(
         damp,
         json_to_object(object_to_json(damp), rlp.DiscreteActionModelParameters),
     )
예제 #4
0
    def __init__(
        self,
        q_network,
        q_network_target,
        # Start SlateQTrainerParameters
        rl: rlp.RLParameters = field(  # noqa: B008
            default_factory=lambda: rlp.RLParameters(maxq_learning=False)
        ),
        optimizer: Optimizer__Union = field(  # noqa: B008
            default_factory=Optimizer__Union.default
        ),
        single_selection: bool = True,
        minibatch_size: int = 1024,
        evaluation: rlp.EvaluationParameters = field(  # noqa: B008
            default_factory=lambda: rlp.EvaluationParameters(calc_cpe_in_training=False)
        ),
    ) -> None:
        """
        Args:
            q_network: states, action -> q-value
            rl (optional): an instance of the RLParameter class, which
                defines relevant hyperparameters
            optimizer (optional): the optimizer class and
                optimizer hyperparameters for the q network(s) optimizer
            single_selection (optional): TBD
            minibatch_size (optional): the size of the minibatch
            evaluation (optional): TBD
        """
        super().__init__()
        self.rl_parameters = rl

        self.single_selection = single_selection

        self.q_network = q_network
        self.q_network_target = q_network_target
        self.q_network_optimizer = optimizer