Пример #1
0
    def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, is_standalone_module):
        """ standalone_module means it a submodule that is not inlined in parent module,
        and will be quantized separately as one unit.

        When we are preparing a standalone module:
        input of the module is observed in parent module, output of the module
        is observed in the standalone module.
        Returns:
            model(GraphModule): prepared standalone module with following attributes:
                _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that
                                         needs to be observed in parent module
                _output_is_observed(Bool): a boolean variable indicate whether the output of the
                                   custom module is observed or not
        """
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}

        additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {})
        self.patterns = get_combined_dict(get_default_quant_patterns(), additional_quant_patterns)

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            additional_qat_module_mapping = prepare_custom_config_dict.get("additioanl_qat_module_mapping", {})
            self._qat_swap_modules(model, additional_qat_module_mapping)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = prepare_custom_config_dict.get("standalone_module_name", None)
        standalone_module_classes = prepare_custom_config_dict.get("standalone_module_class", None)
        custom_module_classes = get_custom_module_class_keys(prepare_custom_config_dict, "float_to_observed_custom_module_class")
        matches = self._find_matches(
            model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuantizeHandler object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_')
        model_device = assert_and_get_unique_device(model)

        def insert_observer(node, observer):
            """Insert observer for node by modifying the observed_graph and
               attach observer module to the model
               Args:
                 node: Node
                 observer: observer/fake_quantize module instance
            """
            # respect device affinity when adding observers
            if model_device:
                observer.to(model_device)
            # add observer module as attribute
            prefix = node.name + '_activation_post_process_'
            get_new_observer_name = get_new_attr_name_with_prefix(prefix)
            observer_name = get_new_observer_name(model)
            setattr(model, observer_name, observer)
            # put observer instance activation_post_process map
            self.activation_post_process_map[node.name] = observer
            # insert observer call
            env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
            observed_node_names_set.add(node.name)

        def insert_observer_for_special_module(quantize_handler):
            """ Insert observer for custom module and standalone module
              Returns: standalone_module_input_idxs: the indexs for inputs that needs
              to be observed by parent module
            """
            standalone_module_input_idxs = None
            if isinstance(quantize_handler, CustomModuleQuantizeHandler):
                custom_module = self.modules[node.target]
                custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
                observed_custom_module_class = \
                    get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig)
                observed_custom_module = \
                    observed_custom_module_class.from_float(custom_module)
                parent_name, name = _parent_name(node.target)
                setattr(self.modules[parent_name], name, observed_custom_module)
            elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
                # observe standalone module
                standalone_module = self.modules[node.target]
                prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx
                observed_standalone_module = prepare(standalone_module, {"": qconfig})
                observed_standalone_module.qconfig = qconfig
                standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
                observed_standalone_module = mark_observed_standalone_module(observed_standalone_module)
                parent_name, name = _parent_name(node.target)
                setattr(self.modules[parent_name], name, observed_standalone_module)
                self.modules[node.target] = observed_standalone_module
            return standalone_module_input_idxs

        def insert_observer_for_output_of_the_node(
                node,
                quantize_handler,
                qconfig,
                standalone_module_input_idxs):
            """ Insert observer/fake_quantize module for output of the observed module
            if needed
            """
            # don't need to insert observer for output if activation does not
            # need to be statically quantized
            if activation_is_statically_quantized(qconfig):
                if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and model.training:
                    # we only insert fake quantize module in qat
                    activation_post_process_ctr = \
                        get_default_output_activation_post_process_map().get(pattern, None)
                    assert activation_post_process_ctr is not None, \
                        "activation_post_process constructor not provided for " + \
                        "pattern:" + str(pattern)
                    insert_observer(node, activation_post_process_ctr())
                elif (isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and
                      not model.training) or isinstance(quantize_handler, CopyNode):
                    # inserting observers for output of observed module, or mark the output
                    # as observed
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))
                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif ((isinstance(quantize_handler, Add) or isinstance(quantize_handler, Mul)) and
                      quantize_handler.num_node_args == 1):
                    input_node = matched_nodes[-1]  # first node in the sequence

                    def input_is_observed(arg):
                        return isinstance(arg, Node) and arg.name in observed_node_names_set
                    # This is checking if one of the argument of add/mul
                    # is an observed node
                    # If both of the inputs are number,
                    # we will not consider the output to be observed
                    if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
                        observed_node_names_set.add(node.name)
                elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
                    assert node.op == 'call_module'
                    output_is_observed = self.modules[node.target]._output_is_observed
                    if output_is_observed:
                        observed_node_names_set.add(node.name)
                elif quantize_handler.all_node_args:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    insert_observer(node, new_observer)

            # insert observer for input of standalone module
            if standalone_module_input_idxs is not None:
                for idx in standalone_module_input_idxs:
                    if node.args[idx].name not in observed_node_names_set:
                        new_observer = qconfig.activation()
                        insert_observer(node.args[idx], new_observer)

        def insert_observer_for_input_arg_of_observed_node(arg):
            """
               Input:
                 arg: input arg node for another observed node, e.g.
                 input activaiton for functional linear node
            """
            if node.name not in observed_node_names_set and node.name in quants:
                if is_standalone_module and node.name in graph_inputs:
                    # we'll insert observer for input of standalone module
                    # in parent graph
                    standalone_module_observed_input_idxs.append(graph_inputs.index(node.name))
                    return
                _, activation_post_process_ctr = quants[node.name]
                if activation_post_process_ctr is not None:
                    insert_observer(node, activation_post_process_ctr())

        result_node : Optional[Node] = None
        for node in model.graph.nodes:
            if node.op == 'output':
                observed_graph.output(load_arg(node.args[0]))
                result_node = node
                continue
            if node.name in observed_node_names_set:
                continue

            root_node, matched_nodes, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                # index for input of custom module that needs to be observed in parent
                if qconfig is not None:
                    standalone_module_input_idxs = insert_observer_for_special_module(obj)
                    insert_observer_for_output_of_the_node(
                        node, obj, qconfig, standalone_module_input_idxs)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            insert_observer_for_input_arg_of_observed_node(node)


        model = GraphModule(model, observed_graph)
        self.save_state(model)
        model = mark_observed_module(model)
        if is_standalone_module:
            assert result_node is not None
            assert isinstance(result_node.args[0], Node), \
                'standalone module returning dict is not yet supported'
            # indicator for whether output is observed or not.
            # This used for correctly quantize standalone modules
            output_is_observed = result_node.args[0].name in observed_node_names_set
            model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs
            model._output_is_observed = output_is_observed
        return model
