def test_is_error_bad_ckpnt(error, expected):
    """This function checks the functionality of is_error_bad_ckpnt function
    in markov/utils.py

    <is_error_bad_ckpnt> determines whether a value error is caused by an invalid checkpoint
    by looking for keywords 'tensor', 'shape', 'checksum', 'checkpoint' in the exception message

    Args:
        error (String): Error message to be parsed
        expected (Boolean): Expected return from function
    """
    assert utils.is_error_bad_ckpnt(error) == expected
예제 #2
0
                      simtrace_s3_bucket_map=simtrace_s3_bucket_dict,
                      simtrace_s3_prefix_map=simtrace_s3_prefix_dict,
                      mp4_s3_bucket_map=mp4_s3_bucket_dict,
                      mp4_s3_prefix_map=mp4_s3_object_prefix_dict,
                      display_names=display_names)

    # tournament_worker: terminate tournament_race_node.
    terminate_tournament_race()


if __name__ == '__main__':
    try:
        rospy.init_node('rl_coach', anonymous=True)
        main()
    except ValueError as err:
        if utils.is_error_bad_ckpnt(err):
            utils.log_and_exit("User modified model: {}".format(err),
                               utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
                               utils.SIMAPP_EVENT_ERROR_CODE_400)
        else:
            utils.log_and_exit("Eval worker value error: {}".format(err),
                               utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
                               utils.SIMAPP_EVENT_ERROR_CODE_500)
    except GenericRolloutError as ex:
        ex.log_except_and_exit()
    except GenericRolloutException as ex:
        ex.log_except_and_exit()
    except Exception as ex:
        utils.log_and_exit("Eval worker error: {}".format(ex),
                           utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
                           utils.SIMAPP_EVENT_ERROR_CODE_500)
