def test_is_user_error(error, expected):
    """This function checks the functionality of is_user_error function
    in markov/utils.py

    <is_user_error> determines whether a value error is caused by an invalid checkpoint or model_metadata
    by looking for keywords 'tensor', 'shape', 'checksum', 'checkpoint', 'model_metadata' in the exception message

    Args:
        error (String): Error message to be parsed
        expected (Boolean): Expected return from function
    """
    assert utils.is_user_error(error) == expected
示例#2
0
              s3_bucket=s3_bucket,
              s3_prefix=s3_prefix,
              aws_region=aws_region)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--s3_bucket', help='(string) S3 bucket', type=str)
    parser.add_argument('--s3_prefix', help='(string) S3 prefix', type=str)
    parser.add_argument('--aws_region', help='(string) AWS region', type=str)
    args = parser.parse_args()

    try:
        validate(s3_bucket=args.s3_bucket,
                 s3_prefix=args.s3_prefix,
                 aws_region=args.aws_region)
    except ValueError as err:
        if utils.is_user_error(err):
            log_and_exit("User modified model/model_metadata: {}".format(err),
                         SIMAPP_VALIDATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        else:
            log_and_exit("Validation worker value error: {}".format(err),
                         SIMAPP_VALIDATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        log_and_exit("Validation worker exited with exception: {}".format(ex),
                     SIMAPP_VALIDATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)
    def rename(self, coach_checkpoint_state_file, agent_name):
        '''rename the tensorflow model specified in the rl coach checkpoint state file to include
        agent name

        Args:
            coach_checkpoint_state_file (CheckpointStateFile): CheckpointStateFile instance
            agent_name (str): agent name
        '''
        try:
            LOG.info(
                "Renaming checkpoint from checkpoint_dir: {} for agent: {}".
                format(self._local_dir, agent_name))
            checkpoint_name = str(coach_checkpoint_state_file.read())
            tf_checkpoint_file = os.path.join(self._local_dir, "checkpoint")
            with open(tf_checkpoint_file, "w") as outfile:
                outfile.write(
                    "model_checkpoint_path: \"{}\"".format(checkpoint_name))

            with tf.Session() as sess:
                for var_name, _ in tf.contrib.framework.list_variables(
                        self._local_dir):
                    # Load the variable
                    var = tf.contrib.framework.load_variable(
                        self._local_dir, var_name)
                    new_name = var_name
                    # Set the new name
                    # Replace agent/ or agent_#/ with {agent_name}/
                    new_name = re.sub('agent/|agent_\d+/',
                                      '{}/'.format(agent_name), new_name)
                    # Rename the variable
                    var = tf.Variable(var, name=new_name)
                saver = tf.train.Saver()
                sess.run(tf.global_variables_initializer())
                renamed_checkpoint_path = os.path.join(TEMP_RENAME_FOLDER,
                                                       checkpoint_name)
                LOG.info('Saving updated checkpoint to {}'.format(
                    renamed_checkpoint_path))
                saver.save(sess, renamed_checkpoint_path)
            # Remove the tensorflow 'checkpoint' file
            os.remove(tf_checkpoint_file)
            # Remove the old checkpoint from the checkpoint dir
            for file_name in os.listdir(self._local_dir):
                if checkpoint_name in file_name:
                    os.remove(os.path.join(self._local_dir, file_name))
            # Copy the new checkpoint with renamed variable to the checkpoint dir
            for file_name in os.listdir(TEMP_RENAME_FOLDER):
                full_file_name = os.path.join(
                    os.path.abspath(TEMP_RENAME_FOLDER), file_name)
                if os.path.isfile(
                        full_file_name) and file_name != "checkpoint":
                    shutil.copy(full_file_name, self._local_dir)
            # Remove files from temp_rename_folder
            shutil.rmtree(TEMP_RENAME_FOLDER)
            tf.reset_default_graph()
        # If either of the checkpoint files (index, meta or data) not found
        except tf.errors.NotFoundError as err:
            log_and_exit("No checkpoint found: {}".format(err),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        # Thrown when user modifies model, checkpoints get corrupted/truncated
        except tf.errors.DataLossError as err:
            log_and_exit(
                "User modified ckpt, unrecoverable dataloss or corruption: {}".
                format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_400)
        except tf.errors.OutOfRangeError as err:
            log_and_exit("User modified ckpt: {}".format(err),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        except ValueError as err:
            if utils.is_user_error(err):
                log_and_exit(
                    "Couldn't find 'checkpoint' file or checkpoints in given \
                                directory ./checkpoint: {}".format(err),
                    SIMAPP_SIMULATION_WORKER_EXCEPTION,
                    SIMAPP_EVENT_ERROR_CODE_400)
            else:
                log_and_exit("ValueError in rename checkpoint: {}".format(err),
                             SIMAPP_SIMULATION_WORKER_EXCEPTION,
                             SIMAPP_EVENT_ERROR_CODE_500)
        except Exception as ex:
            log_and_exit("Exception in rename checkpoint: {}".format(ex),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
def training_worker(graph_manager, task_parameters, user_batch_size,
                    user_episode_per_rollout, training_algorithm):
    try:
        # initialize graph
        graph_manager.create_graph(task_parameters)

        # save initial checkpoint
        graph_manager.save_checkpoint()

        # training loop
        steps = 0

        graph_manager.setup_memory_backend()
        graph_manager.signal_ready()

        # To handle SIGTERM
        door_man = utils.DoorMan()

        while steps < graph_manager.improve_steps.num_steps:
            # Collect profiler information only IS_PROFILER_ON is true
            with utils.Profiler(
                    s3_bucket=PROFILER_S3_BUCKET,
                    s3_prefix=PROFILER_S3_PREFIX,
                    output_local_path=TRAINING_WORKER_PROFILER_PATH,
                    enable_profiling=IS_PROFILER_ON):
                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.fetch_from_worker(
                    graph_manager.agent_params.algorithm.
                    num_consecutive_playing_steps)
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout(
                )

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout

                # TODO: Refactor the flow to remove conditional checks for specific algorithms
                # ------------------------sac only---------------------------------------------
                if training_algorithm == TrainingAlgorithm.SAC.value:
                    rollout_steps = graph_manager.memory_backend.get_rollout_steps(
                    )

                    # NOTE: you can train even more iterations than rollout_steps by increasing the number below for SAC
                    agent.ap.algorithm.num_consecutive_training_steps = list(
                        rollout_steps.values())[0]  # rollout_steps[agent]
                # -------------------------------------------------------------------------------
                if graph_manager.should_train():
                    # Make sure we have enough data for the requested batches
                    rollout_steps = graph_manager.memory_backend.get_rollout_steps(
                    )
                    if any(rollout_steps.values()) <= 0:
                        log_and_exit(
                            "No rollout data retrieved from the rollout worker",
                            SIMAPP_TRAINING_WORKER_EXCEPTION,
                            SIMAPP_EVENT_ERROR_CODE_500)

                    # TODO: Refactor the flow to remove conditional checks for specific algorithms
                    # DH: for SAC, check if experience replay memory has enough transitions
                    logger.info("setting trainig algorithm")
                    if training_algorithm == TrainingAlgorithm.SAC.value:
                        replay_mem_size = min([
                            agent.memory.num_transitions()
                            for level in graph_manager.level_managers
                            for agent in level.agents.values()
                        ])
                        episode_batch_size = user_batch_size if replay_mem_size > user_batch_size \
                            else 2**math.floor(math.log(min(rollout_steps.values()), 2))
                    else:
                        logger.info("it is CPPO")
                        episode_batch_size = user_batch_size if min(
                            rollout_steps.values(
                            )) > user_batch_size else 2**math.floor(
                                math.log(min(rollout_steps.values()), 2))
                    # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
                    # as  batch size less than 2 causes the batch list to become a scalar which causes an exception
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            for net_key in agent.ap.network_wrappers:
                                agent.ap.network_wrappers[
                                    net_key].batch_size = episode_batch_size

                    steps += 1

                    graph_manager.phase = core_types.RunPhase.TRAIN
                    graph_manager.train()
                    graph_manager.phase = core_types.RunPhase.UNDEFINED

                    # Check for Nan's in all agents
                    rollout_has_nan = False
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            if np.isnan(agent.loss.get_mean()):
                                rollout_has_nan = True
                    if rollout_has_nan:
                        log_and_exit(
                            "NaN detected in loss function, aborting training.",
                            SIMAPP_TRAINING_WORKER_EXCEPTION,
                            SIMAPP_EVENT_ERROR_CODE_500)

                    if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                        graph_manager.save_checkpoint()
                    else:
                        graph_manager.occasionally_save_checkpoint()

                    # Clear any data stored in signals that is no longer necessary
                    graph_manager.reset_internal_state()

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout

                if door_man.terminate_now:
                    log_and_exit(
                        "Received SIGTERM. Checkpointing before exiting.",
                        SIMAPP_TRAINING_WORKER_EXCEPTION,
                        SIMAPP_EVENT_ERROR_CODE_500)
                    graph_manager.save_checkpoint()
                    break

    except ValueError as err:
        if utils.is_user_error(err):
            log_and_exit("User modified model: {}".format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
        else:
            log_and_exit("An error occured while training: {}".format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        log_and_exit("An error occured while training: {}".format(ex),
                     SIMAPP_TRAINING_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)
    finally:
        graph_manager.data_store.upload_finished_file()
示例#5
0
def rename_checkpoints(checkpoint_dir, agent_name):
    ''' Helper method that rename the specific checkpoint in the CheckpointStateFile
        to be scoped with agent_name
        checkpoint_dir - local checkpoint folder where the checkpoints and .checkpoint file is stored
        agent_name - name of the agent
    '''
    try:
        logger.info(
            "Renaming checkpoint from checkpoint_dir: {} for agent: {}".format(
                checkpoint_dir, agent_name))
        state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
        checkpoint_name = str(state_file.read())
        tf_checkpoint_file = os.path.join(checkpoint_dir, "checkpoint")
        with open(tf_checkpoint_file, "w") as outfile:
            outfile.write(
                "model_checkpoint_path: \"{}\"".format(checkpoint_name))

        config = tf.ConfigProto()
        config.allow_soft_placement = True  # allow placing ops on cpu if they are not fit for gpu
        config.gpu_options.allow_growth = True  # allow the gpu memory allocated for the worker to grow if needed
        config.gpu_options.per_process_gpu_memory_fraction = 0.2
        config.intra_op_parallelism_threads = 1
        config.inter_op_parallelism_threads = 1

        with tf.Session(config=config) as sess:
            for var_name, _ in tf.contrib.framework.list_variables(
                    checkpoint_dir):
                # Load the variable
                var = tf.contrib.framework.load_variable(
                    checkpoint_dir, var_name)
                new_name = var_name
                # Set the new name
                # Replace agent/ or agent_#/ with {agent_name}/
                new_name = re.sub('agent/|agent_\d+/',
                                  '{}/'.format(agent_name), new_name)
                # Rename the variable
                var = tf.Variable(var, name=new_name)
            saver = tf.train.Saver()
            sess.run(tf.global_variables_initializer())
            renamed_checkpoint_path = os.path.join(TEMP_RENAME_FOLDER,
                                                   checkpoint_name)
            logger.info('Saving updated checkpoint to {}'.format(
                renamed_checkpoint_path))
            saver.save(sess, renamed_checkpoint_path)
        # Remove the tensorflow 'checkpoint' file
        os.remove(tf_checkpoint_file)
        # Remove the old checkpoint from the checkpoint dir
        for file_name in os.listdir(checkpoint_dir):
            if checkpoint_name in file_name:
                os.remove(os.path.join(checkpoint_dir, file_name))
        # Copy the new checkpoint with renamed variable to the checkpoint dir
        for file_name in os.listdir(TEMP_RENAME_FOLDER):
            full_file_name = os.path.join(os.path.abspath(TEMP_RENAME_FOLDER),
                                          file_name)
            if os.path.isfile(full_file_name) and file_name != "checkpoint":
                shutil.copy(full_file_name, checkpoint_dir)
        # Remove files from temp_rename_folder
        shutil.rmtree(TEMP_RENAME_FOLDER)
        tf.reset_default_graph()
    # If either of the checkpoint files (index, meta or data) not found
    except tf.errors.NotFoundError as err:
        log_and_exit("No checkpoint found: {}".format(err),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_400)
    # Thrown when user modifies model, checkpoints get corrupted/truncated
    except tf.errors.DataLossError as err:
        log_and_exit(
            "User modified ckpt, unrecoverable dataloss or corruption: {}".
            format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION,
            SIMAPP_EVENT_ERROR_CODE_400)
    except tf.errors.OutOfRangeError as err:
        log_and_exit("User modified ckpt: {}".format(err),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_400)
    except ValueError as err:
        if utils.is_user_error(err):
            log_and_exit(
                "Couldn't find 'checkpoint' file or checkpoints in given \
                            directory ./checkpoint: {}".format(err),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_400)
        else:
            log_and_exit("ValueError in rename checkpoint: {}".format(err),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        log_and_exit("Exception in rename checkpoint: {}".format(ex),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)