def prepare_dynamic(model, qconfig_dict=None):
    propagate_qconfig_(model, qconfig_dict)
Пример #3
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    # Other parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help=
        "The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets."
    )
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")

    parser.add_argument(
        '--version_2_with_negative',
        action='store_true',
        help=
        'If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument(
        '--null_score_diff_threshold',
        type=float,
        default=0.0,
        help=
        "If null_score - best_non_null is greater than the threshold predict null."
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json output file."
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")

    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument("--do_calibration",
                        action='store_true',
                        help="Whether to do calibration.")
    parser.add_argument("--do_int8_inference",
                        action='store_true',
                        help="Whether to run int8 inference.")
    parser.add_argument("--do_fp32_inference",
                        action='store_true',
                        help="Whether to run fp32 inference.")
    parser.add_argument("--mkldnn_eval",
                        action='store_true',
                        help="evaluation with MKLDNN")
    parser.add_argument("--tune",
                        action='store_true',
                        help="run ilit to tune int8 acc.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="SQuAD task")
    parser.add_argument("--warmup",
                        type=int,
                        default=5,
                        help="warmup for performance")

    args = parser.parse_args()

    args.predict_file = os.path.join(
        args.output_dir, 'predictions_{}_{}.txt'.format(
            list(filter(None, args.model_name_or_path.split('/'))).pop(),
            str(args.max_seq_length)))

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    mix_qkv = False
    if args.do_calibration or args.do_int8_inference or args.tune:
        mix_qkv = True

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool('.ckpt' in args.model_name_or_path),
        config=config,
        mix_qkv=mix_qkv,
        cache_dir=args.cache_dir if args.cache_dir else None)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex
            apex.amp.register_half_function(torch, 'einsum')
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                tokenizer,
                                                evaluate=False,
                                                output_examples=False)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # Save the trained model and the tokenizer
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = model.module if hasattr(
            model,
            'module') else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir,
                                            force_download=True,
                                            mix_qkv=mix_qkv)
        tokenizer = tokenizer_class.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
        model.to(args.device)

    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME,
                              recursive=True)))
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce model loading logs

        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        for checkpoint in checkpoints:
            # Reload the model
            global_step = checkpoint.split(
                '-')[-1] if len(checkpoints) > 1 else ""
            if args.mkldnn_eval or args.do_fp32_inference:
                model = model_class.from_pretrained(checkpoint,
                                                    force_download=True)
                model.to(args.device)

                # Evaluate
                result, _ = evaluate(args,
                                     model,
                                     tokenizer,
                                     prefix=global_step)
                result = dict(
                    (k + ('_{}'.format(global_step) if global_step else ''), v)
                    for k, v in result.items())
                results.update(result)

            if args.tune:

                def eval_func_for_ilit(model):
                    result, _ = evaluate(args, model, tokenizer)
                    for key in sorted(result.keys()):
                        logger.info("  %s = %s", key, str(result[key]))
                    bert_task_acc_keys = [
                        'best_f1', 'f1', 'mcc', 'spearmanr', 'acc'
                    ]
                    for key in bert_task_acc_keys:
                        if key in result.keys():
                            logger.info("Finally Eval {}:{}".format(
                                key, result[key]))
                            acc = result[key]
                            break
                    return acc

                model = model_class.from_pretrained(checkpoint,
                                                    force_download=True,
                                                    mix_qkv=True)
                model.to(args.device)
                dataset = load_and_cache_examples(args,
                                                  tokenizer,
                                                  evaluate=True,
                                                  output_examples=False)
                args.eval_batch_size = args.per_gpu_eval_batch_size * max(
                    1, args.n_gpu)
                eval_task = "squad"
                import ilit
                dataset = ilit.data.DATASETS('pytorch')['bert'](
                    dataset=dataset, task=eval_task)
                test_dataloader = ilit.data.DataLoader(
                    'pytorch', dataset, batch_size=args.eval_batch_size)
                tuner = ilit.Tuner("./conf.yaml")
                tuner.tune(model,
                           test_dataloader,
                           eval_func=eval_func_for_ilit)
                exit(0)

            if args.do_calibration:
                model = model_class.from_pretrained(checkpoint,
                                                    force_download=True,
                                                    mix_qkv=True)
                model.to(args.device)
                model.qconfig = default_per_channel_qconfig
                propagate_qconfig_(model)
                add_observer_(model)
                # Evaluate
                evaluate(args,
                         model,
                         tokenizer,
                         prefix=global_step,
                         calibration=True)
                convert(model, inplace=True)
                quantized_model_path = "squad" + str(
                    global_step) + "_quantized_model"
                if not os.path.exists(quantized_model_path):
                    os.makedirs(quantized_model_path)
                model.save_pretrained(quantized_model_path)
                result, _ = evaluate(args,
                                     model,
                                     tokenizer,
                                     prefix=global_step)
                result = dict(
                    (k + ('_{}'.format(global_step) if global_step else ''), v)
                    for k, v in result.items())
                results.update(result)
            if args.do_int8_inference:
                model = model_class.from_pretrained(checkpoint,
                                                    force_download=True,
                                                    mix_qkv=True)
                model.to(args.device)
                model.qconfig = default_per_channel_qconfig
                propagate_qconfig_(model)
                add_observer_(model)
                convert(model, inplace=True)
                quantized_model_path = "squad" + str(
                    global_step) + "_quantized_model"
                if not os.path.exists(quantized_model_path):
                    logger.info("Please run calibration first!")
                    return
                model_bin_file = os.path.join(quantized_model_path,
                                              "pytorch_model.bin")
                state_dict = torch.load(model_bin_file)
                model.load_state_dict(state_dict)
                print(model)
                with torch.autograd.profiler.profile() as prof:
                    result, _ = evaluate(args,
                                         model,
                                         tokenizer,
                                         prefix=global_step)
                print(prof.key_averages().table(sort_by="cpu_time_total"))
                result = dict(
                    (k + ('_{}'.format(global_step) if global_step else ''), v)
                    for k, v in result.items())
                results.update(result)
    logger.info("Results: {}".format(results))

    return results
