예제 #1
0
    def setUp(self):
        # preparing various components for qr-dqn trainer initialization
        self.batch_size = 3
        self.state_dim = 10
        self.action_dim = 2
        self.num_layers = 2
        self.sizes = [20 for _ in range(self.num_layers)]
        self.activations = ["relu" for _ in range(self.num_layers)]
        self.use_layer_norm = False
        self.softmax_temperature = 1

        self.actions = [str(i) for i in range(self.action_dim)]
        self.params = PPOTrainerParameters(actions=self.actions, normalize=False)
        self.reward_options = RewardOptions()
        self.metrics_to_score = get_metrics_to_score(
            self.reward_options.metric_reward_values
        )

        self.policy_network = DuelingQNetwork.make_fully_connected(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            layers=self.sizes,
            activations=self.activations,
        )
        self.sampler = SoftmaxActionSampler(temperature=self.softmax_temperature)
        self.policy = Policy(scorer=self.policy_network, sampler=self.sampler)

        self.value_network = FloatFeatureFullyConnected(
            state_dim=self.state_dim,
            output_dim=1,
            sizes=self.sizes,
            activations=self.activations,
            use_layer_norm=self.use_layer_norm,
        )
예제 #2
0
    def test_cartpole_reinforce(self):
        # TODO(@badri) Parameterize this test
        env = Gym("CartPole-v0")
        norm = build_normalizer(env)

        from reagent.net_builder.discrete_dqn.fully_connected import FullyConnected

        net_builder = FullyConnected(sizes=[8], activations=["linear"])
        cartpole_scorer = net_builder.build_q_network(
            state_feature_config=None,
            state_normalization_data=norm["state"],
            output_dim=len(norm["action"].dense_normalization_parameters),
        )

        from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler

        policy = Policy(scorer=cartpole_scorer, sampler=SoftmaxActionSampler())

        from reagent.training.reinforce import Reinforce, ReinforceParams
        from reagent.optimizer.union import classes

        trainer = Reinforce(
            policy,
            ReinforceParams(gamma=0.995,
                            optimizer=classes["Adam"](lr=5e-3,
                                                      weight_decay=1e-3)),
        )
        run_test_episode_buffer(
            env,
            policy,
            trainer,
            num_train_episodes=500,
            passing_score_bar=180,
            num_eval_episodes=100,
        )
예제 #3
0
    def create_policy(
        self,
        trainer_module: ReAgentLightningModule,
        serving: bool = False,
        normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
    ):
        """Create an online DiscreteDQN Policy from env."""

        # FIXME: this only works for one-hot encoded actions
        # pyre-fixme[16]: `Tensor` has no attribute `input_prototype`.
        action_dim = trainer_module.q_network.input_prototype()[1].float_features.shape[
            1
        ]
        if serving:
            assert normalization_data_map
            return create_predictor_policy_from_model(
                self.build_serving_module(trainer_module, normalization_data_map),
                max_num_actions=action_dim,
            )
        else:
            # pyre-fixme[16]: `ParametricDQNBase` has no attribute `rl_parameters`.
            sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature)
            scorer = parametric_dqn_scorer(
                max_num_actions=action_dim,
                # pyre-fixme[6]: Expected `ModelBase` for 2nd param but got
                #  `Union[torch.Tensor, torch.nn.Module]`.
                q_network=trainer_module.q_network,
            )
            return Policy(scorer=scorer, sampler=sampler)
예제 #4
0
    def test_toyvm(self):
        pl.seed_everything(SEED)
        env = ToyVM(slate_size=5, initial_seed=SEED)
        from reagent.models import MLPScorer

        slate_scorer = MLPScorer(input_dim=3,
                                 log_transform=True,
                                 layer_sizes=[64],
                                 concat=False)

        from reagent.samplers import FrechetSort

        policy = Policy(slate_scorer,
                        FrechetSort(log_scores=True, topk=5, equiv_len=5))
        from reagent.optimizer.union import classes
        from reagent.training.reinforce import Reinforce, ReinforceParams

        trainer = Reinforce(
            policy,
            ReinforceParams(gamma=0,
                            optimizer=classes["Adam"](lr=1e-1,
                                                      weight_decay=1e-3)),
        )

        run_test_episode_buffer(
            env,
            policy,
            trainer,
            num_train_episodes=500,
            passing_score_bar=120,
            num_eval_episodes=100,
        )
예제 #5
0
 def create_policy(self, serving: bool = False):
     if serving:
         return create_predictor_policy_from_model(self.build_serving_module())
     else:
         if self._policy is None:
             sampler = SoftmaxActionSampler(temperature=self.sampler_temperature)
             # pyre-ignore
             self._policy = Policy(scorer=self._policy_network, sampler=sampler)
         return self._policy
예제 #6
0
 def create_policy(self, serving: bool) -> Policy:
     """Create an online DiscreteDQN Policy from env."""
     if serving:
         return create_predictor_policy_from_model(
             self.build_serving_module(), rl_parameters=self.rl_parameters)
     else:
         sampler = GreedyActionSampler()
         # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
         scorer = discrete_dqn_scorer(self.trainer.q_network)
         return Policy(scorer=scorer, sampler=sampler)
