def test_cartpole_reinforce(self): # TODO(@badri) Parameterize this test env = Gym("CartPole-v0") norm = build_normalizer(env) from reagent.net_builder.discrete_dqn.fully_connected import FullyConnected net_builder = FullyConnected(sizes=[8], activations=["linear"]) cartpole_scorer = net_builder.build_q_network( state_feature_config=None, state_normalization_data=norm["state"], output_dim=len(norm["action"].dense_normalization_parameters), ) from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler policy = Policy(scorer=cartpole_scorer, sampler=SoftmaxActionSampler()) from reagent.training.reinforce import Reinforce, ReinforceParams from reagent.optimizer.union import classes trainer = Reinforce( policy, ReinforceParams(gamma=0.995, optimizer=classes["Adam"](lr=5e-3, weight_decay=1e-3)), ) run_test_episode_buffer( env, policy, trainer, num_train_episodes=500, passing_score_bar=180, num_eval_episodes=100, )
def test_toyvm(self): pl.seed_everything(SEED) env = ToyVM(slate_size=5, initial_seed=SEED) from reagent.models import MLPScorer slate_scorer = MLPScorer(input_dim=3, log_transform=True, layer_sizes=[64], concat=False) from reagent.samplers import FrechetSort policy = Policy(slate_scorer, FrechetSort(log_scores=True, topk=5, equiv_len=5)) from reagent.optimizer.union import classes from reagent.training.reinforce import Reinforce, ReinforceParams trainer = Reinforce( policy, ReinforceParams(gamma=0, optimizer=classes["Adam"](lr=1e-1, weight_decay=1e-3)), ) run_test_episode_buffer( env, policy, trainer, num_train_episodes=500, passing_score_bar=120, num_eval_episodes=100, )