Ejemplo n.º 1
0
def test_restore_checkpoint(preset_args,
                            clres,
                            framework,
                            timeout=Def.TimeOuts.test_time_limit):
    """
    Create checkpoints and restore them in second run.
    :param preset_args: all preset that can be tested for argument tests
    :param clres: logs and csv files
    :param framework: name of the test framework
    :param timeout: max time for test
    """
    def _create_cmd_and_run(flag):
        """
        Create default command with given flag and run it
        :param flag: name of the tested flag, this flag will be extended to the
                     running command line
        :return: active process
        """
        run_cmd = [
            'python3',
            'rl_coach/coach.py',
            '-p',
            '{}'.format(preset_args),
            '-e',
            '{}'.format("ExpName_" + preset_args),
            '--seed',
            '{}'.format(4),
            '-f',
            '{}'.format(framework),
        ]

        test_flag = a_utils.add_one_flag_value(flag=flag)
        run_cmd.extend(test_flag)
        print(str(run_cmd))
        p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)

        return p

    start_time = time.time()

    if framework == "mxnet":
        # update preset name - for mxnet framework we are using *_DQN
        preset_args = Def.Presets.mxnet_args_test[0]
        # update logs paths
        test_name = 'ExpName_{}'.format(preset_args)
        test_path = os.path.join(Def.Path.experiments, test_name)
        clres.experiment_path = test_path
        clres.stdout_path = 'test_log_{}.txt'.format(preset_args)

    p_valid_params = p_utils.validation_params(preset_args)
    create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])

    # wait for checkpoint files
    csv_list = a_utils.get_csv_path(clres=clres)
    assert len(csv_list) > 0
    exp_dir = os.path.dirname(csv_list[0])

    checkpoint_dir = os.path.join(exp_dir, Def.Path.checkpoint)

    checkpoint_test_dir = os.path.join(Def.Path.experiments, Def.Path.test_dir)
    if os.path.exists(checkpoint_test_dir):
        shutil.rmtree(checkpoint_test_dir)

    res = a_utils.is_reward_reached(csv_path=csv_list[0],
                                    p_valid_params=p_valid_params,
                                    start_time=start_time,
                                    time_limit=timeout)
    if not res:
        screen.error(open(clres.stdout.name).read(), crash=False)
        assert False

    entities = a_utils.get_files_from_dir(checkpoint_dir)

    assert len(entities) > 0
    assert any(".ckpt." in file for file in entities)

    # send CTRL+C to close experiment
    create_cp_proc.send_signal(signal.SIGINT)

    if os.path.isdir(checkpoint_dir):
        shutil.copytree(exp_dir, checkpoint_test_dir)
        shutil.rmtree(exp_dir)

    create_cp_proc.kill()
    checkpoint_test_dir = "{}/{}".format(checkpoint_test_dir,
                                         Def.Path.checkpoint)
    # run second time with checkpoint folder  (restore)
    restore_cp_proc = _create_cmd_and_run(
        flag=['-crd', checkpoint_test_dir, '--evaluate'])

    new_csv_list = test_utils.get_csv_path(clres=clres)
    time.sleep(10)

    csv = pd.read_csv(new_csv_list[0])
    res = csv['Episode Length'].values[-1]
    expected_reward = 100
    assert res >= expected_reward, Def.Consts.ASSERT_MSG.format(
        str(expected_reward), str(res))
    restore_cp_proc.kill()

    test_folder = os.path.join(Def.Path.experiments, Def.Path.test_dir)
    if os.path.exists(test_folder):
        shutil.rmtree(test_folder)
Ejemplo n.º 2
0
def test_restore_checkpoint(preset_args, clres, start_time=time.time()):
    """ Create checkpoint and restore them in second run."""
    def _create_cmd_and_run(flag):

        run_cmd = [
            'python3',
            'rl_coach/coach.py',
            '-p',
            '{}'.format(preset_args),
            '-e',
            '{}'.format("ExpName_" + preset_args),
        ]
        test_flag = a_utils.add_one_flag_value(flag=flag)
        run_cmd.extend(test_flag)

        p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)

        return p

    create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])

    # wait for checkpoint files
    csv_list = a_utils.get_csv_path(clres=clres)
    assert len(csv_list) > 0
    exp_dir = os.path.dirname(csv_list[0])

    checkpoint_dir = os.path.join(exp_dir, Def.Path.checkpoint)

    checkpoint_test_dir = os.path.join(Def.Path.experiments, Def.Path.test_dir)
    if os.path.exists(checkpoint_test_dir):
        shutil.rmtree(checkpoint_test_dir)

    entities = a_utils.get_files_from_dir(checkpoint_dir)

    while not any("10_Step" in file for file in entities) and time.time() - \
            start_time < Def.TimeOuts.test_time_limit:
        entities = a_utils.get_files_from_dir(checkpoint_dir)
        time.sleep(1)

    assert len(entities) > 0
    assert "checkpoint" in entities
    assert any(".ckpt." in file for file in entities)

    # send CTRL+C to close experiment
    create_cp_proc.send_signal(signal.SIGINT)

    csv = pd.read_csv(csv_list[0])
    rewards = csv['Evaluation Reward'].values
    rewards = rewards[~np.isnan(rewards)]
    min_reward = np.amin(rewards)

    if os.path.isdir(checkpoint_dir):
        shutil.copytree(exp_dir, checkpoint_test_dir)
        shutil.rmtree(exp_dir)

    create_cp_proc.kill()
    checkpoint_test_dir = "{}/{}".format(checkpoint_test_dir,
                                         Def.Path.checkpoint)
    # run second time with checkpoint folder  (restore)
    restore_cp_proc = _create_cmd_and_run(
        flag=['-crd', checkpoint_test_dir, '--evaluate'])

    new_csv_list = test_utils.get_csv_path(clres=clres)
    time.sleep(10)

    csv = pd.read_csv(new_csv_list[0])
    res = csv['Episode Length'].values[-1]
    assert res >= min_reward, \
        Def.Consts.ASSERT_MSG.format(str(res) + ">=" + str(min_reward),
                                     str(res) + " < " + str(min_reward))
    restore_cp_proc.kill()