def test_bf16_binary_broadcast_elemwise_mixed_input(function, dtype, ndim):
    dshape_0 = rand_shape_nd(ndim)
    dshape_1 = tuple()
    for i in range(ndim):
        if (randint(0, 1)):
            dshape_1 += (dshape_0[i], )
        else:
            dshape_1 += (1, )

    a = mx.np.random.uniform(-1, 1, dshape_0, dtype=np.float32)
    a_fp32 = mx.np.array(a, dtype=dtype)
    a_bf16 = a.astype('bfloat16')

    b = mx.np.random.uniform(-1, 1, dshape_1, dtype=np.float32)
    b_fp32 = mx.np.array(b, dtype=dtype)
    b_bf16 = b.astype('bfloat16')

    rtol = 1e-1
    atol = 5e-1
    etol = 0

    out_bf_16_1 = function(a_bf16, b_fp32)
    out_fp_32 = function(a_fp32, b_fp32)
    assert_almost_equal_with_err(out_bf_16_1,
                                 out_fp_32,
                                 rtol=rtol,
                                 atol=atol,
                                 etol=etol)

    out_bf_16_2 = function(a_fp32, b_bf16)
    assert_almost_equal_with_err(out_bf_16_2,
                                 out_fp_32,
                                 rtol=rtol,
                                 atol=atol,
                                 etol=etol)
def test_quantized_fc_bias_overflow(data_min, data_max, weight_min,
                                    weight_max):
    data_shape = (1, 32)
    data_nd = mx.np.random.uniform(data_min,
                                   data_max,
                                   size=data_shape,
                                   device=mx.cpu())
    weight_nd = mx.np.random.uniform(weight_min,
                                     weight_max,
                                     size=[64, 32],
                                     device=mx.cpu())
    bias_nd = mx.np.random.uniform(-1, +1, size=[64], device=mx.cpu())

    class FCBiasOverflow(nn.HybridBlock):
        def __init__(self, dtype='float32', **kwargs):
            super(FCBiasOverflow, self).__init__(**kwargs)
            self.weight = mx.gluon.Parameter('weight',
                                             dtype=dtype,
                                             allow_deferred_init=True)
            self.bias = mx.gluon.Parameter('bias',
                                           dtype=dtype,
                                           allow_deferred_init=True)

        def forward(self, x):
            conv1 = mx.npx.fully_connected(x,
                                           num_hidden=64,
                                           weight=self.weight.data(x.device),
                                           no_bias=False,
                                           bias=self.bias.data(x.device))
            return conv1

        def infer_shape(self, x, *args):
            self.weight.shape = (64, x.shape[x.ndim - 1])
            self.bias.shape = (64, )

    net = FCBiasOverflow()
    net.initialize()
    net(data_nd)  # dummy run

    net.weight.data()[:] = weight_nd
    net.bias.data()[:] = bias_nd
    out = net(data_nd)

    calib_data = mx.gluon.data.DataLoader(data_nd, batch_size=1)
    qnet = quantization.quantize_net(net,
                                     device=mx.cpu(),
                                     exclude_layers=None,
                                     exclude_operators=None,
                                     quantized_dtype='int8',
                                     calib_mode='naive',
                                     calib_data=calib_data,
                                     num_calib_batches=1,
                                     quantize_mode='full')
    out_quantized = qnet(data_nd)
    assert_almost_equal_with_err(out.asnumpy(),
                                 out_quantized.asnumpy(),
                                 rtol=1e-2,
                                 atol=1e-2,
                                 etol=0.01)
