Exemplo n.º 1
0
Arquivo: base.py Projeto: MHGL/deepvac
 def saveSQ(self):
     loader = self.getCalibrateLoader()
     LOG.logI("Pytorch model static quantize starting, will save model in {}".format(self.sq_output_file))
     with torch.no_grad():
         quantized_model = quantize_jit(self._jit(), self.s_qconfig_dict, calibrate, [loader], inplace=False, debug=False)
         torch.jit.save(quantized_model, self.sq_output_file)
     LOG.logI("Pytorch model static quantize succeeded, saved model in {}".format(self.sq_output_file))
    def checkGraphModeOp(self,
                         module,
                         data,
                         quantized_op,
                         tracing=False,
                         debug=False,
                         check=True,
                         eval_mode=True,
                         dynamic=False):
        if debug:
            print('Testing:', str(module))
        qconfig_dict = {
            '': get_default_qconfig(torch.backends.quantized.engine)
        }

        if eval_mode:
            module = module.eval()
        if dynamic:
            qconfig_dict = {'': default_dynamic_qconfig}
            inputs = data
        else:
            *inputs, target = data[0]
        model = get_script_module(module, tracing, inputs).eval()
        if debug:
            print('input graph:', model.graph)
        models = {}
        outputs = {}
        for d in [True, False]:
            # TODO: _test_only_eval_fn --> default_eval_fn
            if dynamic:
                models[d] = quantize_dynamic_jit(model, qconfig_dict, debug=d)
                # make sure it runs
                outputs[d] = models[d](inputs)
            else:
                # module under test can contain in-place ops, and we depend on
                # input data staying constant for comparisons
                data_copy = copy.deepcopy(data)
                models[d] = quantize_jit(model,
                                         qconfig_dict,
                                         test_only_eval_fn, [data_copy],
                                         inplace=False,
                                         debug=d)
                # make sure it runs
                outputs[d] = models[d](*inputs)

        if debug:
            print('debug graph:', models[True].graph)
            print('non debug graph:', models[False].graph)

        if check:
            # debug and non-debug option should have the same numerics
            self.assertEqual(outputs[True], outputs[False])

            # non debug graph should produce quantized op
            FileCheck().check(quantized_op) \
                       .run(models[False].graph)

        return models[False]
Exemplo n.º 3
0
    def saveSQ(self, loader, input_sample=None):
        if loader is None:
            LOG.logE("You enabled config.static_quantize_dir, but didn't provide self.test_loader in forward-only mode, or self.val_loader in train mode.", exit=True)
        if self.ts is None:
            self.export(input_sample)

        LOG.logI("Pytorch model static quantize starting, will save model in {}".format(self.sq_output_file))
        quantized_model = quantize_jit(self.ts, self.s_qconfig_dict, calibrate, [loader], inplace=False,debug=False)
        torch.jit.save(quantized_model, self.sq_output_file)
        LOG.logI("Pytorch model static quantize succeeded, saved model in {}".format(self.sq_output_file))
# 3. ``convert_jit`` takes a calibrated model and produces a quantized model.
#
# If ``debug`` is False (default option), ``convert_jit`` will:
#
# - Calculate quantization parameters using the observers in the model
# - Ifnsert quantization ops like ``aten::quantize_per_tensor`` and ``aten::dequantize`` to the model, and remove the observer modules after that.
# - Replace floating point ops with quantized ops
# - Freeze the model (remove constant attributes and make them as Constant node in the graph).
# - Fold the quantize and prepack ops like ``quantized::conv2d_prepack`` into an attribute, so we don't need to quantize and prepack the weight everytime we run the model.
#
# If ``debug`` is set to ``True``:
#
# - We can still access the attributes of the quantized model the same way as the original floating point model, e.g. ``model.conv1.weight`` (might be harder if you use a module list or sequential)
# - The arithmetic operations all occur in floating point with the numerics being identical to the final quantized model, allowing for debugging.

quantized_model = quantize_jit(ts_model, {'': qconfig}, calibrate,
                               [data_loader_test])

print(quantized_model.graph)

######################################################################
# As we can see ``aten::conv2d`` is changed to ``quantized::conv2d`` and the floating point weight has been quantized
# and packed into an attribute (``quantized._jit_pass_packed_weight_30``), so we don't need to quantize/pack in runtime.
# Also we can't access the weight attributes anymore after the debug option since they are frozen.
#
# 7. Evaluation
# --------------
# We can now print the size and accuracy of the quantized model.

print('Size of model before quantization')
print_size_of_model(ts_model)
print('Size of model after quantization')
Exemplo n.º 5
0
for data in train_loader:
    bg, image, seg, multi_fr = data['bg'], data['image'], data['seg'], data[
        'multi_fr']
    bg, image, seg, multi_fr = Variable(bg.cuda()), Variable(
        image.cuda()), Variable(seg.cuda()), Variable(multi_fr.cuda())
    if args.trace:
        netB = torch.jit.trace(netB, (image, bg, seg, multi_fr))
        netG = torch.jit.trace(netG, (image, bg, seg, multi_fr))
        if args.quantize:
            input_data = (image, bg, seg, multi_fr)
            from torch.quantization import default_qconfig, quantize_jit

            def calibrate(model, data):
                model(data)

            netB_quantized = quantize_jit(netB, {'': default_qconfig},
                                          calibrate, [input_data])
            netG_quantized = quantize_jit(netG, {'': default_qconfig},
                                          calibrate, [input_data])
    else:
        netB(image, bg, seg, multi_fr)
        netG(image, bg, seg, multi_fr)
    break

print('Starting training')
for epoch in range(0, args.epoch):

    netG.train()
    netD.train()

    lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0