Beispiel #1
0
 def test_ac_algorithm_with_global_counter(self):
     env = MyEnv(batch_size=3)
     alg2 = create_algorithm(env)
     new_iter_num = 3
     for _ in range(new_iter_num):
         alg2.train_iter()
     # new_iter_num of iterations done in alg2
     self.assertTrue(alf.summary.get_global_counter() == new_iter_num)
Beispiel #2
0
 def _create_environment(self,
                         nonparallel=False,
                         random_seed=None,
                         register=True):
     env = MyEnv(3)
     if register:
         self._register_env(env)
     return env
Beispiel #3
0
    def test_trac_algorithm(self):
        config = TrainerConfig(root_dir="dummy", unroll_length=5)
        env = MyEnv(batch_size=3)
        alg = TracAlgorithm(observation_spec=env.observation_spec(),
                            action_spec=env.action_spec(),
                            ac_algorithm_cls=create_ac_algorithm,
                            env=env,
                            config=config)

        for _ in range(50):
            alg.train_iter()

        time_step = common.get_initial_time_step(env)
        state = alg.get_initial_predict_state(env.batch_size)
        policy_step = alg.rollout_step(time_step, state)
        logits = policy_step.info.action_distribution.log_prob(
            torch.arange(3).reshape(3, 1))
        print("logits: ", logits)
        # action 1 gets the most reward. So its probability should be higher
        # than other actions after training.
        self.assertTrue(torch.all(logits[1, :] > logits[0, :]))
        self.assertTrue(torch.all(logits[1, :] > logits[2, :]))
Beispiel #4
0
    def test_ac_algorithm(self):
        env = MyEnv(batch_size=3)
        alg1 = create_algorithm(env)

        iter_num = 50
        for _ in range(iter_num):
            alg1.train_iter()

        time_step = common.get_initial_time_step(env)
        state = alg1.get_initial_predict_state(env.batch_size)
        policy_step = alg1.rollout_step(time_step, state)
        logits = policy_step.info.action_distribution.log_prob(
            torch.arange(3).reshape(3, 1))
        print("logits: ", logits)
        self.assertTrue(torch.all(logits[1, :] > logits[0, :]))
        self.assertTrue(torch.all(logits[1, :] > logits[2, :]))

        # global counter is iter_num due to alg1
        self.assertTrue(alf.summary.get_global_counter() == iter_num)