def test_quantized_conv_bias_overflow(data_min, data_max, weight_min,
                                      weight_max):
    data_shape = (1, 32, 2, 2)
    data_nd = mx.random.uniform(data_min,
                                data_max,
                                shape=data_shape,
                                ctx=mx.cpu())
    weight_nd = mx.random.uniform(weight_min,
                                  weight_max,
                                  shape=[64, 32, 1, 1],
                                  ctx=mx.cpu())
    bias_nd = mx.random.uniform(-1, +1, shape=[64], ctx=mx.cpu())

    class ConvBiasOverflow(nn.HybridBlock):
        def __init__(self, dtype='float32', **kwargs):
            super(ConvBiasOverflow, self).__init__(**kwargs)
            self.weight = mx.gluon.Parameter('weight',
                                             dtype=dtype,
                                             allow_deferred_init=True)
            self.bias = mx.gluon.Parameter('bias',
                                           dtype=dtype,
                                           allow_deferred_init=True)

        def hybrid_forward(self, F, x, weight, bias):
            conv1 = F.Convolution(x,
                                  num_filter=64,
                                  kernel=(1, 1),
                                  weight=weight,
                                  no_bias=False,
                                  bias=bias)
            return conv1

    net = ConvBiasOverflow()
    net.initialize()
    net(data_nd)  # dummy run

    net.weight.data()[:] = weight_nd
    net.bias.data()[:] = bias_nd
    out = net(data_nd)

    calib_data = mx.gluon.data.DataLoader(data_nd, batch_size=data_shape[0])
    qnet = quantization.quantize_net(net,
                                     ctx=mx.cpu(),
                                     exclude_layers=None,
                                     exclude_operators=None,
                                     quantized_dtype='int8',
                                     calib_mode='naive',
                                     calib_data=calib_data,
                                     num_calib_batches=1,
                                     quantize_mode='full')

    out_quantized = qnet(data_nd)
    assert_almost_equal_with_err(out.asnumpy(),
                                 out_quantized.asnumpy(),
                                 rtol=1e-2,
                                 atol=1e-2,
                                 etol=0.01)
Exemplo n.º 4
0
def check_quantize(sym, data_shape, out_type, name='conv',
                   check_calibration=True, gluon_forward=False, check_scale_align=False):
  if name in config:
    name = config[name][OP_NAME]
  sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
  mod = Module(symbol=sym, label_names=None)
  mod.bind(for_training=False,
            data_shapes=[('data', data_shape)])
  mod.init_params(mx.init.Normal(0.5))
  arg_params, aux_params = mod.get_params()

  if out_type == 'uint8':
    data = [mx.random.uniform(0.0, 1.0, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes]
  else:
    data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes]
  batch = mx.io.DataBatch(data, [])

  mod.forward(batch, is_train=False)
  for output in mod.get_outputs():
      output.wait_to_read()
  ref_out = mod.get_outputs()

  excluded_sym_names = []
  excluded_op_names = []
  if mx.current_context() == mx.cpu() and gluon_forward == True:
    excluded_op_names += ['_sg_mkldnn_fully_connected']

  calib_data = CalibIter(batch, data_shape, 1)

  qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg,
                                                                   arg_params=arg_params,
                                                                   aux_params=aux_params,
                                                                   ctx=mx.current_context(),
                                                                   excluded_sym_names=excluded_sym_names,
                                                                   excluded_op_names=excluded_op_names,
                                                                   quantized_dtype=out_type,
                                                                   calib_mode='naive',
                                                                   calib_data=calib_data,
                                                                   calib_layer=None,
                                                                   label_names=None,
                                                                   num_calib_examples=1)
  qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
  if check_calibration:
    check_qsym_calibrated(qsym, out_type, name=name)
  if check_scale_align:
    check_qsym_scale_align(qsym)
  if gluon_forward == True:
    check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape)
  else:
    quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape)
    for i in range(len(ref_out)):
      min_range = mx.nd.min(ref_out[i]).asscalar()
      max_range = mx.nd.max(ref_out[i]).asscalar()
      atol = 0.1 * max(abs(min_range), abs(max_range))
      assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
    check_qsym_dummy_forward(qsym, batch, data_shape)