예제 #7
0
 def create_policy(self, serving: bool) -> Policy:
     if serving:
         sampler = GreedyActionSampler()
         scorer = discrete_dqn_serving_scorer(
             DiscreteDqnPredictorUnwrapper(self.build_serving_module()))
     else:
         sampler = SoftmaxActionSampler(
             temperature=self.rl_parameters.temperature)
         # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
         scorer = discrete_qrdqn_scorer(self.trainer.q_network)
     return Policy(scorer=scorer, sampler=sampler)
예제 #8
0
    def create_policy(self) -> Policy:
        """ Create an online DiscreteDQN Policy from env.
        """
        # Avoiding potentially importing gym when it's not installed

        from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler
        from reagent.gym.policies.scorers.discrete_scorer import discrete_dqn_scorer

        sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature)
        scorer = discrete_dqn_scorer(self.trainer.q_network)
        return Policy(scorer=scorer, sampler=sampler)
예제 #9
0
 def create_policy(self, serving: bool) -> Policy:
     if serving:
         return create_predictor_policy_from_model(
             self.build_serving_module(),
             max_num_actions=self.num_candidates,
             slate_size=self.slate_size,
         )
     else:
         scorer = slate_q_scorer(num_candidates=self.num_candidates,
                                 q_network=self._q_network)
         sampler = TopKSampler(k=self.slate_size)
         return Policy(scorer=scorer, sampler=sampler)
예제 #10
0
    def create_policy(self, serving: bool) -> Policy:
        """ Create an online DiscreteDQN Policy from env. """

        # FIXME: this only works for one-hot encoded actions
        action_dim = get_num_output_features(
            self.action_normalization_data.dense_normalization_parameters)
        if serving:
            return create_predictor_policy_from_model(
                self.build_serving_module(), max_num_actions=action_dim)
        else:
            sampler = SoftmaxActionSampler(
                temperature=self.rl_parameters.temperature)
            scorer = parametric_dqn_scorer(max_num_actions=action_dim,
                                           q_network=self._q_network)
            return Policy(scorer=scorer, sampler=sampler)
예제 #11
0
    def create_policy(self, serving: bool) -> Policy:
        """ Create an online DiscreteDQN Policy from env. """

        from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler
        from reagent.gym.policies.scorers.discrete_scorer import (
            discrete_dqn_scorer,
            discrete_dqn_serving_scorer,
        )

        sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature)
        if serving:
            scorer = discrete_dqn_serving_scorer(
                DiscreteDqnPredictorUnwrapper(self.build_serving_module())
            )
        else:
            scorer = discrete_dqn_scorer(self.trainer.q_network)
        return Policy(scorer=scorer, sampler=sampler)
예제 #12
0
 def create_policy(
     self,
     trainer_module: ReAgentLightningModule,
     serving: bool = False,
     normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
 ) -> Policy:
     """Create an online DiscreteDQN Policy from env."""
     if serving:
         assert normalization_data_map
         return create_predictor_policy_from_model(
             self.build_serving_module(trainer_module, normalization_data_map),
             rl_parameters=self.rl_parameters,
         )
     else:
         sampler = GreedyActionSampler()
         # pyre-fixme[6]: Expected `ModelBase` for 1st param but got
         #  `Union[torch.Tensor, torch.nn.Module]`.
         scorer = discrete_dqn_scorer(trainer_module.q_network)
         return Policy(scorer=scorer, sampler=sampler)
예제 #13
0
 def setUp(self):
     logging.getLogger().setLevel(logging.DEBUG)
     env = Gym("CartPole-v0")
     norm = build_normalizer(env)
     net_builder = FullyConnected(sizes=[8], activations=["linear"])
     cartpole_scorer = net_builder.build_q_network(
         state_feature_config=None,
         state_normalization_data=norm["state"],
         output_dim=len(norm["action"].dense_normalization_parameters),
     )
     policy = Policy(scorer=cartpole_scorer, sampler=SoftmaxActionSampler())
     agent = Agent.create_for_env(env, policy)
     self.max_steps = 3
     self.num_episodes = 6
     self.dataset = EpisodicDataset(
         env=env,
         agent=agent,
         num_episodes=self.num_episodes,
         seed=0,
         max_steps=self.max_steps,
     )
예제 #14
0
 def create_policy(
     self,
     trainer_module: ReAgentLightningModule,
     serving: bool = False,
     normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
 ):
     if serving:
         assert normalization_data_map
         return create_predictor_policy_from_model(
             self.build_serving_module(trainer_module, normalization_data_map),
             # pyre-fixme[16]: `SlateQBase` has no attribute `num_candidates`.
             max_num_actions=self.num_candidates,
             # pyre-fixme[16]: `SlateQBase` has no attribute `slate_size`.
             slate_size=self.slate_size,
         )
     else:
         scorer = slate_q_scorer(
             num_candidates=self.num_candidates,
             # pyre-fixme[6]: Expected `ModelBase` for 2nd param but got
             #  `Union[torch.Tensor, torch.nn.Module]`.
             q_network=trainer_module.q_network,
         )
         sampler = TopKSampler(k=self.slate_size)
         return Policy(scorer=scorer, sampler=sampler)
예제 #15
0
 def _create_policy(self, policy_network):
     if self._policy is None:
         sampler = SoftmaxActionSampler(temperature=self.sampler_temperature)
         self._policy = Policy(scorer=policy_network, sampler=sampler)
     return self._policy