def test_bf16_binary_broadcast_elemwise_mixed_input(function, dtype, ndim): dshape_0 = rand_shape_nd(ndim) dshape_1 = tuple() for i in range(ndim): if (randint(0, 1)): dshape_1 += (dshape_0[i], ) else: dshape_1 += (1, ) a = mx.np.random.uniform(-1, 1, dshape_0, dtype=np.float32) a_fp32 = mx.np.array(a, dtype=dtype) a_bf16 = a.astype('bfloat16') b = mx.np.random.uniform(-1, 1, dshape_1, dtype=np.float32) b_fp32 = mx.np.array(b, dtype=dtype) b_bf16 = b.astype('bfloat16') rtol = 1e-1 atol = 5e-1 etol = 0 out_bf_16_1 = function(a_bf16, b_fp32) out_fp_32 = function(a_fp32, b_fp32) assert_almost_equal_with_err(out_bf_16_1, out_fp_32, rtol=rtol, atol=atol, etol=etol) out_bf_16_2 = function(a_fp32, b_bf16) assert_almost_equal_with_err(out_bf_16_2, out_fp_32, rtol=rtol, atol=atol, etol=etol)
def test_quantized_fc_bias_overflow(data_min, data_max, weight_min, weight_max): data_shape = (1, 32) data_nd = mx.np.random.uniform(data_min, data_max, size=data_shape, device=mx.cpu()) weight_nd = mx.np.random.uniform(weight_min, weight_max, size=[64, 32], device=mx.cpu()) bias_nd = mx.np.random.uniform(-1, +1, size=[64], device=mx.cpu()) class FCBiasOverflow(nn.HybridBlock): def __init__(self, dtype='float32', **kwargs): super(FCBiasOverflow, self).__init__(**kwargs) self.weight = mx.gluon.Parameter('weight', dtype=dtype, allow_deferred_init=True) self.bias = mx.gluon.Parameter('bias', dtype=dtype, allow_deferred_init=True) def forward(self, x): conv1 = mx.npx.fully_connected(x, num_hidden=64, weight=self.weight.data(x.device), no_bias=False, bias=self.bias.data(x.device)) return conv1 def infer_shape(self, x, *args): self.weight.shape = (64, x.shape[x.ndim - 1]) self.bias.shape = (64, ) net = FCBiasOverflow() net.initialize() net(data_nd) # dummy run net.weight.data()[:] = weight_nd net.bias.data()[:] = bias_nd out = net(data_nd) calib_data = mx.gluon.data.DataLoader(data_nd, batch_size=1) qnet = quantization.quantize_net(net, device=mx.cpu(), exclude_layers=None, exclude_operators=None, quantized_dtype='int8', calib_mode='naive', calib_data=calib_data, num_calib_batches=1, quantize_mode='full') out_quantized = qnet(data_nd) assert_almost_equal_with_err(out.asnumpy(), out_quantized.asnumpy(), rtol=1e-2, atol=1e-2, etol=0.01)
def test_quantized_conv_bias_overflow(data_min, data_max, weight_min, weight_max): data_shape = (1, 32, 2, 2) data_nd = mx.random.uniform(data_min, data_max, shape=data_shape, ctx=mx.cpu()) weight_nd = mx.random.uniform(weight_min, weight_max, shape=[64, 32, 1, 1], ctx=mx.cpu()) bias_nd = mx.random.uniform(-1, +1, shape=[64], ctx=mx.cpu()) class ConvBiasOverflow(nn.HybridBlock): def __init__(self, dtype='float32', **kwargs): super(ConvBiasOverflow, self).__init__(**kwargs) self.weight = mx.gluon.Parameter('weight', dtype=dtype, allow_deferred_init=True) self.bias = mx.gluon.Parameter('bias', dtype=dtype, allow_deferred_init=True) def hybrid_forward(self, F, x, weight, bias): conv1 = F.Convolution(x, num_filter=64, kernel=(1, 1), weight=weight, no_bias=False, bias=bias) return conv1 net = ConvBiasOverflow() net.initialize() net(data_nd) # dummy run net.weight.data()[:] = weight_nd net.bias.data()[:] = bias_nd out = net(data_nd) calib_data = mx.gluon.data.DataLoader(data_nd, batch_size=data_shape[0]) qnet = quantization.quantize_net(net, ctx=mx.cpu(), exclude_layers=None, exclude_operators=None, quantized_dtype='int8', calib_mode='naive', calib_data=calib_data, num_calib_batches=1, quantize_mode='full') out_quantized = qnet(data_nd) assert_almost_equal_with_err(out.asnumpy(), out_quantized.asnumpy(), rtol=1e-2, atol=1e-2, etol=0.01)
def check_quantize(sym, data_shape, out_type, name='conv', check_calibration=True, gluon_forward=False, check_scale_align=False): if name in config: name = config[name][OP_NAME] sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) mod = Module(symbol=sym, label_names=None) mod.bind(for_training=False, data_shapes=[('data', data_shape)]) mod.init_params(mx.init.Normal(0.5)) arg_params, aux_params = mod.get_params() if out_type == 'uint8': data = [mx.random.uniform(0.0, 1.0, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes] else: data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes] batch = mx.io.DataBatch(data, []) mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read() ref_out = mod.get_outputs() excluded_sym_names = [] excluded_op_names = [] if mx.current_context() == mx.cpu() and gluon_forward == True: excluded_op_names += ['_sg_mkldnn_fully_connected'] calib_data = CalibIter(batch, data_shape, 1) qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, arg_params=arg_params, aux_params=aux_params, ctx=mx.current_context(), excluded_sym_names=excluded_sym_names, excluded_op_names=excluded_op_names, quantized_dtype=out_type, calib_mode='naive', calib_data=calib_data, calib_layer=None, label_names=None, num_calib_examples=1) qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) if check_calibration: check_qsym_calibrated(qsym, out_type, name=name) if check_scale_align: check_qsym_scale_align(qsym) if gluon_forward == True: check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape) else: quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape) for i in range(len(ref_out)): min_range = mx.nd.min(ref_out[i]).asscalar() max_range = mx.nd.max(ref_out[i]).asscalar() atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) check_qsym_dummy_forward(qsym, batch, data_shape)
def helper_quantized_conv_bias_overflow(data_min, data_max, weight_min, weight_max): data_shape = (1, 32, 2, 2) data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') weight = mx.symbol.Variable('weight', dtype='float32') bias = mx.symbol.Variable('bias', dtype='float32') sym = mx.symbol.Convolution(data=data, weight=weight, bias=bias, name='conv', num_filter=64, kernel=(1, 1), stride=(1, 1)) data_nd = mx.random.uniform(data_min, data_max, shape=data_shape, ctx=mx.cpu()) weight_nd = mx.random.uniform(weight_min, weight_max, shape=[64, 32, 1, 1], ctx=mx.cpu()) bias_nd = mx.random.uniform(-1, +1, shape=[64], ctx=mx.cpu()) arg_params = {'data': data_nd, 'weight': weight_nd, 'bias': bias_nd} ex = sym.bind(mx.cpu(), arg_params, args_grad=None) ex.forward() ex.outputs[0].wait_to_read() sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) batch = mx.io.DataBatch([data_nd], []) calib_data = CalibIter(batch, data_shape, 1) qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model( sym=sym_sg, arg_params={ 'weight': weight_nd, 'bias': bias_nd }, aux_params={}, ctx=mx.cpu(), excluded_sym_names=None, excluded_op_names=None, quantized_dtype='int8', calib_mode='naive', calib_data=calib_data, label_names=None, num_calib_examples=1, quantize_mode='full') qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) qarg_params['data'] = data_nd qex = qsym.bind(mx.cpu(), qarg_params, args_grad=None) qex.forward() qex.outputs[0].wait_to_read() assert_almost_equal_with_err(ex.outputs[0].asnumpy(), qex.outputs[0].asnumpy(), rtol=1e-2, atol=1e-2, etol=0.01)
def test_quantized_conv_bias_overflow(data_min, data_max, weight_min, weight_max): data_shape = (1, 32, 2, 2) data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') weight = mx.symbol.Variable('weight', dtype='float32') bias = mx.symbol.Variable('bias', dtype='float32') sym = mx.symbol.Convolution(data=data, weight=weight, bias=bias, name='conv', num_filter=64, kernel=(1, 1), stride=(1, 1)) data_nd = mx.random.uniform(data_min, data_max, shape=data_shape, ctx=mx.cpu()) weight_nd = mx.random.uniform(weight_min, weight_max, shape=[64, 32, 1, 1], ctx=mx.cpu()) bias_nd = mx.random.uniform(-1, +1, shape=[64], ctx=mx.cpu()) arg_params = {'weight': weight_nd, 'bias': bias_nd} ex = sym._bind(mx.cpu(), arg_params, args_grad=None) ex.forward(data=data_nd) ex.outputs[0].wait_to_read() sym_sg = sym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) calib_data = mx.gluon.data.DataLoader(data_nd, batch_size=data_shape[0]) qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model( sym=sym_sg, arg_params=arg_params, aux_params={}, ctx=mx.cpu(), excluded_sym_names=None, excluded_op_names=None, quantized_dtype='int8', calib_mode='naive', calib_data=calib_data, num_calib_batches=1, quantize_mode='full') qsym = qsym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) qarg_params['data'] = data_nd qex = qsym._bind(mx.cpu(), qarg_params, args_grad=None) qex.forward() qex.outputs[0].wait_to_read() assert_almost_equal_with_err(ex.outputs[0].asnumpy(), qex.outputs[0].asnumpy(), rtol=1e-2, atol=1e-2, etol=0.01)
def test_batch_dot(batch_size, seq_length, units, num_heads): class BatchDotBlock(nn.HybridBlock): def __init__(self, **kwargs): super(BatchDotBlock, self).__init__(**kwargs) def forward(self, lhs, rhs): x = mx.npx.batch_dot(lhs, rhs) return x lhs_data = mx.np.random.uniform(low=-1, high=1, size=[batch_size, units, seq_length], dtype='float32') rhs_data = mx.np.random.uniform(low=-1, high=1, size=[batch_size, seq_length, seq_length], dtype='float32') net = BatchDotBlock() net.initialize() fused_net = net net.hybridize() ref_out = net(lhs_data, rhs_data) fused_net.optimize_for(lhs_data, rhs_data, backend="ONEDNN") out = fused_net(lhs_data, rhs_data) mx.nd.waitall() for i in range(len(out)): assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy()) calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset( lhs_data, rhs_data), batch_size=1) qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto', exclude_layers=None, exclude_layers_match=None, calib_data=calib_data, calib_mode='naive', num_calib_batches=1, ctx=mx.cpu()) qout = qnet(lhs_data, rhs_data) mx.nd.waitall() for i in range(len(ref_out)): min_range = np.min(ref_out[i].asnumpy()) max_range = np.max(ref_out[i].asnumpy()) atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.1)
def check_quantize(net_original, data_shape, out_type, name='conv', check_calibration=True, check_scale_align=False): quantize_granularity_list = ['tensor-wise'] if name == 'fc': quantize_granularity_list += ['channel-wise'] if name in config: name = config[name][OP_NAME] net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True) min_value = -1 if out_type != 'uint8' else 0 data = mx.np.random.uniform(min_value, 1.0, size=data_shape, dtype='float32', ctx=mx.current_device()) outputs = net_original(data) for output in outputs: output.wait_to_read() ref_out = outputs calib_data = mx.gluon.data.DataLoader(data, batch_size=1) for quantize_granularity in quantize_granularity_list: qnet = quantization.quantize_net( net_original, ctx=mx.current_device(), exclude_layers=None, exclude_operators=None, quantized_dtype=out_type, calib_mode='naive', calib_data=calib_data, num_calib_batches=1, quantize_mode='full', quantize_granularity=quantize_granularity) qsym, _ = qnet.export(None) if check_calibration: check_qsym_calibrated(qsym, out_type, name=name) if check_scale_align: check_qsym_scale_align(qsym) quantized_out = qnet(data) for i in range(len(ref_out)): min_range = mx.np.min(ref_out[i]).item() max_range = mx.np.max(ref_out[i]).item() atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(quantized_out.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
def test_quantize_whole_model_with_forward(qdtype): batch_size = 4 data_shape = (batch_size, 4, 10, 10) data = mx.sym.Variable('data') conv0 = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0') sym = mx.sym.Convolution(conv0, kernel=(1, 1), num_filter=16, name='conv1') sym_block = mx.gluon.SymbolBlock(outputs=sym, inputs=data) initialize_block_params(sym_block, mx.init.Normal(0.5)) in_data = mx.random.uniform(0.0 if qdtype == 'uint8' else -1.0, 1.0, shape=data_shape) ref_out = sym_block(in_data) excluded_layers = [] calib_data = mx.nd.random.uniform(0.0 if qdtype == 'uint8' else -1.0, 1.0, shape=data_shape) calib_data = mx.gluon.data.DataLoader(calib_data, batch_size=batch_size) qsym = mx.contrib.quantization.quantize_net(sym_block, ctx=mx.current_context(), exclude_layers=excluded_layers, quantized_dtype=qdtype, calib_mode='naive', calib_data=calib_data, num_calib_batches=1, quantize_mode='full') outputs = qsym(in_data) for output in outputs: output.wait_to_read() for i in range(len(ref_out)): min_range = mx.nd.min(ref_out[i]).asscalar() max_range = mx.nd.max(ref_out[i]).asscalar() atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(outputs[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
def test_self_attention(batch_size, seq_length, units, num_heads): net = MultiHeadAttention(units, num_heads) in_data = mx.np.random.uniform(size=[batch_size, seq_length, units], dtype='float32') mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, seq_length], dtype='int32') net.initialize() fused_net = net net.hybridize() ref_out = net(in_data, mask) fused_net.optimize_for(in_data, mask, backend="ONEDNN") out = fused_net(in_data, mask) mx.nd.waitall() for i in range(len(out)): assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy()) calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset( in_data, mask), batch_size=1) qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto', exclude_layers=None, exclude_layers_match=None, calib_data=calib_data, calib_mode='naive', num_calib_batches=1, ctx=mx.cpu()) qout = qnet(in_data, mask) mx.nd.waitall() for i in range(len(ref_out)): min_range = np.min(ref_out[i].asnumpy()) max_range = np.max(ref_out[i].asnumpy()) atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
def check_operator_accuracy(sym_fp32, sym_bf16, data_shape, num_input_data=1, bf16_use_fp32_params=False, rtol=1e-1, atol=5e-1, etol=0): """ check accuracy for bfloat16 operators sym_fp32: Symbol fp32 operator sym_bf16: Symbol bf16 operator data_shape: tuple of int input data shape for fp32/bf16 symbol num_input_data: int number of input data, default is 1, should set different values for those operators with multiple inputs, like concat, elemwise_add, etc. bf16_use_fp32_params: bool currently only bn use this param as True, since bf16 bn only accept bf16 data with fp32 mean/var/scale/shift rtol: float the relative threshold atol: float the absolute threshold etol: float The error rate threshold, allow a small amount of value not consistent between bf16 and fp32 """ if not isinstance(data_shape, tuple): data_shape = tuple(data_shape) data_range = (0.0, 10.0) data_list_fp32 = list() data_list_bf16 = list() for i in range(num_input_data): data_list_fp32.append( mx.nd.random.uniform(low=data_range[0], high=data_range[1], shape=data_shape)) data_list_bf16.append(mx.nd.amp_cast(data_list_fp32[i], dtype=bfloat16)) arg_shapes, _, aux_shapes = sym_fp32.infer_shape(data=data_shape) arg_names = sym_fp32.list_arguments() aux_names = sym_fp32.list_auxiliary_states() exe_fp32 = sym_fp32.simple_bind(ctx=mx.cpu(), data=data_shape) arg_params_fp32 = {} aux_params_fp32 = {} type_dict = {} for i, arg_name in enumerate(arg_names): if i < num_input_data: exe_fp32.arg_dict[arg_name][:] = data_list_fp32[i] continue arg_params_fp32[arg_name] = mx.nd.random.uniform(low=data_range[0], high=data_range[1], shape=arg_shapes[i]) exe_fp32.arg_dict[arg_name][:] = arg_params_fp32[arg_name] # specify the dtype of arguments if not bf16_use_fp32_params: type_dict.update({arg_name: bfloat16}) for i, aux_name in enumerate(aux_names): aux_params_fp32[aux_name] = mx.nd.random.uniform(low=data_range[0], high=data_range[1], shape=aux_shapes[i]) exe_fp32.aux_dict[aux_name][:] = aux_params_fp32[aux_name] output_fp32 = exe_fp32.forward()[0] exe_bf16 = sym_bf16.simple_bind(ctx=mx.cpu(), data=data_shape, type_dict=type_dict) arg_params_bf16 = {} aux_params_bf16 = {} for i, arg_name in enumerate(arg_names): if i < num_input_data: exe_bf16.arg_dict[arg_name][:] = data_list_bf16[i] continue if bf16_use_fp32_params: exe_bf16.arg_dict[arg_name][:] = arg_params_fp32[arg_name] else: exe_bf16.arg_dict[arg_name][:] = mx.nd.amp_cast( arg_params_fp32[arg_name], dtype=bfloat16) for aux_name in aux_names: if bf16_use_fp32_params: exe_bf16.aux_dict[aux_name][:] = aux_params_fp32[aux_name] else: exe_bf16.aux_dict[aux_name][:] = mx.nd.amp_cast( aux_params_fp32[aux_name], dtype=bfloat16) output_bf16 = exe_bf16.forward()[0] output_bf16.wait_to_read() output_bf16_2_fp32 = mx.nd.amp_cast(output_bf16, dtype="float32") assert_almost_equal_with_err(output_bf16_2_fp32, output_fp32, rtol=rtol, atol=atol, etol=etol)
def check_quantize(net_original, data_shapes, out_type, name='conv', check_calibration=True, check_scale_align=False, quantize_mode='full', attrs_dict={}): quantize_granularity_list = ['tensor-wise'] if name == 'fc': quantize_granularity_list += ['channel-wise'] if name in config: name = config[name][OP_NAME] net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True) min_value = -1 if out_type != 'uint8' else 0 one_shape = isinstance(data_shapes, tuple) if one_shape: # replace one shape with list of shapes with one element inside to follow later the same schema data_shapes = [data_shapes] data = [] for shape in data_shapes: data.append( mx.np.random.uniform(min_value, 1.0, size=shape, dtype='float32', device=mx.cpu())) outputs = net_original(*data) for output in outputs: output.wait_to_read() ref_out = outputs one_output = not isinstance(ref_out, list) if one_output: # make a list to have a common path for one and multiple outputs ref_out = [ref_out] dataArray = mx.gluon.data.ArrayDataset(*data) calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1) for quantize_granularity in quantize_granularity_list: qnet = quantization.quantize_net( net_original, device=mx.cpu(), exclude_layers=None, exclude_operators=None, quantized_dtype=out_type, calib_mode='naive', calib_data=calib_data, num_calib_batches=1, quantize_mode=quantize_mode, quantize_granularity=quantize_granularity) qsym, _ = qnet.export(None) check_fusion_parameter(qsym, attrs_dict) if check_calibration: check_qsym_calibrated(qsym, out_type, name=name) if check_scale_align: check_qsym_scale_align(qsym) quantized_out = qnet(*data) if one_output: quantized_out = [quantized_out] for i in range(len(ref_out)): min_range = mx.np.min(ref_out[i]).item() max_range = mx.np.max(ref_out[i]).item() atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
def check_quantize(net_original, data_shapes, out_type, name='conv', check_calibration=True, check_scale_align=False, quantize_mode='full', attrs_dict={}, calib_mode='naive', check_fusion=True): quantize_granularity_list = ['tensor-wise'] if name == 'fc': quantize_granularity_list += ['channel-wise'] if name in config: name = config[name][OP_NAME] sigma = 0.01 if hasattr( net_original, 'alg') is True and net_original.alg == 'exp' else 0.5 if out_type == 'uint8': # Initialize weights and tensors only with positive values to be sure # that results are always positive init = CustomNormalInit(sigma=sigma, bounded=True) min_value = 0 else: init = mx.init.Normal(sigma) min_value = -1 net_original.initialize(init=init, force_reinit=True) one_shape = isinstance(data_shapes, tuple) if one_shape: # replace one shape with list of shapes with one element inside to follow later the same schema data_shapes = [data_shapes] data = [] for shape in data_shapes: data.append( mx.np.random.uniform(min_value, 1.0, size=shape, dtype='float32', device=mx.cpu())) outputs = net_original(*data) for output in outputs: output.wait_to_read() ref_out = outputs one_output = not isinstance(ref_out, list) if one_output: # make a list to have a common path for one and multiple outputs ref_out = [ref_out] class TestDataLoader(mx.gluon.data.DataLoader): def __init__(self, data): self.data = data self.finish = False def __iter__(self): self.finish = False return self def __next__(self): if self.finish: raise StopIteration self.finish = True return self.data def __del__(self): pass calib_data = TestDataLoader(data) for quantize_granularity in quantize_granularity_list: qnet = quantization.quantize_net( net_original, device=mx.cpu(), exclude_layers=None, exclude_operators=None, quantized_dtype=out_type, calib_mode=calib_mode, calib_data=calib_data, num_calib_batches=1, quantize_mode=quantize_mode, quantize_granularity=quantize_granularity) qsym, _ = qnet.export(None) if check_fusion: check_fusion_parameter(qsym, attrs_dict) if check_calibration: check_qsym_calibrated(qsym, out_type, name=name) if check_scale_align: check_qsym_scale_align(qsym) quantized_out = qnet(*data) if one_output: quantized_out = [quantized_out] for i in range(len(ref_out)): min_range = mx.np.min(ref_out[i]).item() max_range = mx.np.max(ref_out[i]).item() atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
def test_self_attention(batch_size, seq_length, units, num_heads): class MultiHeadAttention(nn.HybridBlock): def __init__(self, units, num_heads, dtype='float32', **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self._units = units self._num_heads = num_heads self._fc = nn.Dense(in_units=self._units, units=3*self._units, flatten=False, dtype=dtype) self._scale = math.sqrt(self._units // self._num_heads) def forward(self, x, mask): x = mx.np.copy(x) out = self._fc(x) query, key, value = mx.np.split(out, 3, axis=-1) query = mx.npx.reshape(query, (-2, -2, self._num_heads, -1)) key = mx.npx.reshape(key, (-2, -2, self._num_heads, -1)) value = mx.npx.reshape(value, (-2, -2, self._num_heads, -1)) scores = mx.npx.batch_dot(mx.np.swapaxes(query, 1, 2), mx.np.swapaxes(key, 1, 2), transpose_b=True) mask = mx.np.expand_dims(mask, axis=1).astype(np.bool) attn_weights = mx.npx.masked_softmax(scores, mask=mask.astype(np.bool), axis=-1, temperature=self._scale) attn_weights = mx.npx.dropout(attn_weights, p=0.1) context_vec = mx.npx.batch_dot(attn_weights, mx.np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3)) context_vec = mx.npx.reshape(context_vec, (-2, -2, -1)) return context_vec net = MultiHeadAttention(units, num_heads) in_data = mx.np.random.uniform(size=[batch_size, seq_length, units], dtype='float32') mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, seq_length], dtype='int32') net.initialize() fused_net = net net.hybridize() ref_out = net(in_data, mask) fused_net.optimize_for(in_data, mask, backend="MKLDNN") out = fused_net(in_data, mask) mx.nd.waitall() for i in range(len(out)): assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy()) calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, mask), batch_size=1) qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto', exclude_layers=None, exclude_layers_match=None, calib_data=calib_data, calib_mode='naive', num_calib_batches=1, ctx=mx.cpu()) qout = qnet(in_data, mask) mx.nd.waitall() for i in range(len(ref_out)): min_range = np.min(ref_out[i].asnumpy()) max_range = np.max(ref_out[i].asnumpy()) atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
def check_quantize(sym, data_shape, out_type, name='conv', check_calibration=True, check_scale_align=False): quantize_granularity_list = ['tensor-wise'] if name == 'fc': quantize_granularity_list += ['channel-wise'] if name in config: name = config[name][OP_NAME] sym_sg = sym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) inputs = mx.sym.var('data', dtype='float32') sym_block = mx.gluon.SymbolBlock(sym, inputs) initialize_block_params(sym_block, mx.init.Normal(0.5)) min_value = -1 if out_type != 'uint8' else 0 data = mx.random.uniform(min_value, 1.0, shape=data_shape, dtype='float32', ctx=mx.current_context()) outputs = sym_block(data) for output in outputs: output.wait_to_read() ref_out = outputs arg_params, aux_params = collect_block_args_aux(sym_block, sym) excluded_sym_names = [] excluded_op_names = [] calib_data = mx.gluon.data.DataLoader(data, batch_size=1) for quantize_granularity in quantize_granularity_list: qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model( sym=sym_sg, arg_params=arg_params, aux_params=aux_params, ctx=mx.current_context(), excluded_sym_names=excluded_sym_names, excluded_op_names=excluded_op_names, quantized_dtype=out_type, calib_mode='naive', calib_data=calib_data, num_calib_batches=1, quantize_mode='full', quantize_granularity=quantize_granularity) qsym = qsym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) if check_calibration: check_qsym_calibrated(qsym, out_type, name=name) if check_scale_align: check_qsym_scale_align(qsym) quantized_out = check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data) for i in range(len(ref_out)): min_range = mx.nd.min(ref_out[i]).asscalar() max_range = mx.nd.max(ref_out[i]).asscalar() atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) check_qsym_dummy_forward(qsym, data, data_shape)