コード例 #1
0
ファイル: test_gym.py プロジェクト: kevin3062/ReAgent
    def test_cartpole_reinforce(self):
        # TODO(@badri) Parameterize this test
        env = Gym("CartPole-v0")
        norm = build_normalizer(env)

        from reagent.net_builder.discrete_dqn.fully_connected import FullyConnected

        net_builder = FullyConnected(sizes=[8], activations=["linear"])
        cartpole_scorer = net_builder.build_q_network(
            state_feature_config=None,
            state_normalization_data=norm["state"],
            output_dim=len(norm["action"].dense_normalization_parameters),
        )

        from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler

        policy = Policy(scorer=cartpole_scorer, sampler=SoftmaxActionSampler())

        from reagent.training.reinforce import Reinforce, ReinforceParams
        from reagent.optimizer.union import classes

        trainer = Reinforce(
            policy,
            ReinforceParams(gamma=0.995,
                            optimizer=classes["Adam"](lr=5e-3,
                                                      weight_decay=1e-3)),
        )
        run_test_episode_buffer(
            env,
            policy,
            trainer,
            num_train_episodes=500,
            passing_score_bar=180,
            num_eval_episodes=100,
        )
コード例 #2
0
    def test_toyvm(self):
        pl.seed_everything(SEED)
        env = ToyVM(slate_size=5, initial_seed=SEED)
        from reagent.models import MLPScorer

        slate_scorer = MLPScorer(input_dim=3,
                                 log_transform=True,
                                 layer_sizes=[64],
                                 concat=False)

        from reagent.samplers import FrechetSort

        policy = Policy(slate_scorer,
                        FrechetSort(log_scores=True, topk=5, equiv_len=5))
        from reagent.optimizer.union import classes
        from reagent.training.reinforce import Reinforce, ReinforceParams

        trainer = Reinforce(
            policy,
            ReinforceParams(gamma=0,
                            optimizer=classes["Adam"](lr=1e-1,
                                                      weight_decay=1e-3)),
        )

        run_test_episode_buffer(
            env,
            policy,
            trainer,
            num_train_episodes=500,
            passing_score_bar=120,
            num_eval_episodes=100,
        )