def test_nested3(self):
        r"""More complicated nested test case with child qconfig overrides
        parent qconfig
        """
        model = AnnotatedCustomConfigNestedModel()
        prepare(model)

        def checkPrepModules(model, before_calib=False):
            if before_calib:
                self.checkObservers(model)
            self.checkNoPrepModules(model)
            self.checkNoPrepModules(model.sub1)
            self.checkNoPrepModules(model.sub1.fc)
            self.checkNoPrepModules(model.sub1.relu)
            self.checkNoPrepModules(model.sub2)
            self.checkHasPrepModules(model.sub2.fc1)
            self.checkHasPrepModules(model.sub2.fc2)
            self.checkHasPrepModules(model.fc3)

        checkPrepModules(model, True)

        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            checkPrepModules(model)
            self.checkWrappedQuantizedLinear(model.sub2.fc1)
            self.checkWrappedQuantizedLinear(model.sub2.fc2)
            self.checkWrappedQuantizedLinear(model.fc3)
            test_only_eval_fn(model, self.calib_data)
            self.checkScriptable(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn,
                         self.calib_data)
        checkQuantized(model)
Beispiel #2
0
    def test_nested1(self):
        r"""Test quantization for nested model, top level 'fc3' and
        'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
        """
        model = NestedModel().eval()
        qconfig_dict = {'fc3': default_qconfig, 'sub2.fc1': default_qconfig}

        def checkPrepModules(model, before_calib=False):
            if before_calib:
                self.checkObservers(model)
            self.checkNoPrepModules(model)
            self.checkNoPrepModules(model.sub1)
            self.checkNoPrepModules(model.sub1.fc)
            self.checkNoPrepModules(model.sub1.relu)
            self.checkNoPrepModules(model.sub2)
            self.checkHasPrepModules(model.sub2.fc1)
            self.checkNoPrepModules(model.sub2.fc2)
            self.checkHasPrepModules(model.fc3)

        model = prepare(model, qconfig_dict)
        checkPrepModules(model, True)
        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            checkPrepModules(model)
            self.checkLinear(model.sub1.fc)
            self.checkQuantizedLinear(model.fc3)
            self.checkQuantizedLinear(model.sub2.fc1)
            self.checkLinear(model.sub2.fc2)
            test_only_eval_fn(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(NestedModel().eval(), test_only_eval_fn,
                         self.calib_data, qconfig_dict)
        checkQuantized(model)
Beispiel #3
0
    def test_nested1(self):
        r"""Test quantization for nested model, top level 'fc3' and
        'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
        """
        model = AnnotatedNestedModel()

        def checkPrepModules(model, before_calib=False):
            if before_calib:
                self.checkObservers(model)
            self.checkNoPrepModules(model)
            self.checkNoPrepModules(model.sub1)
            self.checkNoPrepModules(model.sub1.fc)
            self.checkNoPrepModules(model.sub1.relu)
            self.checkNoPrepModules(model.sub2)
            self.checkHasPrepModules(model.sub2.fc1)
            self.checkNoPrepModules(model.sub2.fc2)
            self.checkHasPrepModules(model.fc3)

        prepare(model)
        checkPrepModules(model, True)
        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            checkPrepModules(model)
            self.checkLinear(model.sub1.fc)
            self.checkWrappedQuantizedLinear(model.fc3)
            self.checkWrappedQuantizedLinear(model.sub2.fc1)
            self.checkLinear(model.sub2.fc2)
            test_only_eval_fn(model, self.calib_data)
            self.checkScriptable(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(AnnotatedNestedModel(), test_only_eval_fn,
                         self.calib_data)
        checkQuantized(model)
Beispiel #4
0
    def test_manual(self):
        r"""User inserts QuantStub and DeQuantStub in model code
        and call the quantization utility functions.
        """
        model = QuantStubModel()
        # propagate the qconfig of parents to children, model is changed
        # inplace
        prepare(model)
        self.checkObservers(model)

        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.fc), nnq.Linear)
            test_only_eval_fn(model, self.calib_data)
            self.checkScriptable(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data)
        checkQuantized(model)
Beispiel #5
0
    def test_nested2(self):
        model = AnnotatedSubNestedModel()
        prepare(model)

        def checkPrepModules(model, before_calib=False):
            if before_calib:
                self.checkObservers(model)
            self.checkNoPrepModules(model)
            self.checkNoPrepModules(model.sub1)
            self.checkNoPrepModules(model.sub1.fc)
            self.checkNoPrepModules(model.sub1.relu)
            self.checkHasPrepModules(model.sub2)
            self.checkNoPrepModules(model.sub2.module.fc1)
            self.checkNoPrepModules(model.sub2.module.fc2)
            self.checkHasPrepModules(model.fc3)

        checkPrepModules(model, True)

        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            checkPrepModules(model)
            self.checkLinear(model.sub1.fc)
            self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
            self.checkQuantizedLinear(model.sub2.module.fc1)
            self.checkQuantizedLinear(model.sub2.module.fc2)
            self.checkWrappedQuantizedLinear(model.fc3)
            test_only_eval_fn(model, self.calib_data)
            self.checkScriptable(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn,
                         self.calib_data)
        checkQuantized(model)