Пример #4
0
def prepare(
        model: GraphModule,
        qconfig_dict: Any,
        node_name_to_scope: Dict[str, Tuple[str, type]],
        prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
        equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
        is_standalone_module: bool = False) -> ObservedGraphModule:
    """ standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    How the standalone module is observed is specified by `input_quantized_idxs` and
    `output_quantized_idxs` in the prepare_custom_config for the standalone module
    Args:
        node_name_to_scope: mapping from node name to the scope of the module which contains the node.
        The scope is a tuple of fully qualified path of the module and the type of the module
    Returns:
        model(GraphModule): prepared standalone module
        attributes:
            _standalone_module_input_quantized_idxs(List[Int]): a list of
                indexes for the graph input that is expected to be quantized,
                same as input_quantized_idxs configuration provided
                for the standalone module
            _standalone_module_output_quantized_idxs(List[Int]): a list of
                indexs for the graph output that is quantized
                same as input_quantized_idxs configuration provided
                for the standalone module
    """
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}
    if equalization_qconfig_dict is None:
        equalization_qconfig_dict = {}

    additional_quant_patterns = \
        prepare_custom_config_dict.get("additional_quant_pattern", {})
    # mapping from a tuple of nodes in reverse order to uninitialized
    #   QuantizeHandler subclass. For example,
    # {
    #   # match a single node
    #   (<class 'torch.nn.modules.conv.Conv3d'>:
    #     <class 'torch.quantization.fx.quantize.ConvRelu'>),
    #   # match multiple nodes in reverse order
    #   ((<function relu at 0x7f766a7360d0>, <built-in function add>):
    #     <class 'torch.quantization.fx.quantize.Add'>),
    # }
    patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict(
        get_default_quant_patterns(), additional_quant_patterns)

    convert_dict_to_ordered_dict(qconfig_dict)
    convert_dict_to_ordered_dict(equalization_qconfig_dict)
    flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
    # TODO: support regex as well
    propagate_qconfig_(model, flattened_qconfig_dict)

    if model.training:
        additional_qat_module_mapping = prepare_custom_config_dict.get(
            "additional_qat_module_mapping", {})
        qat_swap_modules(model, additional_qat_module_mapping)
        qconfig_dict = update_qconfig_for_qat(qconfig_dict, additional_qat_module_mapping)

    qconfig_dict = update_qconfig_for_fusion(model, qconfig_dict)
    equalization_qconfig_dict = update_qconfig_for_fusion(model, equalization_qconfig_dict)

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    modules = dict(model.named_modules())

    # fill qconfig_map, a map from node name to qconfig, used in find_matches
    equalization_qconfig_map = generate_qconfig_map(model, modules, model.graph, equalization_qconfig_dict, node_name_to_scope)
    qconfig_map = generate_qconfig_map(model, modules, model.graph, qconfig_dict, node_name_to_scope)

    # match the patterns that will get quantized
    standalone_module_name_configs = prepare_custom_config_dict.get(
        "standalone_module_name", [])
    standalone_module_class_configs = prepare_custom_config_dict.get(
        "standalone_module_class", [])

    standalone_module_names = [config[0] for config in standalone_module_name_configs]
    standalone_module_classes = [config[0] for config in standalone_module_class_configs]
    custom_module_classes = get_custom_module_class_keys(
        prepare_custom_config_dict, "float_to_observed_custom_module_class")
    matches = find_matches(
        model.graph, modules, patterns, qconfig_map, standalone_module_names,
        standalone_module_classes, custom_module_classes)

    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    run_prepare_fx_on_standalone_modules(
        model, modules, matches, prepare_custom_config_dict)

    result_node = insert_observers_for_model(
        model, modules, matches, qconfig_map,
        model.graph, prepare_custom_config_dict,
        equalization_qconfig_map,
        input_quantized_idxs, output_quantized_idxs)

    save_state(model, qconfig_map, node_name_to_scope, patterns,
               prepare_custom_config_dict, equalization_qconfig_map)
    preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", []))
    model = ObservedGraphModule(model, model.graph, preserved_attributes)
    if is_standalone_module:
        assert result_node is not None
        assert isinstance(result_node.args[0], Node), \
            "standalone module only supports returning simple value currently"\
            "(not tuple, dict etc.)"
        # these inputs are observed in parent
        # converting List[int] to Tensor since module attribute is
        # Union[Tensor, Module]
        model._standalone_module_input_quantized_idxs = \
            torch.tensor(input_quantized_idxs)
        model._standalone_module_output_quantized_idxs = torch.tensor(output_quantized_idxs)
    return model