예제 #3
0
def training_worker(graph_manager, task_parameters, user_batch_size,
                    user_episode_per_rollout):
    try:
        # initialize graph
        graph_manager.create_graph(task_parameters)

        # save initial checkpoint
        graph_manager.save_checkpoint()

        # training loop
        steps = 0

        graph_manager.setup_memory_backend()
        graph_manager.signal_ready()

        # To handle SIGTERM
        door_man = utils.DoorMan()

        while steps < graph_manager.improve_steps.num_steps:
            # Collect profiler information only IS_PROFILER_ON is true
            with utils.Profiler(
                    s3_bucket=PROFILER_S3_BUCKET,
                    s3_prefix=PROFILER_S3_PREFIX,
                    output_local_path=TRAINING_WORKER_PROFILER_PATH,
                    enable_profiling=IS_PROFILER_ON):
                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.fetch_from_worker(
                    graph_manager.agent_params.algorithm.
                    num_consecutive_playing_steps)
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout(
                )

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout

                if graph_manager.should_train():
                    # Make sure we have enough data for the requested batches
                    rollout_steps = graph_manager.memory_backend.get_rollout_steps(
                    )
                    if any(rollout_steps.values()) <= 0:
                        log_and_exit(
                            "No rollout data retrieved from the rollout worker",
                            SIMAPP_TRAINING_WORKER_EXCEPTION,
                            SIMAPP_EVENT_ERROR_CODE_500)

                    episode_batch_size = user_batch_size if min(
                        rollout_steps.values(
                        )) > user_batch_size else 2**math.floor(
                            math.log(min(rollout_steps.values()), 2))
                    # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
                    # as  batch size less than 2 causes the batch list to become a scalar which causes an exception
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            agent.ap.network_wrappers[
                                'main'].batch_size = episode_batch_size

                    steps += 1

                    graph_manager.phase = core_types.RunPhase.TRAIN

                    #funzione in multi_agent_graph_manager.py
                    graph_manager.train()
                    graph_manager.phase = core_types.RunPhase.UNDEFINED

                    # Check for Nan's in all agents
                    rollout_has_nan = False
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            if np.isnan(agent.loss.get_mean()):
                                rollout_has_nan = True
                    if rollout_has_nan:
                        log_and_exit(
                            "NaN detected in loss function, aborting training.",
                            SIMAPP_TRAINING_WORKER_EXCEPTION,
                            SIMAPP_EVENT_ERROR_CODE_500)

                    if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                        graph_manager.save_checkpoint()
                    else:
                        graph_manager.occasionally_save_checkpoint()

                    # Clear any data stored in signals that is no longer necessary
                    graph_manager.reset_internal_state()

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout

                if door_man.terminate_now:
                    log_and_exit(
                        "Received SIGTERM. Checkpointing before exiting.",
                        SIMAPP_TRAINING_WORKER_EXCEPTION,
                        SIMAPP_EVENT_ERROR_CODE_500)
                    graph_manager.save_checkpoint()
                    break

    except ValueError as err:
        if utils.is_error_bad_ckpnt(err):
            log_and_exit("User modified model: {}".format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        else:
            log_and_exit("An error occured while training: {}".format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        log_and_exit("An error occured while training: {}".format(ex),
                     SIMAPP_TRAINING_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)
    finally:
        graph_manager.data_store.upload_finished_file()
예제 #4
0
def rename_checkpoints(checkpoint_dir, agent_name):
    ''' Helper method that rename the specific checkpoint in the CheckpointStateFile
        to be scoped with agent_name
        checkpoint_dir - local checkpoint folder where the checkpoints and .checkpoint file is stored
        agent_name - name of the agent
    '''
    try:
        logger.info("Renaming checkpoint from checkpoint_dir: {} for agent: {}".format(checkpoint_dir, agent_name))
        state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
        checkpoint_name = str(state_file.read())
        tf_checkpoint_file = os.path.join(checkpoint_dir, "checkpoint")
        with open(tf_checkpoint_file, "w") as outfile:
            outfile.write("model_checkpoint_path: \"{}\"".format(checkpoint_name))

        with tf.Session() as sess:
            for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
                # Load the variable
                var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
                new_name = var_name
                # Set the new name
                # Replace agent/ or agent_#/ with {agent_name}/
                new_name = re.sub('agent/|agent_\d+/', '{}/'.format(agent_name), new_name)
                # Rename the variable
                var = tf.Variable(var, name=new_name)
            saver = tf.train.Saver()
            sess.run(tf.global_variables_initializer())
            renamed_checkpoint_path = os.path.join(TEMP_RENAME_FOLDER, checkpoint_name)
            logger.info('Saving updated checkpoint to {}'.format(renamed_checkpoint_path))
            saver.save(sess, renamed_checkpoint_path)
        # Remove the tensorflow 'checkpoint' file
        os.remove(tf_checkpoint_file)
        # Remove the old checkpoint from the checkpoint dir
        for file_name in os.listdir(checkpoint_dir):
            if checkpoint_name in file_name:
                os.remove(os.path.join(checkpoint_dir, file_name))
        # Copy the new checkpoint with renamed variable to the checkpoint dir
        for file_name in os.listdir(TEMP_RENAME_FOLDER):
            full_file_name = os.path.join(os.path.abspath(TEMP_RENAME_FOLDER), file_name)
            if os.path.isfile(full_file_name) and file_name != "checkpoint":
                shutil.copy(full_file_name, checkpoint_dir)
        # Remove files from temp_rename_folder
        shutil.rmtree(TEMP_RENAME_FOLDER)
        tf.reset_default_graph()
    # If either of the checkpoint files (index, meta or data) not found
    except tf.errors.NotFoundError as err:
        log_and_exit("No checkpoint found: {}".format(err),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_400)
    # Thrown when user modifies model, checkpoints get corrupted/truncated
    except tf.errors.DataLossError as err:
        log_and_exit("User modified ckpt, unrecoverable dataloss or corruption: {}"
                     .format(err),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_400)
    except ValueError as err:
        if utils.is_error_bad_ckpnt(err):
            log_and_exit("Couldn't find 'checkpoint' file or checkpoints in given \
                            directory ./checkpoint: {}".format(err),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        else:
            log_and_exit("ValueError in rename checkpoint: {}".format(err),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        log_and_exit("Exception in rename checkpoint: {}".format(ex),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)