Exemplo n.º 5
0
def helper_quantized_conv_bias_overflow(data_min, data_max, weight_min,
                                        weight_max):
    data_shape = (1, 32, 2, 2)
    data = mx.symbol.Variable('data', shape=data_shape, dtype='float32')
    weight = mx.symbol.Variable('weight', dtype='float32')
    bias = mx.symbol.Variable('bias', dtype='float32')
    sym = mx.symbol.Convolution(data=data,
                                weight=weight,
                                bias=bias,
                                name='conv',
                                num_filter=64,
                                kernel=(1, 1),
                                stride=(1, 1))
    data_nd = mx.random.uniform(data_min,
                                data_max,
                                shape=data_shape,
                                ctx=mx.cpu())
    weight_nd = mx.random.uniform(weight_min,
                                  weight_max,
                                  shape=[64, 32, 1, 1],
                                  ctx=mx.cpu())
    bias_nd = mx.random.uniform(-1, +1, shape=[64], ctx=mx.cpu())
    arg_params = {'data': data_nd, 'weight': weight_nd, 'bias': bias_nd}

    ex = sym.bind(mx.cpu(), arg_params, args_grad=None)
    ex.forward()
    ex.outputs[0].wait_to_read()
    sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
    batch = mx.io.DataBatch([data_nd], [])
    calib_data = CalibIter(batch, data_shape, 1)
    qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(
        sym=sym_sg,
        arg_params={
            'weight': weight_nd,
            'bias': bias_nd
        },
        aux_params={},
        ctx=mx.cpu(),
        excluded_sym_names=None,
        excluded_op_names=None,
        quantized_dtype='int8',
        calib_mode='naive',
        calib_data=calib_data,
        label_names=None,
        num_calib_examples=1,
        quantize_mode='full')
    qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
    qarg_params['data'] = data_nd
    qex = qsym.bind(mx.cpu(), qarg_params, args_grad=None)
    qex.forward()
    qex.outputs[0].wait_to_read()
    assert_almost_equal_with_err(ex.outputs[0].asnumpy(),
                                 qex.outputs[0].asnumpy(),
                                 rtol=1e-2,
                                 atol=1e-2,
                                 etol=0.01)
Exemplo n.º 6
0
def test_quantized_conv_bias_overflow(data_min, data_max, weight_min,
                                      weight_max):
    data_shape = (1, 32, 2, 2)
    data = mx.symbol.Variable('data', shape=data_shape, dtype='float32')
    weight = mx.symbol.Variable('weight', dtype='float32')
    bias = mx.symbol.Variable('bias', dtype='float32')
    sym = mx.symbol.Convolution(data=data,
                                weight=weight,
                                bias=bias,
                                name='conv',
                                num_filter=64,
                                kernel=(1, 1),
                                stride=(1, 1))
    data_nd = mx.random.uniform(data_min,
                                data_max,
                                shape=data_shape,
                                ctx=mx.cpu())
    weight_nd = mx.random.uniform(weight_min,
                                  weight_max,
                                  shape=[64, 32, 1, 1],
                                  ctx=mx.cpu())
    bias_nd = mx.random.uniform(-1, +1, shape=[64], ctx=mx.cpu())
    arg_params = {'weight': weight_nd, 'bias': bias_nd}

    ex = sym._bind(mx.cpu(), arg_params, args_grad=None)
    ex.forward(data=data_nd)
    ex.outputs[0].wait_to_read()
    sym_sg = sym.optimize_for(QUANTIZE_SG_PASS_NAME,
                              dedup_subgraph=True,
                              skip_infer=True)

    calib_data = mx.gluon.data.DataLoader(data_nd, batch_size=data_shape[0])
    qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(
        sym=sym_sg,
        arg_params=arg_params,
        aux_params={},
        ctx=mx.cpu(),
        excluded_sym_names=None,
        excluded_op_names=None,
        quantized_dtype='int8',
        calib_mode='naive',
        calib_data=calib_data,
        num_calib_batches=1,
        quantize_mode='full')
    qsym = qsym.optimize_for(QUANTIZE_SG_PASS_NAME,
                             dedup_subgraph=True,
                             skip_infer=True)
    qarg_params['data'] = data_nd
    qex = qsym._bind(mx.cpu(), qarg_params, args_grad=None)
    qex.forward()
    qex.outputs[0].wait_to_read()
    assert_almost_equal_with_err(ex.outputs[0].asnumpy(),
                                 qex.outputs[0].asnumpy(),
                                 rtol=1e-2,
                                 atol=1e-2,
                                 etol=0.01)