Beispiel #6
0
    def test_skip_quant(self):
        r"""The case when we want to skip quantizing some layers
        """

        model = SkipQuantModel()
        prepare(model)
        self.checkObservers(model)

        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            self.checkLinear(model.fc)
            self.checkQuantDequant(model.sub)
            self.checkQuantizedLinear(model.sub.module.fc1)
            self.checkQuantizedLinear(model.sub.module.fc2)
            self.assertEqual(type(model.sub.module.relu), nnq.ReLU)
            self.checkScriptable(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(SkipQuantModel(), test_only_eval_fn, self.calib_data)
        checkQuantized(model)
    def test_fuse_module_eval(self):
        model = ModelForFusion(default_qconfig)
        model.eval()
        fuse_modules(model,
                     [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']])
        self.assertEqual(type(model.conv1), nni.ConvReLU2d,
                         "Fused Conv + BN + Relu first layer (BN is folded)")
        self.assertEqual(type(model.conv1[0]), nn.Conv2d,
                         "Fused Conv + BN + Relu (Conv + folded BN only)")
        self.assertEqual(type(model.conv1[1]), nn.ReLU,
                         "Fused Conv + BN + Relu second layer (Relu only)")
        self.assertEqual(type(model.bn1), nn.Identity,
                         "Fused Conv + BN + Relu second layer (Skipped BN)")
        self.assertEqual(type(model.relu1), nn.Identity,
                         "Fused Conv + BN + Relu second layer (Skipped Relu)")

        self.assertEqual(type(model.sub1.conv), nn.Conv2d,
                         "Fused submodule Conv + folded BN")
        self.assertEqual(type(model.sub1.bn), nn.Identity,
                         "Fused submodule (skipped BN)")
        self.assertEqual(type(model.sub2.conv), nn.Conv2d,
                         "Non-fused submodule Conv")
        self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
                         "Non-fused submodule ReLU")

        model = prepare(model)
        self.checkObservers(model)
        test_only_eval_fn(model, self.img_data)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)
            test_only_eval_fn(model, self.img_data)

        checkQuantized(model)

        model = ModelForFusion(default_qconfig).eval()
        fuse_modules(model,
                     [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']])
        model = quantize(model, test_only_eval_fn, self.img_data)
        checkQuantized(model)
Beispiel #8
0
    def quantize(self):
        """Quantize the model and save as self.quantizedmodel.

        Weights are quantized to 8-bit precision and the model is saved as
        quantizemodel.pth in self.model_dir directory.

        Returns
        -------
        [type]
            [description]
        """
        print("Quantizing model..")

        self.learn.model.cpu()
        self.quantizedmodel = convert(self.learn.model, inplace=False)
        torch.save(self.quantizedmodel, self.model_dir / "quantizedmodel.pth")
        self.learn.model.to("cuda")
        return self.quantizedmodel
    def test_relu(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.relu = nn.ReLU()

            def forward(self, x):
                x = self.relu(x)
                return x

        m = M().train()
        m.qconfig = default_qconfig
        m = prepare_qat(m)
        # make sure no activation_post_process is inserted for relu
        self.assertFalse(hasattr(m, "activation_post_process"))
        m = convert(m)
        # make sure ReLU module is not changed
        self.assertTrue(type(m.relu), nn.ReLU)
    def test_manual(self):
        model = ManualLinearQATModel()
        model = prepare_qat(model)
        self.checkObservers(model)
        test_only_train_fn(model, self.train_data)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.fc1), nnq.Linear)
            self.assertEqual(type(model.fc2), nnq.Linear)
            test_only_eval_fn(model, self.calib_data)
            self.checkScriptable(model, self.calib_data)

        checkQuantized(model)

        model = quantize_qat(ManualLinearQATModel(), test_only_train_fn,
                             self.train_data)
        checkQuantized(model)
Beispiel #11
0
def load(checkpoint_dir, model):
    """Execute the quantize process on the specified model.

    Args:
        checkpoint_dir (dir): The folder of checkpoint.
                              'best_configure.yaml' and 'best_model_weights.pt' are needed
                              in This directory. 'checkpoint' dir is under workspace folder
                              and workspace folder is define in configure yaml file.
        model (object): fp32 model need to do quantization.

    Returns:
        (object): quantized model
    """

    tune_cfg_file = os.path.join(
        os.path.abspath(os.path.expanduser(checkpoint_dir)),
        'best_configure.yaml')
    weights_file = os.path.join(
        os.path.abspath(os.path.expanduser(checkpoint_dir)),
        'best_model_weights.pt')
    assert os.path.exists(
        tune_cfg_file), "tune configure file %s didn't exist" % tune_cfg_file
    assert os.path.exists(
        weights_file), "weight file %s didn't exist" % weights_file

    q_model = copy.deepcopy(model.eval())

    with open(tune_cfg_file, 'r') as f:
        tune_cfg = yaml.load(f, Loader=yaml.UnsafeLoader)

    op_cfgs = _cfg_to_qconfig(tune_cfg)
    _propagate_qconfig(q_model, op_cfgs)
    # sanity check common API misusage
    if not any(hasattr(m, 'qconfig') and m.qconfig for m in q_model.modules()):
        logger.warn(
            "None of the submodule got qconfig applied. Make sure you "
            "passed correct configuration through `qconfig_dict` or "
            "by assigning the `.qconfig` attribute directly on submodules")
    add_observer_(q_model)
    q_model = convert(q_model, inplace=True)
    weights = torch.load(weights_file)
    q_model.load_state_dict(weights)
    return q_model
    def test_resnet_base(self, qconfig):
        r"""Test quantization for bottleneck topology used in resnet/resnext
        and add coverage for conversion of average pool and float functional
        """
        model = ResNetBase().float().eval()
        model = QuantWrapper(model)
        model.qconfig = qconfig
        fuse_list = ['module.conv1', 'module.bn1', 'module.relu1']
        fuse_modules(model, fuse_list, inplace=True)
        model = prepare(model)
        self.checkObservers(model)
        test_only_eval_fn(model, self.img_data)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.module.conv1), nn._intrinsic.quantized.ConvReLU2d)
            self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
            self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d)
            test_only_eval_fn(model, self.img_data)

        checkQuantized(model)
    def _test_activation_convert_numerics_impl(self, Act, data):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.act = Act()
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.act(x)
                x = self.dequant(x)
                return x

        m = M().train()
        m.qconfig = default_qat_qconfig
        m = prepare_qat(m)
        before_convert = m(data)
        m = convert(m)
        after_convert = m(data)
        self.assertEqual(before_convert, after_convert)
    def test_manual(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualLinearQATModel(qengine)
                model = prepare_qat(model)
                self.checkObservers(model)
                test_only_train_fn(model, self.train_data)
                model = convert(model)

                def checkQuantized(model):
                    self.assertEqual(type(model.fc1), nnq.Linear)
                    self.assertEqual(type(model.fc2), nnq.Linear)
                    test_only_eval_fn(model, self.calib_data)
                    self.checkScriptable(model, self.calib_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = quantize_qat(ManualLinearQATModel(qengine),
                                     test_only_train_fn, [self.train_data])
                checkQuantized(model)
    def test_conv_linear(self):
        model = ManualConvLinearQATModel()

        model = prepare_qat(model)
        self.checkObservers(model)

        test_only_train_fn(model, self.img_data)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv), nnq.Conv2d)
            self.assertEqual(type(model.fc1), nnq.Linear)
            self.assertEqual(type(model.fc2), nnq.Linear)
            test_only_eval_fn(model, self.img_data)
            self.checkScriptable(model, self.img_data)

        checkQuantized(model)

        model = ManualConvLinearQATModel()
        model = quantize_qat(model, test_only_train_fn, self.img_data)
        checkQuantized(model)
