示例#1
0
def bert_pretrained_initialisers(config, args):
    if args.synthetic_data:
        logger.info("Initialising from synthetic_data")
        return None

    if args.generated_data:
        logger.info("Initialising from generated_data")
        return None

    # The initialised weights will be broadcast after the session has been created
    if not popdist_root(args):
        return None

    init = None
    if args.onnx_checkpoint:
        logger.info(
            f"Initialising from ONNX checkpoint: {args.onnx_checkpoint}")
        init = utils.load_initializers_from_onnx(args.onnx_checkpoint)

    if args.tf_checkpoint:
        logger.info(f"Initialising from TF checkpoint: {args.tf_checkpoint}")
        init = load_initializers_from_tf(args.tf_checkpoint, True, config,
                                         args.task)

    if init is not None:
        init.update(**get_phased_initializers_from_default(args, init))

    return init
示例#2
0
def bert_pretrained_initialisers(config, args):

    if args.synthetic_data:
        logger.info("Initialising from synthetic_data")
        return None

    if args.generated_data:
        logger.info("Initialising from generated_data")
        return None

    init = None
    if args.onnx_checkpoint:
        logger.info(
            f"Initialising from ONNX checkpoint: {args.onnx_checkpoint}")
        init = utils.load_initializers_from_onnx(args.onnx_checkpoint)

    if args.tf_checkpoint:
        logger.info(f"Initialising from TF checkpoint: {args.tf_checkpoint}")
        init = load_initializers_from_tf(args.tf_checkpoint, True, config,
                                         args.task)

    if init is not None:
        init.update(**get_phased_initializers_from_default(args, init))

    return init
示例#3
0
def bert_pretrained_initialisers(config, args):
    if args.synthetic_data:
        return None
    if args.onnx_checkpoint:
        logger.info(f"Initialising from ONNX checkpoint: {args.onnx_checkpoint}")
        return utils.load_initializers_from_onnx(args.onnx_checkpoint)
    if args.tf_checkpoint:
        logger.info(f"Initialising from TF checkpoint: {args.tf_checkpoint}")
        return load_initializers_from_tf(args.tf_checkpoint, True, config)
    return None