示例#1
0
def bert_pretrained_initialisers(config, args):
    if args.synthetic_data:
        logger.info("Initialising from synthetic_data")
        return None

    if args.generated_data:
        logger.info("Initialising from generated_data")
        return None

    # The initialised weights will be broadcast after the session has been created
    if not popdist_root(args):
        return None

    init = None
    if args.onnx_checkpoint:
        logger.info(
            f"Initialising from ONNX checkpoint: {args.onnx_checkpoint}")
        init = utils.load_initializers_from_onnx(args.onnx_checkpoint)

    if args.tf_checkpoint:
        logger.info(f"Initialising from TF checkpoint: {args.tf_checkpoint}")
        init = load_initializers_from_tf(args.tf_checkpoint, True, config,
                                         args.task)

    if init is not None:
        init.update(**get_phased_initializers_from_default(args, init))

    return init
示例#2
0
def save_model(args, session, step, epoch=None, step_in_filename=False):
    if not args.no_model_save and popdist_root(args):
        save_file = "model"
        if epoch is not None:
            save_file += f"_{epoch}"
        if step_in_filename:
            save_file += f":{step}"

        if args.save_initializers_externally:
            save_dir = Path(args.checkpoint_dir, save_file)
            save_dir.mkdir(parents=True, exist_ok=True)
        else:
            save_dir = args.checkpoint_dir
        save_file += '.onnx'
        save_path = os.path.join(save_dir, save_file)
        save_vars = 'vars'.join(save_path.rsplit('model', 1))
        if args.save_initializers_externally:
            if hasattr(args, 'save_vars_prev') and os.path.exists(
                    args.save_vars_prev):
                logger.debug(
                    f'Updating external location for vars to {args.save_vars_prev}.'
                )
                session.updateExternallySavedTensorLocations(
                    args.save_vars_prev, save_vars)
        session.modelToHost(save_path)
        args.save_vars_prev = save_vars
        logger.info(f"Saved model to: {save_path}.")
        if args.save_initializers_externally:
            logger.info(
                f"Saved variables(weights and optimizer state) to: {save_vars}."
            )
示例#3
0
def bert_writer(args):
    writer = None
    if args.log_dir is not None and popdist_root(args):
        log_name = f"{os.path.basename(args.checkpoint_dir)}."\
                   f"{datetime.datetime.now().isoformat()}"
        log_dir = os.path.join(args.log_dir, log_name)
        writer = SummaryWriter(log_dir=log_dir)
    return writer
示例#4
0
def save_model_and_stats(args,
                         session,
                         writer,
                         step,
                         epoch=None,
                         step_in_filename=False):
    if not args.no_model_save and popdist_root(args):
        save_file = "model"
        if epoch is not None:
            save_file += f"_{epoch}"
        if step_in_filename:
            save_file += f":{step}"
        save_file += '.onnx'
        save_path = os.path.join(args.checkpoint_dir, save_file)
        logger.info(f"Saving model to: {save_path}")
        session.modelToHost(save_path)
        utils.save_model_statistics(save_path, writer, step)
示例#5
0
        path = args.profile_dir
        if args.use_popdist:
            path += f"_rank{args.popdist_rank}"
        popvision.set_profiling_vars(path, args.profile_instrument)
        popvision.set_logging_vars()
        args_dict = vars(args)
        args_dict["hostname"] = socket.gethostname()
        args_dict["command"] = ' '.join(sys.argv)
        popvision.save_app_info(args_dict)
        logging_handler = popvision.get_profile_logging_handler()
    else:
        logging_handler = None

    setup_logger(logging.getLevelName(args.log_level), logging_handler)

    if args.wandb and popdist_root(args):
        import wandb
        wandb.init(project="popart-bert", sync_tensorboard=True)
        wandb_config = vars(args)
        wandb.config.update(args)

    logger.info("Program Start")
    logger.info("Hostname: " + socket.gethostname())
    logger.info("Command Executed: " + str(sys.argv))

    # Run the main inference/training session by default
    if args.inference or not args.no_training:
        main(args)

    # If this was a training session and validation isn't disabled; validate.
    if not args.inference and not args.no_validation and not args.no_model_save and popdist_root(