def test_batch_dot(batch_size, seq_length, units, num_heads):
    class BatchDotBlock(nn.HybridBlock):
        def __init__(self, **kwargs):
            super(BatchDotBlock, self).__init__(**kwargs)

        def forward(self, lhs, rhs):
            x = mx.npx.batch_dot(lhs, rhs)
            return x

    lhs_data = mx.np.random.uniform(low=-1,
                                    high=1,
                                    size=[batch_size, units, seq_length],
                                    dtype='float32')
    rhs_data = mx.np.random.uniform(low=-1,
                                    high=1,
                                    size=[batch_size, seq_length, seq_length],
                                    dtype='float32')

    net = BatchDotBlock()
    net.initialize()
    fused_net = net
    net.hybridize()
    ref_out = net(lhs_data, rhs_data)

    fused_net.optimize_for(lhs_data, rhs_data, backend="ONEDNN")
    out = fused_net(lhs_data, rhs_data)
    mx.nd.waitall()

    for i in range(len(out)):
        assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy())

    calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(
        lhs_data, rhs_data),
                                          batch_size=1)
    qnet = mx.contrib.quant.quantize_net(net,
                                         quantized_dtype='auto',
                                         exclude_layers=None,
                                         exclude_layers_match=None,
                                         calib_data=calib_data,
                                         calib_mode='naive',
                                         num_calib_batches=1,
                                         ctx=mx.cpu())

    qout = qnet(lhs_data, rhs_data)
    mx.nd.waitall()

    for i in range(len(ref_out)):
        min_range = np.min(ref_out[i].asnumpy())
        max_range = np.max(ref_out[i].asnumpy())
        atol = 0.1 * max(abs(min_range), abs(max_range))
        assert_almost_equal_with_err(qout[i].asnumpy(),
                                     ref_out[i].asnumpy(),
                                     rtol=0.1,
                                     atol=atol,
                                     etol=0.1)
Exemplo n.º 8
0
def check_quantize(net_original,
                   data_shape,
                   out_type,
                   name='conv',
                   check_calibration=True,
                   check_scale_align=False):
    quantize_granularity_list = ['tensor-wise']
    if name == 'fc':
        quantize_granularity_list += ['channel-wise']

    if name in config:
        name = config[name][OP_NAME]

    net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True)
    min_value = -1 if out_type != 'uint8' else 0
    data = mx.np.random.uniform(min_value,
                                1.0,
                                size=data_shape,
                                dtype='float32',
                                ctx=mx.current_device())

    outputs = net_original(data)
    for output in outputs:
        output.wait_to_read()
    ref_out = outputs

    calib_data = mx.gluon.data.DataLoader(data, batch_size=1)
    for quantize_granularity in quantize_granularity_list:
        qnet = quantization.quantize_net(
            net_original,
            ctx=mx.current_device(),
            exclude_layers=None,
            exclude_operators=None,
            quantized_dtype=out_type,
            calib_mode='naive',
            calib_data=calib_data,
            num_calib_batches=1,
            quantize_mode='full',
            quantize_granularity=quantize_granularity)
        qsym, _ = qnet.export(None)
        if check_calibration:
            check_qsym_calibrated(qsym, out_type, name=name)
        if check_scale_align:
            check_qsym_scale_align(qsym)

        quantized_out = qnet(data)
        for i in range(len(ref_out)):
            min_range = mx.np.min(ref_out[i]).item()
            max_range = mx.np.max(ref_out[i]).item()
            atol = 0.1 * max(abs(min_range), abs(max_range))
            assert_almost_equal_with_err(quantized_out.asnumpy(),
                                         ref_out.asnumpy(),
                                         rtol=0.1,
                                         atol=atol,
                                         etol=0.2)
Exemplo n.º 9
0
def test_quantize_whole_model_with_forward(qdtype):
    batch_size = 4
    data_shape = (batch_size, 4, 10, 10)
    data = mx.sym.Variable('data')
    conv0 = mx.sym.Convolution(data,
                               kernel=(1, 1),
                               num_filter=16,
                               name='conv0')
    sym = mx.sym.Convolution(conv0, kernel=(1, 1), num_filter=16, name='conv1')

    sym_block = mx.gluon.SymbolBlock(outputs=sym, inputs=data)
    initialize_block_params(sym_block, mx.init.Normal(0.5))

    in_data = mx.random.uniform(0.0 if qdtype == 'uint8' else -1.0,
                                1.0,
                                shape=data_shape)
    ref_out = sym_block(in_data)

    excluded_layers = []

    calib_data = mx.nd.random.uniform(0.0 if qdtype == 'uint8' else -1.0,
                                      1.0,
                                      shape=data_shape)
    calib_data = mx.gluon.data.DataLoader(calib_data, batch_size=batch_size)
    qsym = mx.contrib.quantization.quantize_net(sym_block,
                                                ctx=mx.current_context(),
                                                exclude_layers=excluded_layers,
                                                quantized_dtype=qdtype,
                                                calib_mode='naive',
                                                calib_data=calib_data,
                                                num_calib_batches=1,
                                                quantize_mode='full')

    outputs = qsym(in_data)
    for output in outputs:
        output.wait_to_read()

    for i in range(len(ref_out)):
        min_range = mx.nd.min(ref_out[i]).asscalar()
        max_range = mx.nd.max(ref_out[i]).asscalar()
        atol = 0.1 * max(abs(min_range), abs(max_range))
        assert_almost_equal_with_err(outputs[i].asnumpy(),
                                     ref_out[i].asnumpy(),
                                     rtol=0.1,
                                     atol=atol,
                                     etol=0.2)