def quantization_auto_tuning(model, run_fn, run_args, run_calibration,
                             calibration_args, metric = "top-1", relative_error = 0.01,
                             absolute_error = 0.01, relative_err_master = True,
                             fallback_op_types=DEFAULT_QUANTIZED_OP,
                             performance_fine_tuning=True):
    r"""
    The auto-tuning tool API for user.

    Args:
        model:    the model should already be prepared by first two steps in
        run_fn:   evaluation function, the return should be {accuracy_metric:value}
                  for example, {"acc": 0.62}
        run_args: this is the args of evaluation function, recommond using
                  the type of parser.parse_args()
        run_calibration: calibration function
        calibration_args: the args for calibration function
        metric:   the accuracy metric, such as: acc, f1, mcc, top-1, map and so.
        relative_error: the maximum torlerance ratio of relative error btween fp32 model
                        and quantized model, the default value is 0.01 (1%)
        absolute_error: the maximum torlerance ratio of absolute error btween fp32 model
                        and quantized model, the default value is 0.01 (1%)
        relative_err_master: whether relative_error or absolute_error is import for you
        fallback_op_types: which type quantized op should be auto-tuing fallback, there
                           are generally several diffrent quantized op in the quantized
                           model, sometimes, you just want to fallback some types not all types.
                           for example: conv/linear are in a CV model, you just want to fallback
                           linear, then fallback_op_types={nnq.Linear}

    """
    #run fp32 evaluation to collect accuracy and fp32 tensor
    model_tmp = copy.deepcopy(model)
    propagate_qconfig_(model_tmp)
    add_save_observer_(model_tmp)
    result = run_fn(model_tmp, run_args)
    fp32_accuracy = result[metric]
    #run calibration
    model_tmp = copy.deepcopy(model)
    prepare(model_tmp, inplace = True)
    run_calibration(model_tmp, calibration_args)

    #run int8 to collect accuracy and int8 tensor
    convert(model_tmp, inplace=True)
    add_save_observer_(model_tmp)
    result = run_fn(model_tmp, run_args)
    int8_accuracy = result[metric]
    save_quantized_model(model_tmp, {},
                            save_directory="quantized_model", save_config = True)
    need_to_fallback = False
    if relative_err_master:
       need_to_fallback = True if int8_accuracy < fp32_accuracy * (1 - relative_error) else False
    else:
       need_to_fallback = True if int8_accuracy < fp32_accuracy * (1 - absolute_error) else False

    #begin to fallback auto-tuning
    if need_to_fallback:
       #comput distance between fp32 tensor and int8 dequantize tensor
       layer_gap_dict = {}
       compute_fp32_and_int8_dequantize_gap(model, "", layer_gap_dict)
       #sort layer according to above distance to construct auto-tuning search order
       sorted_gap = sorted(layer_gap_dict.items(), key=lambda item:item[1], reverse=True)
       for item in sorted_gap:
           print(item)

       cur_int8_accuracy = int8_accuracy
       pre_int8_accuracy = int8_accuracy #the currenty best accuacy
       len_gap_dict = len(layer_gap_dict)#the maximum search times
       fallback_layers = {} #bucket to save fallback layers
       accuracy_improvment_dict = {}
       count = 0
       #fallback auto-tuning
       while need_to_fallback and  count < len_gap_dict:
             #fallback layers in the bucket
             model_tmp = copy.deepcopy(model)
             propagate_qconfig_(model_tmp)
             fallback_layers.update({sorted_gap[count % len_gap_dict][0]:False})
             fallback_layer(model_tmp, "", fallback_layers)

             #calibration and validate the accuracy of
             #partitial fallback quantized model_tmp
             add_observer_(model_tmp)
             run_calibration(model_tmp, calibration_args)
             convert(model_tmp, inplace = True)
             result = run_fn(model_tmp, run_args)
             cur_int8_accuracy=result[metric]
             if cur_int8_accuracy > pre_int8_accuracy:
                accuracy_improvment_dict.update(
                       {sorted_gap[count % len_gap_dict][0]:
                        cur_int8_accuracy - pre_int8_accuracy })
                print("accuracy_improvment_dict", accuracy_improvment_dict)
                pre_int8_accuracy = cur_int8_accuracy
             else:
                del fallback_layers[sorted_gap[count % len_gap_dict][0]]
             count += 1
             if relative_err_master:
                need_to_fallback = True if pre_int8_accuracy < fp32_accuracy * (1 - relative_error) else False
             else:
                need_to_fallback = True if pre_int8_accuracy < fp32_accuracy * (1 - absolute_error) else False
       print(performance_fine_tuning)
       performance_fine_tuning=True
       if performance_fine_tuning:
          #furtherly search the  subset of fallback_layers to improve performance
          fined_fallback_layers = {}
          #sort layer by accuracy value difference
          fallback_layers = sorted(fallback_layers.items(),
                            key=lambda item:item[1], reverse=True)
          print("Candidate fallback_layers:")
          for item in fallback_layers:
              print(item)
          for layer in fallback_layers:
              print(type(layer))
              model_tmp = copy.deepcopy(model)
              propagate_qconfig_(model_tmp)
              fined_fallback_layers.update({layer[0]:layer[1]})
              fallback_layer(model_tmp, "", fined_fallback_layers)

              #calibration and validate the accuracy of
              #partitial fallback quantized model_tmp
              add_observer_(model_tmp)
              run_calibration(model_tmp, calibration_args)
              convert(model_tmp, inplace = True)
              result = run_fn(model_tmp, run_args)
              cur_int8_accuracy=result[metric]
              if relative_err_master and cur_int8_accuracy >= fp32_accuracy * (1 - relative_error):
                 break
              elif not relative_err_master and cur_int8_accuracy >= fp32_accuracy * (1 - absolute_error):
                 break

       if performance_fine_tuning:
          fallback_layers = fined_fallback_layers

       propagate_qconfig_(model)
       fallback_layer(model, "", fallback_layers)

       #calibration and validate the accuracy of
       #partitial fallback quantized model
       add_observer_(model)
       run_calibration(model, calibration_args)
       convert(model, inplace = True)
       result = run_fn(model, run_args)
       print("The fallback layers as following:")
       for layer in fallback_layers.keys():
           print(layer)
       print("The Int8 accuacy:", result)
       save_quantized_model(model, fallback_layers=fallback_layers,
                            save_directory="quantized_model", save_config = True)
                                    
