Esempio n. 1
0
    def test_merlin_algorithm(self):
        batch_size = 100
        steps_per_episode = 15
        gap = 10
        env = RNNPolicyUnittestEnv(
            batch_size, steps_per_episode, gap, obs_dim=3)
        env = TFPyEnvironment(env)

        common.set_global_env(env)

        algorithm = _create_merlin_algorithm(
            learning_rate=1e-3, debug_summaries=False)
        driver = OnPolicyDriver(env, algorithm, train_interval=6)

        eval_driver = OnPolicyDriver(env, algorithm, training=False)

        proc = psutil.Process(os.getpid())

        policy_state = driver.get_initial_policy_state()
        time_step = driver.get_initial_time_step()
        for i in range(100):
            t0 = time.time()
            time_step, policy_state, _ = driver.run(
                max_num_steps=150 * batch_size,
                time_step=time_step,
                policy_state=policy_state)
            mem = proc.memory_info().rss // 1e6
            logging.info('%s time=%.3f mem=%s' % (i, time.time() - t0, mem))

        env.reset()
        time_step, _ = eval_driver.run(max_num_steps=14 * batch_size)
        logging.info("eval reward=%.3f" % tf.reduce_mean(time_step.reward))
        self.assertAlmostEqual(
            1.0, float(tf.reduce_mean(time_step.reward)), delta=1e-2)
Esempio n. 2
0
def play(root_dir,
         env,
         algorithm,
         checkpoint_name=None,
         greedy_predict=True,
         random_seed=None,
         num_episodes=10,
         sleep_time_per_step=0.01,
         record_file=None,
         use_tf_functions=True):
    """Play using the latest checkpoint under `train_dir`.

    The following example record the play of a trained model to a mp4 video:
    ```bash
    python -m alf.bin.play \
    --root_dir=~/tmp/bullet_humanoid/ppo2/ppo2-11 \
    --num_episodes=1 \
    --record_file=ppo_bullet_humanoid.mp4
    ```
    Args:
        root_dir (str): same as the root_dir used for `train()`
        env (TFEnvironment): the environment
        algorithm (OnPolicyAlgorithm): the training algorithm
        checkpoint_name (str): name of the checkpoint (e.g. 'ckpt-12800`).
            If None, the latest checkpoint under train_dir will be used.
        greedy_predict (bool): use greedy action for evaluation.
        random_seed (None|int): random seed, a random seed is used if None
        num_episodes (int): number of episodes to play
        sleep_time_per_step (float): sleep so many seconds for each step
        record_file (str): if provided, video will be recorded to a file
            instead of shown on the screen.
        use_tf_functions (bool): whether to use tf.function
    """
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    if random_seed is not None:
        random.seed(random_seed)
        np.random.seed(random_seed)
        tf.random.set_seed(random_seed)

    global_step = get_global_counter()

    driver = OnPolicyDriver(env=env,
                            algorithm=algorithm,
                            training=False,
                            greedy_predict=greedy_predict)

    ckpt_dir = os.path.join(train_dir, 'algorithm')
    checkpoint = tf.train.Checkpoint(algorithm=algorithm,
                                     metrics=metric_utils.MetricsGroup(
                                         driver.get_metrics(), 'metrics'),
                                     global_step=global_step)
    if checkpoint_name is not None:
        ckpt_path = os.path.join(ckpt_dir, checkpoint_name)
    else:
        ckpt_path = tf.train.latest_checkpoint(ckpt_dir)
    if ckpt_path is not None:
        logging.info("Restore from checkpoint %s" % ckpt_path)
        checkpoint.restore(ckpt_path)
    else:
        logging.info("Checkpoint is not found at %s" % ckpt_dir)

    if not use_tf_functions:
        tf.config.experimental_run_functions_eagerly(True)

    recorder = None
    if record_file is not None:
        recorder = VideoRecorder(env.pyenv.envs[0], path=record_file)
    else:
        # pybullet_envs need to render() before reset() to enable mode='human'
        env.pyenv.envs[0].render(mode='human')
    env.reset()
    if recorder:
        recorder.capture_frame()
    time_step = driver.get_initial_time_step()
    policy_state = driver.get_initial_policy_state()
    episode_reward = 0.
    episode_length = 0
    episodes = 0
    while episodes < num_episodes:
        time_step, policy_state = driver.run(max_num_steps=1,
                                             time_step=time_step,
                                             policy_state=policy_state)
        if recorder:
            recorder.capture_frame()
        else:
            env.pyenv.envs[0].render(mode='human')
            time.sleep(sleep_time_per_step)

        episode_reward += float(time_step.reward)

        if time_step.is_last():
            logging.info("episode_length=%s episode_reward=%s" %
                         (episode_length, episode_reward))
            episode_reward = 0.
            episode_length = 0.
            episodes += 1
        else:
            episode_length += 1
    if recorder:
        recorder.close()
    env.reset()