def test_self_attention(batch_size, seq_length, units, num_heads):
    net = MultiHeadAttention(units, num_heads)
    in_data = mx.np.random.uniform(size=[batch_size, seq_length, units],
                                   dtype='float32')
    mask = mx.np.random.uniform(low=0,
                                high=2,
                                size=[batch_size, seq_length, seq_length],
                                dtype='int32')

    net.initialize()
    fused_net = net
    net.hybridize()
    ref_out = net(in_data, mask)

    fused_net.optimize_for(in_data, mask, backend="ONEDNN")
    out = fused_net(in_data, mask)
    mx.nd.waitall()

    for i in range(len(out)):
        assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy())

    calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(
        in_data, mask),
                                          batch_size=1)
    qnet = mx.contrib.quant.quantize_net(net,
                                         quantized_dtype='auto',
                                         exclude_layers=None,
                                         exclude_layers_match=None,
                                         calib_data=calib_data,
                                         calib_mode='naive',
                                         num_calib_batches=1,
                                         ctx=mx.cpu())

    qout = qnet(in_data, mask)
    mx.nd.waitall()

    for i in range(len(ref_out)):
        min_range = np.min(ref_out[i].asnumpy())
        max_range = np.max(ref_out[i].asnumpy())
        atol = 0.1 * max(abs(min_range), abs(max_range))
        assert_almost_equal_with_err(qout[i].asnumpy(),
                                     ref_out[i].asnumpy(),
                                     rtol=0.1,
                                     atol=atol,
                                     etol=0.2)