Beispiel #16
0
    def test_single_layer(self, qconfig):
        r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
        to nnq.Linear which is the quantized version of the module
        """
        model = AnnotatedSingleLayerLinearModel()
        model.qconfig = qconfig
        model = prepare(model)
        # Check if observers and quant/dequant nodes are inserted
        self.checkNoPrepModules(model)
        self.checkHasPrepModules(model.fc1)
        self.checkObservers(model)

        test_only_eval_fn(model, self.calib_data)
        model = convert(model)

        def checkQuantized(model):
            self.checkNoPrepModules(model)
            self.checkHasPrepModules(model.fc1)
            self.checkWrappedQuantizedLinear(model.fc1)
            test_only_eval_fn(model, self.calib_data)
            self.checkScriptable(model, self.calib_data)

        checkQuantized(model)

        # test one line API - out of place version
        base = AnnotatedSingleLayerLinearModel()
        base.qconfig = qconfig
        keys_before = set(list(base.state_dict().keys()))
        model = quantize(base, test_only_eval_fn, self.calib_data)
        checkQuantized(model)
        keys_after = set(list(base.state_dict().keys()))
        self.assertEqual(keys_before,
                         keys_after)  # simple check that nothing changed

        # in-place version
        model = AnnotatedSingleLayerLinearModel()
        model.qconfig = qconfig
        quantize(model, test_only_eval_fn, self.calib_data, inplace=True)
        checkQuantized(model)
    def test_compare_model_stub_functional_static(self):
        r"""Compare the output of static quantized functional layer and its float shadow module
        """

        qengine = torch.backends.quantized.engine

        model = ModelWithFunctionals().eval()
        model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
        q_model = prepare(model, inplace=False)
        q_model(self.img_data_2d[0][0])
        q_model = convert(q_model)
        module_swap_list = [nnq.FloatFunctional]
        ob_dict = compare_model_stub(model, q_model, module_swap_list,
                                     self.img_data_2d[0][0])
        self.assertEqual(len(ob_dict), 6)
        self.assertTrue(isinstance(q_model.mycat, Shadow))
        self.assertTrue(isinstance(q_model.myadd, Shadow))
        self.assertTrue(isinstance(q_model.mymul, Shadow))
        self.assertTrue(isinstance(q_model.myadd_relu, Shadow))
        self.assertTrue(isinstance(q_model.my_scalar_add, Shadow))
        self.assertTrue(isinstance(q_model.my_scalar_mul, Shadow))
        for k, v in ob_dict.items():
            self.assertTrue(v["float"].shape == v["quantized"].shape)
def train_subject_specific_quant(subject,
                                 epochs=500,
                                 batch_size=32,
                                 lr=0.001,
                                 silent=False,
                                 plot=True,
                                 **kwargs):
    """
    Trains a subject specific model for the given subject

    Parameters:
     - subject:    Integer in the Range 1 <= subject <= 9
     - epochs:     Number of epochs to train
     - batch_size: Batch Size
     - lr:         Learning Rate
     - silent:     bool, if True, hide all output including the progress bar
     - plot:       bool, if True, generate plots
     - kwargs:     Remaining arguments passed to the EEGnet model

    Returns: (model, metrics)
     - model:   t.nn.Module, trained model
     - metrics: t.tensor, size=[1, 4], accuracy, precision, recall, f1
    """
    # load the data
    train_samples, train_labels = get_data(subject, training=True)
    test_samples, test_labels = get_data(subject, training=False)
    train_loader = as_data_loader(train_samples,
                                  train_labels,
                                  batch_size=batch_size)
    # test_loader = as_data_loader(test_samples, test_labels, batch_size=test_labels.shape[0])
    test_loader = as_data_loader(test_samples,
                                 test_labels,
                                 batch_size=batch_size)

    # prepare quantization configuration
    qconfig = tq.QConfig(
        activation=tq.MinMaxObserver.with_args(dtype=t.quint8),
        weight=tq.MinMaxObserver.with_args(dtype=t.qint8))

    # prepare the model
    model = EEGNetQuant(T=train_samples.shape[2], qconfig=qconfig, **kwargs)
    model.initialize_params()
    if t.cuda.is_available():
        model = model.cuda()

    # prepare the quantization
    tq.prepare_qat(model, inplace=True)

    # prepare loss function and optimizer
    loss_function = t.nn.CrossEntropyLoss()
    optimizer = t.optim.Adam(model.parameters(), lr=lr, eps=1e-7)
    scheduler = None

    # print the training setup
    print_summary(model, optimizer, loss_function, scheduler)

    # prepare progress bar
    with tqdm(desc=f"Subject {subject}",
              total=epochs,
              leave=False,
              disable=silent,
              unit='epoch',
              ascii=True) as pbar:

        # Early stopping is not allowed in this mode, because the testing data cannot be used for
        # training!
        model, metrics, _, history = _train_net(subject,
                                                model,
                                                train_loader,
                                                test_loader,
                                                loss_function,
                                                optimizer,
                                                scheduler=scheduler,
                                                epochs=epochs,
                                                early_stopping=False,
                                                plot=plot,
                                                pbar=pbar)

    # convert the model into a quantized model
    model = model.cpu()
    tq.convert(model, inplace=True)

    metrics = get_metrics_from_model(model, test_loader)

    if not silent:
        print(f"Subject {subject}: accuracy = {metrics[0, 0]}")
    return model, metrics, history
Beispiel #19
0
 def _qat_swap_modules(
         self, root: torch.nn.Module,
         additional_qat_module_mapping: Dict[Callable, Callable]) -> None:
     all_mappings = get_combined_dict(
         get_default_qat_module_mappings(), additional_qat_module_mapping)
     convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
Beispiel #20
0
    def _test_model_impl(self,
                         mode,
                         name,
                         model,
                         eager_quantizable_model,
                         check_with_eager=True,
                         diff_of_quant=None,
                         diff_from_eager=None):
        if diff_of_quant is None or diff_from_eager is None:
            diff_of_quant = {}
            diff_from_eager = {}

        if mode not in diff_of_quant or mode not in diff_from_eager:
            diff_of_quant[mode] = {}
            diff_from_eager[mode] = {}

        input_tensor = torch.rand(1, 3, 224, 224)
        input_tensor_inception = torch.rand(1, 3, 299, 299)
        output_value = torch.randint(0, 1, (1, ))

        # print('quantizing:', name, ' mode:', mode)
        if name == 'inception_v3':
            input_value = input_tensor_inception
        else:
            input_value = input_tensor

        qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
        qconfig_dict = {'': qconfig}
        graph_module = symbolic_trace(model)
        # print('graph module:', graph_module.src)
        script = torch.jit.script(graph_module)

        # make sure graph module and script module are both runanble
        original_out = graph_module(input_value)
        is_not_tuple_out = not isinstance(original_out, tuple)
        script_out = script(input_value)
        self.assertEqual(
            (original_out - script_out).abs().max(), 0,
            'Reslut of original graph module and script module does not match')

        # set to train just before quantization
        if mode != 'static':
            model.train()

        graph_module = fuse_fx(graph_module)
        prepared = prepare_fx(graph_module, qconfig_dict)

        if mode == 'ddp':
            mp.spawn(run_ddp,
                     args=(world_size, prepared),
                     nprocs=world_size,
                     join=True)
        elif mode == 'qat':
            assert prepared.training, 'prepared must be in training mode for qat'
            optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
            criterion = nn.CrossEntropyLoss()
            train_one_epoch(prepared, criterion,
                            optimizer, [(input_value, output_value)],
                            torch.device('cpu'), 1)
        else:
            for i in range(10):
                prepared(input_value)

        # print('after observation root:', prepared.root)

        qgraph = convert_fx(prepared)
        # print('after quantization root:', qgraph.root)
        # print('after quantization code:', qgraph.src)
        qgraph.eval()
        qgraph_script = torch.jit.script(qgraph)
        # print('quantized and scripted:', qgraph_script.graph)

        qgraph_out = qgraph(input_value)
        qgraph_script = qgraph_script(input_value)

        if is_not_tuple_out:
            diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
            assert torch.allclose(qgraph_out,
                                  qgraph_script), 'graph, scripted graph'
        else:
            print('tuple output')

        if eager_quantizable_model is not None:
            # comparing to eager mode quantization
            qeager = eager_quantizable_model
            ref_out = qeager(input_value)
            qeager.qconfig = qconfig
            if mode == 'static':
                qeager.fuse_model()
                prepare(qeager, inplace=True)
            else:
                qeager.train()
                qeager.fuse_model()
                prepare_qat(qeager, inplace=True)

            # calibration
            if mode == 'ddp':
                mp.spawn(run_ddp,
                         args=(world_size, qeager),
                         nprocs=world_size,
                         join=True)
            elif mode == 'qat':
                assert qeager.training, 'qeager should be in training mode for qat'
                optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
                train_one_epoch(qeager, criterion, optimizer,
                                [(input_value, output_value)],
                                torch.device('cpu'), 1)
            else:
                for i in range(10):
                    qeager(input_value)

            # print('ref after observation:', qeager)

            convert(qeager, inplace=True)
            qeager.eval()

            # print('ref after quantization:', qeager)
            qeager_out = qeager(input_value)
            qeager_script = torch.jit.script(qeager)
            qscript_out = qeager_script(input_value)
            if is_not_tuple_out:
                diff_from_eager[mode][name] = (qeager_out -
                                               qgraph_out).abs().max()
                if check_with_eager:
                    self.assertEqual(
                        diff_from_eager[mode][name], 0,
                        'Result of graph mode quantization and ' +
                        'eager mode quantization on model: ' + name +
                        ' should match. Mode: ' + mode + ' diff:' +
                        str(diff_from_eager[mode][name]))
Beispiel #21
0
    def __init__(
        self,
        config,
        model,
        length_prediction_model,
        trg_vocab,
        beam_size,
        use_gold_length,
        beam_ranking_algorithm,
        quantize,
        embed_quantize,
    ):
        super().__init__()
        length_prediction_model = length_prediction_model.create_eval_module()
        if quantize:
            self.model = torch.quantization.quantize_dynamic(
                model,
                {
                    torch.nn.Linear:
                    torch.quantization.per_channel_dynamic_qconfig
                },
                dtype=torch.qint8,
                inplace=False,
            )
            # embedding quantization
            if embed_quantize != EmbedQuantizeType.NONE:

                # 8-bit embedding quantization
                if embed_quantize == EmbedQuantizeType.BIT_8:
                    ## identify nn.Embedding
                    for module in self.model.modules():
                        if isinstance(module, torch.nn.Embedding):
                            module.qconfig = float_qparams_weight_only_qconfig

                    prepare(self.model, inplace=True)
                    convert(self.model, inplace=True)

                # 4-bit embedding quantization
                elif embed_quantize == EmbedQuantizeType.BIT_4:
                    raise NotImplementedError(
                        "4bit embedding quantization not yet supported")
                else:
                    raise NotImplementedError(
                        "Embedding Quantization should be either 8bit or 4bit")

            self.length_prediction_model = torch.quantization.quantize_dynamic(
                length_prediction_model,
                {
                    torch.nn.Linear:
                    torch.quantization.per_channel_dynamic_qconfig
                },
                dtype=torch.qint8,
                inplace=False,
            )
        else:
            self.model = model
            self.length_prediction_model = length_prediction_model

        self.trg_vocab = ScriptVocabulary(
            list(trg_vocab),
            pad_idx=trg_vocab.get_pad_index(),
            bos_idx=trg_vocab.get_bos_index(-1),
            eos_idx=trg_vocab.get_eos_index(-1),
            mask_idx=trg_vocab.get_mask_index(),
        )
        self.length_beam_size = beam_size
        self.use_gold_length = use_gold_length
        self.beam_ranking_algorithm = get_beam_ranking_function(
            ranking_algorithm=beam_ranking_algorithm)
        self.clip_target_length = config.clip_target_length
        self.targetlen_cap = config.targetlen_cap
        self.targetlen_a = config.targetlen_a
        self.targetlen_b = config.targetlen_b
        self.targetlen_c = config.targetlen_c
Beispiel #22
0
 def _qat_swap_modules(self, root):
     convert(root, mapping=get_qat_module_mappings(), inplace=True, remove_qconfig=False)
def convert(model, path):
    state_dict = load_model(path)
    model.load_state_dict(state_dict)
    model.eval()
    opt_model = quantization.convert(model)
    return opt_model
Beispiel #24
0
 def _qat_swap_modules(self, root):
     convert(root,
             mapping=DEFAULT_QAT_MODULE_MAPPING,
             inplace=True,
             remove_qconfig=False)
Beispiel #25
0
 def _qat_swap_modules(self, root, additional_qat_module_mapping):
     all_mappings = get_combined_dict(get_default_qat_module_mappings(),
                                      additional_qat_module_mapping)
     convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
            print('6nd shape before reshape{}'.format(x.shape))
        x = self.pool(F.relu(self.conv6(x)))
        if DEBUG:
            print('6nd shape before reshape{}'.format(x.shape))

        x = x.view(-1, 875)
        if DEBUG:
            print('7nd_shape after reshape{}'.format(x.shape))
        x = F.relu(self.fc1(x))
        if DEBUG:
            print('8nd_shape {}'.format(x.shape))
        x = F.relu(self.fc2(x))
        if DEBUG:
            print('9nd_shape {}'.format(x.shape))
        x = F.relu(self.fc3(x))
        if DEBUG:
            print('10nd_shape {}'.format(x.shape))
        x = self.fc4(x)
        if DEBUG:
            print('11nd_shape {}'.format(x.shape))
        return x


warnings.filterwarnings('ignore')
device = 'cuda'
model = torch.load('last_cnn.pt', 'cuda')
modelq = quantization.convert(model)
example = torch.rand(1, 3, 512, 512)
ex = example.to(device)
traced_script_module = torch.jit.trace(modelq, ex)
traced_script_module.save('last_jit_model_moda.pt')
Beispiel #27
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    # Other parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help=
        "The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets."
    )
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")

    parser.add_argument(
        '--version_2_with_negative',
        action='store_true',
        help=
        'If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument(
        '--null_score_diff_threshold',
        type=float,
        default=0.0,
        help=
        "If null_score - best_non_null is greater than the threshold predict null."
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json output file."
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")

    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument("--do_calibration",
                        action='store_true',
                        help="Whether to do calibration.")
    parser.add_argument("--do_int8_inference",
                        action='store_true',
                        help="Whether to run int8 inference.")
    parser.add_argument("--do_fp32_inference",
                        action='store_true',
                        help="Whether to run fp32 inference.")
    parser.add_argument("--mkldnn_eval",
                        action='store_true',
                        help="evaluation with MKLDNN")
    parser.add_argument(
        "--tune",
        action='store_true',
        help="run Low Precision Optimization Tool to tune int8 acc.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="SQuAD task")
    parser.add_argument("--warmup",
                        type=int,
                        default=5,
                        help="warmup for performance")
    parser.add_argument('-i',
                        "--iter",
                        default=0,
                        type=int,
                        help='For accuracy measurement only.')
    parser.add_argument('--benchmark',
                        dest='benchmark',
                        action='store_true',
                        help='run benchmark')
    parser.add_argument('-r',
                        "--accuracy_only",
                        dest='accuracy_only',
                        action='store_true',
                        help='For accuracy measurement only.')
    parser.add_argument(
        "--tuned_checkpoint",
        default='./',
        type=str,
        metavar='PATH',
        help=
        'path to checkpoint tuned by Low Precision Optimization Tool (default: ./)'
    )
    parser.add_argument('--int8',
                        dest='int8',
                        action='store_true',
                        help='run benchmark')

    args = parser.parse_args()

    args.predict_file = os.path.join(
        args.output_dir, 'predictions_{}_{}.txt'.format(
            list(filter(None, args.model_name_or_path.split('/'))).pop(),
            str(args.max_seq_length)))

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    mix_qkv = False
    if args.do_calibration or args.do_int8_inference or args.tune:
        mix_qkv = True

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool('.ckpt' in args.model_name_or_path),
        config=config,
        mix_qkv=mix_qkv,
        cache_dir=args.cache_dir if args.cache_dir else None)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex
            apex.amp.register_half_function(torch, 'einsum')
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                tokenizer,
                                                evaluate=False,
                                                output_examples=False)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # Save the trained model and the tokenizer
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = model.module if hasattr(
            model,
            'module') else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir,
                                            force_download=True,
                                            mix_qkv=mix_qkv)
        tokenizer = tokenizer_class.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
        model.to(args.device)

    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME,
                              recursive=True)))
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce model loading logs

        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        for checkpoint in checkpoints:
            # Reload the model
            global_step = checkpoint.split(
                '-')[-1] if len(checkpoints) > 1 else ""
            if args.mkldnn_eval or args.do_fp32_inference:
                model = model_class.from_pretrained(checkpoint,
                                                    force_download=True)
                model.to(args.device)

                # Evaluate
                result, _ = evaluate(args,
                                     model,
                                     tokenizer,
                                     prefix=global_step)
                result = dict(
                    (k + ('_{}'.format(global_step) if global_step else ''), v)
                    for k, v in result.items())
                results.update(result)

            if args.tune:

                def eval_func_for_lpot(model):
                    result, _ = evaluate(args, model, tokenizer)
                    for key in sorted(result.keys()):
                        logger.info("  %s = %s", key, str(result[key]))
                    bert_task_acc_keys = [
                        'best_f1', 'f1', 'mcc', 'spearmanr', 'acc'
                    ]
                    for key in bert_task_acc_keys:
                        if key in result.keys():
                            logger.info("Finally Eval {}:{}".format(
                                key, result[key]))
                            acc = result[key]
                            break
                    return acc

                model = model_class.from_pretrained(checkpoint,
                                                    force_download=True,
                                                    mix_qkv=True)
                model.to(args.device)
                dataset = load_and_cache_examples(args,
                                                  tokenizer,
                                                  evaluate=True,
                                                  output_examples=False)
                args.eval_batch_size = args.per_gpu_eval_batch_size * max(
                    1, args.n_gpu)
                eval_task = "squad"
                from lpot import Quantization
                quantizer = Quantization("./conf.yaml")
                dataset = quantizer.dataset('bert',
                                            dataset=dataset,
                                            task=eval_task,
                                            model_type=args.model_type)
                test_dataloader = quantizer.dataloader(
                    dataset, batch_size=args.eval_batch_size)
                quantizer(model, test_dataloader, eval_func=eval_func_for_lpot)
                exit(0)

            if args.benchmark or args.accuracy_only:
                model = model_class.from_pretrained(checkpoint, mix_qkv=True)
                model.to(args.device)
                if args.int8:
                    from lpot.utils.pytorch import load
                    new_model = load(
                        os.path.abspath(
                            os.path.expanduser(args.tuned_checkpoint)), model)
                else:
                    new_model = model
                result, _ = evaluate(args,
                                     new_model,
                                     tokenizer,
                                     prefix=global_step)
                exit(0)

            if args.do_calibration:
                model = model_class.from_pretrained(checkpoint,
                                                    force_download=True,
                                                    mix_qkv=True)
                model.to(args.device)
                model.qconfig = default_per_channel_qconfig
                propagate_qconfig_(model)
                add_observer_(model)
                # Evaluate
                evaluate(args,
                         model,
                         tokenizer,
                         prefix=global_step,
                         calibration=True)
                convert(model, inplace=True)
                quantized_model_path = "squad" + str(
                    global_step) + "_quantized_model"
                if not os.path.exists(quantized_model_path):
                    os.makedirs(quantized_model_path)
                model.save_pretrained(quantized_model_path)
                result, _ = evaluate(args,
                                     model,
                                     tokenizer,
                                     prefix=global_step)
                result = dict(
                    (k + ('_{}'.format(global_step) if global_step else ''), v)
                    for k, v in result.items())
                results.update(result)
            if args.do_int8_inference:
                model = model_class.from_pretrained(checkpoint,
                                                    force_download=True,
                                                    mix_qkv=True)
                model.to(args.device)
                model.qconfig = default_per_channel_qconfig
                propagate_qconfig_(model)
                add_observer_(model)
                convert(model, inplace=True)
                quantized_model_path = "squad" + str(
                    global_step) + "_quantized_model"
                if not os.path.exists(quantized_model_path):
                    logger.info("Please run calibration first!")
                    return
                model_bin_file = os.path.join(quantized_model_path,
                                              "pytorch_model.bin")
                state_dict = torch.load(model_bin_file)
                model.load_state_dict(state_dict)
                print(model)
                with torch.autograd.profiler.profile() as prof:
                    result, _ = evaluate(args,
                                         model,
                                         tokenizer,
                                         prefix=global_step)
                print(prof.key_averages().table(sort_by="cpu_time_total"))
                result = dict(
                    (k + ('_{}'.format(global_step) if global_step else ''), v)
                    for k, v in result.items())
                results.update(result)
    logger.info("Results: {}".format(results))

    return results
Beispiel #28
0
 def _qat_swap_modules(self, root, additional_qat_module_mapping):
     all_mappings = get_default_qat_module_mappings().copy()
     for k, v in additional_qat_module_mapping.items():
         all_mappings[k] = v
     convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
Beispiel #29
0
def convert_dynamic(module):
    convert(module, DEFAULT_DYNAMIC_MODULE_MAPPING, inplace=True)
Beispiel #30
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list;"
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " +
        ", ".join(processors.keys()))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")

    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument("--mkldnn_eval",
                        action='store_true',
                        help="evaluation with MKLDNN")
    parser.add_argument("--mkldnn_train",
                        action='store_true',
                        help="training with MKLDNN")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument("--do_fp32_inference",
                        action='store_true',
                        help="Whether to run fp32 inference.")
    parser.add_argument("--do_calibration",
                        action='store_true',
                        help="Whether to do calibration.")
    parser.add_argument("--do_int8_inference",
                        action='store_true',
                        help="Whether to run int8 inference.")
    parser.add_argument("--do_bf16",
                        action='store_true',
                        help="run bf16 evaluation / training.")
    parser.add_argument("--tune",
                        action='store_true',
                        help="run ilit to tune int8 acc.")
    parser.add_argument("--warmup",
                        type=int,
                        default=2,
                        help="warmup for performance")

    args = parser.parse_args()

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    # Set seed
    set_seed(args)

    # Prepare GLUE task
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)
    mix_qkv = False
    if args.do_calibration or args.do_int8_inference or args.tune:
        mix_qkv = True
    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool('.ckpt' in args.model_name_or_path),
        config=config,
        mix_qkv=mix_qkv,
        bf16=args.do_bf16,
        mkldnn_train=args.mkldnn_train,
        cache_dir=args.cache_dir if args.cache_dir else None)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                args.task_name,
                                                tokenizer,
                                                evaluate=False)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = model.module if hasattr(
            model,
            'module') else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir)
        tokenizer = tokenizer_class.from_pretrained(args.output_dir)
        model.to(args.device)

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        tokenizer = tokenizer_class.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME,
                              recursive=True)))
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split(
                '-')[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split(
                '/')[-1] if checkpoint.find('checkpoint') != -1 else ""

            logger.info("Evaluate:" + args.task_name)
            if args.mkldnn_eval or args.do_fp32_inference or args.do_bf16:
                model = model_class.from_pretrained(checkpoint)
                model.to(args.device)
                result = evaluate(args, model, tokenizer, prefix=prefix)
                result = dict((k + '_{}'.format(global_step), v)
                              for k, v in result.items())
                results.update(result)

            if args.tune:

                def eval_func_for_ilit(model):
                    result, perf = evaluate(args,
                                            model,
                                            tokenizer,
                                            prefix=prefix)
                    bert_task_acc_keys = [
                        'acc_and_f1', 'f1', 'mcc', 'spearmanr', 'acc'
                    ]
                    for key in bert_task_acc_keys:
                        if key in result.keys():
                            logger.info("Finally Eval {}:{}".format(
                                key, result[key]))
                            acc = result[key]
                            break
                    return acc

                model = model_class.from_pretrained(checkpoint, mix_qkv=True)
                model.to(args.device)
                eval_task_names = (
                    "mnli", "mnli-mm") if args.task_name == "mnli" else (
                        args.task_name, )

                for eval_task in eval_task_names:
                    eval_dataset = load_and_cache_examples(args,
                                                           eval_task,
                                                           tokenizer,
                                                           evaluate=True)

                    args.eval_batch_size = args.per_gpu_eval_batch_size * max(
                        1, args.n_gpu)
                    # multi-gpu eval
                    if args.n_gpu > 1:
                        model = torch.nn.DataParallel(model)

                    if args.mkldnn_eval:
                        from torch.utils import mkldnn as mkldnn_utils
                        model = mkldnn_utils.to_mkldnn(model)
                        print(model)
                    import ilit
                    tuner = ilit.Tuner("./conf.yaml")
                    if eval_task != "squad":
                        eval_task = 'classifier'
                    eval_dataset = tuner.dataset('bert',
                                                 dataset=eval_dataset,
                                                 task=eval_task)
                    test_dataloader = tuner.dataloader(
                        eval_dataset, batch_size=args.eval_batch_size)
                    tuner.tune(model,
                               test_dataloader,
                               eval_func=eval_func_for_ilit)
                exit(0)

            if args.do_calibration:
                model = model_class.from_pretrained(checkpoint, mix_qkv=True)
                model.to(args.device)
                model.qconfig = default_per_channel_qconfig
                fallback_layers = {}
                if args.model_name_or_path == "bert-base-uncased" and args.task_name == "mrpc":
                    fallback_layers = {"bert.encoder.layer.9.output.dense."}
                propagate_qconfig_(model)
                fallback_layer(model,
                               layer_name="",
                               exculde_layers=fallback_layers)
                add_observer_(model)
                result, _ = evaluate(args,
                                     model,
                                     tokenizer,
                                     prefix=global_step,
                                     calibration=True)
                convert(model, inplace=True)
                quantized_model_path = args.task_name + "_quantized_model"
                if not os.path.exists(quantized_model_path):
                    os.makedirs(quantized_model_path)
                model.save_pretrained(quantized_model_path)
                print(model)
                result, _ = evaluate(args, model, tokenizer, prefix=prefix)
            if args.do_int8_inference:
                model = model_class.from_pretrained(checkpoint, mix_qkv=True)
                model.to(args.device)
                model.qconfig = default_per_channel_qconfig
                fallback_layers = {}
                if args.model_name_or_path == "bert-base-uncased" and args.task_name == "mrpc":
                    fallback_layers = {"bert.encoder.layer.9.output.dense."}
                propagate_qconfig_(model)
                fallback_layer(model,
                               layer_name="",
                               exculde_layers=fallback_layers)
                add_observer_(model)
                convert(model, inplace=True)
                quantized_model_path = args.task_name + "_quantized_model"
                if not os.path.exists(quantized_model_path):
                    logger.error(
                        "please do calibrantion befor run int8 inference")
                    return
                prepare(model, inplace=True)
                convert(model, inplace=True)
                model_bin_file = os.path.join(quantized_model_path,
                                              "pytorch_model.bin")
                state_dict = torch.load(model_bin_file)
                model.load_state_dict(state_dict)
                result, _ = evaluate(args, model, tokenizer, prefix=prefix)

    return results