def test_restore_checkpoint(preset_args, clres, framework, timeout=Def.TimeOuts.test_time_limit): """ Create checkpoints and restore them in second run. :param preset_args: all preset that can be tested for argument tests :param clres: logs and csv files :param framework: name of the test framework :param timeout: max time for test """ def _create_cmd_and_run(flag): """ Create default command with given flag and run it :param flag: name of the tested flag, this flag will be extended to the running command line :return: active process """ run_cmd = [ 'python3', 'rl_coach/coach.py', '-p', '{}'.format(preset_args), '-e', '{}'.format("ExpName_" + preset_args), '--seed', '{}'.format(4), '-f', '{}'.format(framework), ] test_flag = a_utils.add_one_flag_value(flag=flag) run_cmd.extend(test_flag) print(str(run_cmd)) p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout) return p start_time = time.time() if framework == "mxnet": # update preset name - for mxnet framework we are using *_DQN preset_args = Def.Presets.mxnet_args_test[0] # update logs paths test_name = 'ExpName_{}'.format(preset_args) test_path = os.path.join(Def.Path.experiments, test_name) clres.experiment_path = test_path clres.stdout_path = 'test_log_{}.txt'.format(preset_args) p_valid_params = p_utils.validation_params(preset_args) create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5']) # wait for checkpoint files csv_list = a_utils.get_csv_path(clres=clres) assert len(csv_list) > 0 exp_dir = os.path.dirname(csv_list[0]) checkpoint_dir = os.path.join(exp_dir, Def.Path.checkpoint) checkpoint_test_dir = os.path.join(Def.Path.experiments, Def.Path.test_dir) if os.path.exists(checkpoint_test_dir): shutil.rmtree(checkpoint_test_dir) res = a_utils.is_reward_reached(csv_path=csv_list[0], p_valid_params=p_valid_params, start_time=start_time, time_limit=timeout) if not res: screen.error(open(clres.stdout.name).read(), crash=False) assert False entities = a_utils.get_files_from_dir(checkpoint_dir) assert len(entities) > 0 assert any(".ckpt." in file for file in entities) # send CTRL+C to close experiment create_cp_proc.send_signal(signal.SIGINT) if os.path.isdir(checkpoint_dir): shutil.copytree(exp_dir, checkpoint_test_dir) shutil.rmtree(exp_dir) create_cp_proc.kill() checkpoint_test_dir = "{}/{}".format(checkpoint_test_dir, Def.Path.checkpoint) # run second time with checkpoint folder (restore) restore_cp_proc = _create_cmd_and_run( flag=['-crd', checkpoint_test_dir, '--evaluate']) new_csv_list = test_utils.get_csv_path(clres=clres) time.sleep(10) csv = pd.read_csv(new_csv_list[0]) res = csv['Episode Length'].values[-1] expected_reward = 100 assert res >= expected_reward, Def.Consts.ASSERT_MSG.format( str(expected_reward), str(res)) restore_cp_proc.kill() test_folder = os.path.join(Def.Path.experiments, Def.Path.test_dir) if os.path.exists(test_folder): shutil.rmtree(test_folder)
def test_restore_checkpoint(preset_args, clres, start_time=time.time()): """ Create checkpoint and restore them in second run.""" def _create_cmd_and_run(flag): run_cmd = [ 'python3', 'rl_coach/coach.py', '-p', '{}'.format(preset_args), '-e', '{}'.format("ExpName_" + preset_args), ] test_flag = a_utils.add_one_flag_value(flag=flag) run_cmd.extend(test_flag) p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout) return p create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5']) # wait for checkpoint files csv_list = a_utils.get_csv_path(clres=clres) assert len(csv_list) > 0 exp_dir = os.path.dirname(csv_list[0]) checkpoint_dir = os.path.join(exp_dir, Def.Path.checkpoint) checkpoint_test_dir = os.path.join(Def.Path.experiments, Def.Path.test_dir) if os.path.exists(checkpoint_test_dir): shutil.rmtree(checkpoint_test_dir) entities = a_utils.get_files_from_dir(checkpoint_dir) while not any("10_Step" in file for file in entities) and time.time() - \ start_time < Def.TimeOuts.test_time_limit: entities = a_utils.get_files_from_dir(checkpoint_dir) time.sleep(1) assert len(entities) > 0 assert "checkpoint" in entities assert any(".ckpt." in file for file in entities) # send CTRL+C to close experiment create_cp_proc.send_signal(signal.SIGINT) csv = pd.read_csv(csv_list[0]) rewards = csv['Evaluation Reward'].values rewards = rewards[~np.isnan(rewards)] min_reward = np.amin(rewards) if os.path.isdir(checkpoint_dir): shutil.copytree(exp_dir, checkpoint_test_dir) shutil.rmtree(exp_dir) create_cp_proc.kill() checkpoint_test_dir = "{}/{}".format(checkpoint_test_dir, Def.Path.checkpoint) # run second time with checkpoint folder (restore) restore_cp_proc = _create_cmd_and_run( flag=['-crd', checkpoint_test_dir, '--evaluate']) new_csv_list = test_utils.get_csv_path(clres=clres) time.sleep(10) csv = pd.read_csv(new_csv_list[0]) res = csv['Episode Length'].values[-1] assert res >= min_reward, \ Def.Consts.ASSERT_MSG.format(str(res) + ">=" + str(min_reward), str(res) + " < " + str(min_reward)) restore_cp_proc.kill()