Exemplo n.º 11
0
def check_operator_accuracy(sym_fp32,
                            sym_bf16,
                            data_shape,
                            num_input_data=1,
                            bf16_use_fp32_params=False,
                            rtol=1e-1,
                            atol=5e-1,
                            etol=0):
    """
    check accuracy for bfloat16 operators

    sym_fp32: Symbol
        fp32 operator
    sym_bf16: Symbol
        bf16 operator
    data_shape: tuple of int
        input data shape for fp32/bf16 symbol
    num_input_data: int
        number of input data, default is 1, should set different values for those operators with multiple inputs, like concat, elemwise_add, etc.
    bf16_use_fp32_params: bool
        currently only bn use this param as True, since bf16 bn only accept bf16 data with fp32 mean/var/scale/shift
    rtol: float
        the relative threshold
    atol: float
        the absolute threshold
    etol: float
        The error rate threshold, allow a small amount of value not consistent between bf16 and fp32
    """
    if not isinstance(data_shape, tuple):
        data_shape = tuple(data_shape)
    data_range = (0.0, 10.0)
    data_list_fp32 = list()
    data_list_bf16 = list()
    for i in range(num_input_data):
        data_list_fp32.append(
            mx.nd.random.uniform(low=data_range[0],
                                 high=data_range[1],
                                 shape=data_shape))
        data_list_bf16.append(mx.nd.amp_cast(data_list_fp32[i],
                                             dtype=bfloat16))

    arg_shapes, _, aux_shapes = sym_fp32.infer_shape(data=data_shape)
    arg_names = sym_fp32.list_arguments()
    aux_names = sym_fp32.list_auxiliary_states()

    exe_fp32 = sym_fp32.simple_bind(ctx=mx.cpu(), data=data_shape)

    arg_params_fp32 = {}
    aux_params_fp32 = {}
    type_dict = {}
    for i, arg_name in enumerate(arg_names):
        if i < num_input_data:
            exe_fp32.arg_dict[arg_name][:] = data_list_fp32[i]
            continue
        arg_params_fp32[arg_name] = mx.nd.random.uniform(low=data_range[0],
                                                         high=data_range[1],
                                                         shape=arg_shapes[i])
        exe_fp32.arg_dict[arg_name][:] = arg_params_fp32[arg_name]
        # specify the dtype of arguments
        if not bf16_use_fp32_params:
            type_dict.update({arg_name: bfloat16})

    for i, aux_name in enumerate(aux_names):
        aux_params_fp32[aux_name] = mx.nd.random.uniform(low=data_range[0],
                                                         high=data_range[1],
                                                         shape=aux_shapes[i])
        exe_fp32.aux_dict[aux_name][:] = aux_params_fp32[aux_name]

    output_fp32 = exe_fp32.forward()[0]

    exe_bf16 = sym_bf16.simple_bind(ctx=mx.cpu(),
                                    data=data_shape,
                                    type_dict=type_dict)

    arg_params_bf16 = {}
    aux_params_bf16 = {}
    for i, arg_name in enumerate(arg_names):
        if i < num_input_data:
            exe_bf16.arg_dict[arg_name][:] = data_list_bf16[i]
            continue

        if bf16_use_fp32_params:
            exe_bf16.arg_dict[arg_name][:] = arg_params_fp32[arg_name]
        else:
            exe_bf16.arg_dict[arg_name][:] = mx.nd.amp_cast(
                arg_params_fp32[arg_name], dtype=bfloat16)

    for aux_name in aux_names:
        if bf16_use_fp32_params:
            exe_bf16.aux_dict[aux_name][:] = aux_params_fp32[aux_name]
        else:
            exe_bf16.aux_dict[aux_name][:] = mx.nd.amp_cast(
                aux_params_fp32[aux_name], dtype=bfloat16)

    output_bf16 = exe_bf16.forward()[0]
    output_bf16.wait_to_read()
    output_bf16_2_fp32 = mx.nd.amp_cast(output_bf16, dtype="float32")
    assert_almost_equal_with_err(output_bf16_2_fp32,
                                 output_fp32,
                                 rtol=rtol,
                                 atol=atol,
                                 etol=etol)
Exemplo n.º 12
0
def check_quantize(net_original,
                   data_shapes,
                   out_type,
                   name='conv',
                   check_calibration=True,
                   check_scale_align=False,
                   quantize_mode='full',
                   attrs_dict={}):
    quantize_granularity_list = ['tensor-wise']
    if name == 'fc':
        quantize_granularity_list += ['channel-wise']

    if name in config:
        name = config[name][OP_NAME]

    net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True)
    min_value = -1 if out_type != 'uint8' else 0
    one_shape = isinstance(data_shapes, tuple)
    if one_shape:
        # replace one shape with list of shapes with one element inside to follow later the same schema
        data_shapes = [data_shapes]
    data = []
    for shape in data_shapes:
        data.append(
            mx.np.random.uniform(min_value,
                                 1.0,
                                 size=shape,
                                 dtype='float32',
                                 device=mx.cpu()))

    outputs = net_original(*data)
    for output in outputs:
        output.wait_to_read()
    ref_out = outputs
    one_output = not isinstance(ref_out, list)
    if one_output:
        # make a list to have a common path for one and multiple outputs
        ref_out = [ref_out]

    dataArray = mx.gluon.data.ArrayDataset(*data)

    calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1)
    for quantize_granularity in quantize_granularity_list:
        qnet = quantization.quantize_net(
            net_original,
            device=mx.cpu(),
            exclude_layers=None,
            exclude_operators=None,
            quantized_dtype=out_type,
            calib_mode='naive',
            calib_data=calib_data,
            num_calib_batches=1,
            quantize_mode=quantize_mode,
            quantize_granularity=quantize_granularity)
        qsym, _ = qnet.export(None)
        check_fusion_parameter(qsym, attrs_dict)
        if check_calibration:
            check_qsym_calibrated(qsym, out_type, name=name)
        if check_scale_align:
            check_qsym_scale_align(qsym)

        quantized_out = qnet(*data)
        if one_output:
            quantized_out = [quantized_out]
        for i in range(len(ref_out)):
            min_range = mx.np.min(ref_out[i]).item()
            max_range = mx.np.max(ref_out[i]).item()
            atol = 0.1 * max(abs(min_range), abs(max_range))
            assert_almost_equal_with_err(quantized_out[i].asnumpy(),
                                         ref_out[i].asnumpy(),
                                         rtol=0.1,
                                         atol=atol,
                                         etol=0.2)