Пример #6
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list;"
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " +
        ", ".join(processors.keys()))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")

    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument("--mkldnn_eval",
                        action='store_true',
                        help="evaluation with MKLDNN")
    parser.add_argument("--mkldnn_train",
                        action='store_true',
                        help="training with MKLDNN")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument("--do_fp32_inference",
                        action='store_true',
                        help="Whether to run fp32 inference.")
    parser.add_argument("--do_calibration",
                        action='store_true',
                        help="Whether to do calibration.")
    parser.add_argument("--do_int8_inference",
                        action='store_true',
                        help="Whether to run int8 inference.")
    parser.add_argument("--do_bf16",
                        action='store_true',
                        help="run bf16 evaluation / training.")
    parser.add_argument(
        "--tune",
        action='store_true',
        help="run Low Precision Optimization Tool to tune int8 acc.")
    parser.add_argument("--warmup",
                        type=int,
                        default=2,
                        help="warmup for performance")
    parser.add_argument('-i',
                        "--iter",
                        default=0,
                        type=int,
                        help='For accuracy measurement only.')
    parser.add_argument('--benchmark',
                        dest='benchmark',
                        action='store_true',
                        help='run benchmark')
    parser.add_argument('-r',
                        "--accuracy_only",
                        dest='accuracy_only',
                        action='store_true',
                        help='For accuracy measurement only.')
    parser.add_argument(
        "--tuned_checkpoint",
        default='./',
        type=str,
        metavar='PATH',
        help=
        'path to checkpoint tuned by Low Precision Optimization Tool (default: ./)'
    )
    parser.add_argument('--int8',
                        dest='int8',
                        action='store_true',
                        help='run benchmark')

    args = parser.parse_args()

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    # Set seed
    set_seed(args)

    # Prepare GLUE task
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)
    mix_qkv = False
    if args.do_calibration or args.do_int8_inference or args.tune:
        mix_qkv = True
    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool('.ckpt' in args.model_name_or_path),
        config=config,
        mix_qkv=mix_qkv,
        bf16=args.do_bf16,
        mkldnn_train=args.mkldnn_train,
        cache_dir=args.cache_dir if args.cache_dir else None)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                args.task_name,
                                                tokenizer,
                                                evaluate=False)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = model.module if hasattr(
            model,
            'module') else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir)
        tokenizer = tokenizer_class.from_pretrained(args.output_dir)
        model.to(args.device)

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        tokenizer = tokenizer_class.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME,
                              recursive=True)))
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split(
                '-')[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split(
                '/')[-1] if checkpoint.find('checkpoint') != -1 else ""

            logger.info("Evaluate:" + args.task_name)
            if args.mkldnn_eval or args.do_fp32_inference or args.do_bf16:
                model = model_class.from_pretrained(checkpoint)
                model.to(args.device)
                result = evaluate(args, model, tokenizer, prefix=prefix)
                result = dict((k + '_{}'.format(global_step), v)
                              for k, v in result.items())
                results.update(result)

            if args.tune:

                def eval_func_for_lpot(model):
                    result, perf = evaluate(args,
                                            model,
                                            tokenizer,
                                            prefix=prefix)
                    bert_task_acc_keys = [
                        'acc_and_f1', 'f1', 'mcc', 'spearmanr', 'acc'
                    ]
                    for key in bert_task_acc_keys:
                        if key in result.keys():
                            logger.info("Finally Eval {}:{}".format(
                                key, result[key]))
                            acc = result[key]
                            break
                    return acc

                model = model_class.from_pretrained(checkpoint, mix_qkv=True)
                model.to(args.device)
                eval_task_names = (
                    "mnli", "mnli-mm") if args.task_name == "mnli" else (
                        args.task_name, )

                for eval_task in eval_task_names:
                    eval_dataset = load_and_cache_examples(args,
                                                           eval_task,
                                                           tokenizer,
                                                           evaluate=True)

                    args.eval_batch_size = args.per_gpu_eval_batch_size * max(
                        1, args.n_gpu)
                    # multi-gpu eval
                    if args.n_gpu > 1:
                        model = torch.nn.DataParallel(model)

                    if args.mkldnn_eval:
                        from torch.utils import mkldnn as mkldnn_utils
                        model = mkldnn_utils.to_mkldnn(model)
                        print(model)
                    from lpot import Quantization
                    quantizer = Quantization("./conf.yaml")
                    if eval_task != "squad":
                        eval_task = 'classifier'
                    eval_dataset = quantizer.dataset(
                        'bert',
                        dataset=eval_dataset,
                        task=eval_task,
                        model_type=args.model_type)
                    test_dataloader = quantizer.dataloader(
                        eval_dataset, batch_size=args.eval_batch_size)
                    quantizer(model,
                              test_dataloader,
                              eval_func=eval_func_for_lpot)
                exit(0)

            if args.benchmark or args.accuracy_only:
                model = model_class.from_pretrained(checkpoint, mix_qkv=True)
                model.to(args.device)

                if args.int8:
                    from lpot.utils.pytorch import load
                    new_model = load(
                        os.path.abspath(
                            os.path.expanduser(args.tuned_checkpoint)), model)
                else:
                    new_model = model
                result, _ = evaluate(args, new_model, tokenizer, prefix=prefix)
                exit(0)

            if args.do_calibration:
                model = model_class.from_pretrained(checkpoint, mix_qkv=True)
                model.to(args.device)
                model.qconfig = default_per_channel_qconfig
                fallback_layers = {}
                if args.model_name_or_path == "bert-base-uncased" and args.task_name == "mrpc":
                    fallback_layers = {"bert.encoder.layer.9.output.dense."}
                propagate_qconfig_(model)
                fallback_layer(model,
                               layer_name="",
                               exculde_layers=fallback_layers)
                add_observer_(model)
                result, _ = evaluate(args,
                                     model,
                                     tokenizer,
                                     prefix=global_step,
                                     calibration=True)
                convert(model, inplace=True)
                quantized_model_path = args.task_name + "_quantized_model"
                if not os.path.exists(quantized_model_path):
                    os.makedirs(quantized_model_path)
                model.save_pretrained(quantized_model_path)
                print(model)
                result, _ = evaluate(args, model, tokenizer, prefix=prefix)
            if args.do_int8_inference:
                model = model_class.from_pretrained(checkpoint, mix_qkv=True)
                model.to(args.device)
                model.qconfig = default_per_channel_qconfig
                fallback_layers = {}
                if args.model_name_or_path == "bert-base-uncased" and args.task_name == "mrpc":
                    fallback_layers = {"bert.encoder.layer.9.output.dense."}
                propagate_qconfig_(model)
                fallback_layer(model,
                               layer_name="",
                               exculde_layers=fallback_layers)
                add_observer_(model)
                convert(model, inplace=True)
                quantized_model_path = args.task_name + "_quantized_model"
                if not os.path.exists(quantized_model_path):
                    logger.error(
                        "please do calibrantion befor run int8 inference")
                    return
                prepare(model, inplace=True)
                convert(model, inplace=True)
                model_bin_file = os.path.join(quantized_model_path,
                                              "pytorch_model.bin")
                state_dict = torch.load(model_bin_file)
                model.load_state_dict(state_dict)
                result, _ = evaluate(args, model, tokenizer, prefix=prefix)

    return results