def evaluate(params):
    # file params
    experiment_path = os.path.join(params.output_data_dir)
    logger.experiment_path = os.path.join(experiment_path, 'evaluation')
    params.checkpoint_restore_dir = os.path.join(params.input_data_dir, 'checkpoint')
    checkpoint_file = os.path.join(params.checkpoint_restore_dir, 'checkpoint')

    inplace_change(checkpoint_file, "/opt/ml/output/data/checkpoint", ".")
    # Note that due to a tensorflow issue (https://github.com/tensorflow/tensorflow/issues/9146) we need to replace
    # the absolute path for the evaluation-from-a-checkpointed-model to work

    vis_params = VisualizationParameters()
    vis_params.dump_gifs = True

    task_params = TaskParameters(evaluate_only=True, experiment_path=logger.experiment_path)
    task_params.__dict__ = add_items_to_dict(task_params.__dict__, params.__dict__)

    graph_manager = BasicRLGraphManager(
        agent_params=ClippedPPOAgentParameters(),
        env_params=GymVectorEnvironment(level='TSP_env:TSPEasyEnv'),
        schedule_params=ScheduleParameters(),
        vis_params=vis_params
    )
    graph_manager = graph_manager.create_graph(task_parameters=task_params)
    graph_manager.evaluate(EnvironmentSteps(5))
Ejemplo n.º 2
0
def evaluate(params):
    # file params
    experiment_path = os.path.join(params.output_data_dir)
    logger.experiment_path = os.path.join(experiment_path, 'evaluation')
    params.checkpoint_restore_dir = os.path.join(params.input_data_dir,
                                                 'checkpoint')
    checkpoint_file = os.path.join(params.checkpoint_restore_dir, 'checkpoint')

    inplace_change(checkpoint_file, "/opt/ml/output/data/checkpoint", ".")
    # Note that due to a tensorflow issue (https://github.com/tensorflow/tensorflow/issues/9146) we need to replace
    # the absolute path for the evaluation-from-a-checkpointed-model to work

    vis_params = VisualizationParameters()
    vis_params.dump_gifs = True

    task_params = TaskParameters(evaluate_only=True,
                                 experiment_path=logger.experiment_path)
    task_params.__dict__ = add_items_to_dict(task_params.__dict__,
                                             params.__dict__)

    graph_manager = BasicRLGraphManager(
        agent_params=ClippedPPOAgentParameters(),
        env_params=GymVectorEnvironment(level='TSP_env:TSPEasyEnv'),
        schedule_params=ScheduleParameters(),
        vis_params=vis_params)
    graph_manager = graph_manager.create_graph(task_parameters=task_params)
    graph_manager.evaluate(EnvironmentSteps(5))
level = 'gym_dynamic_multi_armed_bandit.envs:BasicEnv2'
env_params = GymVectorEnvironment(level)

########################
# Create Graph Manager #
########################

graph_manager = BasicRLGraphManager(agent_params=agent_params,
                                    env_params=env_params,
                                    schedule_params=schedule)

#######################
# add task parameters #
#######################

log_path = './experiments_v2/log'  # training logs are saved
checkpoint_sec = 60  # checkpoints are used to restore the model
if not os.path.exists(log_path):
    os.makedirs(log_path)

task_parameters = TaskParameters(evaluate_only=False,
                                 experiment_path=log_path,
                                 checkpoint_save_secs=checkpoint_sec)

graph_manager.create_graph(task_parameters)

##################
# start training #
##################

graph_manager.improve()