Exemplo n.º 13
0
def check_quantize(net_original,
                   data_shapes,
                   out_type,
                   name='conv',
                   check_calibration=True,
                   check_scale_align=False,
                   quantize_mode='full',
                   attrs_dict={},
                   calib_mode='naive',
                   check_fusion=True):
    quantize_granularity_list = ['tensor-wise']
    if name == 'fc':
        quantize_granularity_list += ['channel-wise']

    if name in config:
        name = config[name][OP_NAME]

    sigma = 0.01 if hasattr(
        net_original, 'alg') is True and net_original.alg == 'exp' else 0.5
    if out_type == 'uint8':
        # Initialize weights and tensors only with positive values to be sure
        # that results are always positive
        init = CustomNormalInit(sigma=sigma, bounded=True)
        min_value = 0
    else:
        init = mx.init.Normal(sigma)
        min_value = -1

    net_original.initialize(init=init, force_reinit=True)

    one_shape = isinstance(data_shapes, tuple)
    if one_shape:
        # replace one shape with list of shapes with one element inside to follow later the same schema
        data_shapes = [data_shapes]
    data = []
    for shape in data_shapes:
        data.append(
            mx.np.random.uniform(min_value,
                                 1.0,
                                 size=shape,
                                 dtype='float32',
                                 device=mx.cpu()))

    outputs = net_original(*data)
    for output in outputs:
        output.wait_to_read()
    ref_out = outputs
    one_output = not isinstance(ref_out, list)
    if one_output:
        # make a list to have a common path for one and multiple outputs
        ref_out = [ref_out]

    class TestDataLoader(mx.gluon.data.DataLoader):
        def __init__(self, data):
            self.data = data
            self.finish = False

        def __iter__(self):
            self.finish = False
            return self

        def __next__(self):
            if self.finish:
                raise StopIteration
            self.finish = True
            return self.data

        def __del__(self):
            pass

    calib_data = TestDataLoader(data)

    for quantize_granularity in quantize_granularity_list:
        qnet = quantization.quantize_net(
            net_original,
            device=mx.cpu(),
            exclude_layers=None,
            exclude_operators=None,
            quantized_dtype=out_type,
            calib_mode=calib_mode,
            calib_data=calib_data,
            num_calib_batches=1,
            quantize_mode=quantize_mode,
            quantize_granularity=quantize_granularity)
        qsym, _ = qnet.export(None)
        if check_fusion:
            check_fusion_parameter(qsym, attrs_dict)
        if check_calibration:
            check_qsym_calibrated(qsym, out_type, name=name)
        if check_scale_align:
            check_qsym_scale_align(qsym)

        quantized_out = qnet(*data)
        if one_output:
            quantized_out = [quantized_out]
        for i in range(len(ref_out)):
            min_range = mx.np.min(ref_out[i]).item()
            max_range = mx.np.max(ref_out[i]).item()
            atol = 0.1 * max(abs(min_range), abs(max_range))
            assert_almost_equal_with_err(quantized_out[i].asnumpy(),
                                         ref_out[i].asnumpy(),
                                         rtol=0.1,
                                         atol=atol,
                                         etol=0.2)
Exemplo n.º 14
0
def test_self_attention(batch_size, seq_length, units, num_heads):
  class MultiHeadAttention(nn.HybridBlock):
    def __init__(self, units, num_heads, dtype='float32', **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self._units = units
        self._num_heads = num_heads
        self._fc = nn.Dense(in_units=self._units, units=3*self._units, flatten=False, dtype=dtype)
        self._scale = math.sqrt(self._units // self._num_heads)

    def forward(self, x, mask):
        x = mx.np.copy(x)
        out = self._fc(x)
        query, key, value = mx.np.split(out, 3, axis=-1)
        query = mx.npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = mx.npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = mx.npx.reshape(value, (-2, -2, self._num_heads, -1))
        scores = mx.npx.batch_dot(mx.np.swapaxes(query, 1, 2), mx.np.swapaxes(key, 1, 2),
                               transpose_b=True)
        mask = mx.np.expand_dims(mask, axis=1).astype(np.bool)
        attn_weights = mx.npx.masked_softmax(scores, mask=mask.astype(np.bool),
                                            axis=-1, temperature=self._scale)
        attn_weights = mx.npx.dropout(attn_weights, p=0.1)
        context_vec = mx.npx.batch_dot(attn_weights,
                                     mx.np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
        context_vec = mx.npx.reshape(context_vec, (-2, -2, -1))

        return context_vec

  net = MultiHeadAttention(units, num_heads)
  in_data = mx.np.random.uniform(size=[batch_size, seq_length, units], dtype='float32')
  mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, seq_length], dtype='int32')

  net.initialize()
  fused_net = net
  net.hybridize()
  ref_out = net(in_data, mask)

  fused_net.optimize_for(in_data, mask, backend="MKLDNN")
  out = fused_net(in_data, mask)
  mx.nd.waitall()

  for i in range(len(out)):
    assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy())


  calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, mask), batch_size=1)
  qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
                                            exclude_layers=None,
                                            exclude_layers_match=None,
                                            calib_data=calib_data,
                                            calib_mode='naive',
                                            num_calib_batches=1,
                                            ctx=mx.cpu())

  qout = qnet(in_data, mask)
  mx.nd.waitall()

  for i in range(len(ref_out)):
      min_range = np.min(ref_out[i].asnumpy())
      max_range = np.max(ref_out[i].asnumpy())
      atol = 0.1 * max(abs(min_range), abs(max_range))
      assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
Exemplo n.º 15
0
def check_quantize(sym,
                   data_shape,
                   out_type,
                   name='conv',
                   check_calibration=True,
                   check_scale_align=False):
    quantize_granularity_list = ['tensor-wise']
    if name == 'fc':
        quantize_granularity_list += ['channel-wise']

    if name in config:
        name = config[name][OP_NAME]
    sym_sg = sym.optimize_for(QUANTIZE_SG_PASS_NAME,
                              dedup_subgraph=True,
                              skip_infer=True)

    inputs = mx.sym.var('data', dtype='float32')
    sym_block = mx.gluon.SymbolBlock(sym, inputs)
    initialize_block_params(sym_block, mx.init.Normal(0.5))

    min_value = -1 if out_type != 'uint8' else 0
    data = mx.random.uniform(min_value,
                             1.0,
                             shape=data_shape,
                             dtype='float32',
                             ctx=mx.current_context())

    outputs = sym_block(data)
    for output in outputs:
        output.wait_to_read()
    ref_out = outputs
    arg_params, aux_params = collect_block_args_aux(sym_block, sym)

    excluded_sym_names = []
    excluded_op_names = []

    calib_data = mx.gluon.data.DataLoader(data, batch_size=1)
    for quantize_granularity in quantize_granularity_list:
        qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(
            sym=sym_sg,
            arg_params=arg_params,
            aux_params=aux_params,
            ctx=mx.current_context(),
            excluded_sym_names=excluded_sym_names,
            excluded_op_names=excluded_op_names,
            quantized_dtype=out_type,
            calib_mode='naive',
            calib_data=calib_data,
            num_calib_batches=1,
            quantize_mode='full',
            quantize_granularity=quantize_granularity)
        qsym = qsym.optimize_for(QUANTIZE_SG_PASS_NAME,
                                 dedup_subgraph=True,
                                 skip_infer=True)
        if check_calibration:
            check_qsym_calibrated(qsym, out_type, name=name)
        if check_scale_align:
            check_qsym_scale_align(qsym)

        quantized_out = check_qsym_gluon_forward(qsym, qarg_params,
                                                 qaux_params, data)
        for i in range(len(ref_out)):
            min_range = mx.nd.min(ref_out[i]).asscalar()
            max_range = mx.nd.max(ref_out[i]).asscalar()
            atol = 0.1 * max(abs(min_range), abs(max_range))
            assert_almost_equal_with_err(quantized_out[i].asnumpy(),
                                         ref_out[i].asnumpy(),
                                         rtol=0.1,
                                         atol=atol,
                                         etol=0.2)
        check_qsym_dummy_forward(qsym, data, data_shape)