def calibration(net, dev_data_list, num_calib_batches, quantized_dtype, calib_mode): """calibration function on the dev dataset.""" assert len(dev_data_list) == 1, \ 'Currectly, MNLI not supported.' assert ctx == mx.cpu(), \ 'Currently only supports CPU with MKL-DNN backend.' logging.info('Now we are doing calibration on dev with %s.', ctx) for _, dev_data in dev_data_list: collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=logging) num_calib_examples = dev_batch_size * num_calib_batches net = mx.contrib.quantization.quantize_net_v2( net, quantized_dtype=quantized_dtype, exclude_layers=[], quantize_mode='smart', quantize_granularity='channel-wise', calib_data=dev_data, calib_mode=calib_mode, num_calib_examples=num_calib_examples, ctx=ctx, LayerOutputCollector=collector, logger=logging) # save params ckpt_name = 'model_bert_{0}_quantized_{1}'.format( task_name, calib_mode) params_saved = os.path.join(output_dir, ckpt_name) net.export(params_saved, epoch=0) logging.info('Saving quantized model at %s', output_dir)
def calibration(net, num_calib_batches, quantized_dtype, calib_mode): """calibration function on the dev dataset.""" log.info('Loading dev data...') if version_2: dev_data = SQuAD('dev', version='2.0') else: dev_data = SQuAD('dev', version='1.1') if args.debug: sampled_data = [dev_data[0], dev_data[1], dev_data[2]] dev_data = mx.gluon.data.SimpleDataset(sampled_data) log.info('Number of records in dev data:{}'.format(len(dev_data))) batchify_fn_calib = nlp.data.batchify.Tuple( nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to), nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to), nlp.data.batchify.Stack('float32'), nlp.data.batchify.Stack('float32')) dev_data_transform = preprocess_dataset(tokenizer, dev_data, max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, input_features=True, for_calibration=True) dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform, batchify_fn=batchify_fn_calib, num_workers=4, batch_size=test_batch_size, shuffle=False, last_batch='keep') assert ctx == mx.cpu(), \ 'Currently only supports CPU with MKL-DNN backend.' log.info('Now we are doing calibration on dev with %s.', ctx) collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=log) num_calib_examples = test_batch_size * num_calib_batches net = mx.contrib.quantization.quantize_net_v2( net, quantized_dtype=quantized_dtype, exclude_layers=[], quantize_mode='smart', quantize_granularity='channel-wise', calib_data=dev_dataloader, calib_mode=calib_mode, num_calib_examples=num_calib_examples, ctx=ctx, LayerOutputCollector=collector, logger=log) # save params ckpt_name = 'model_bert_squad_quantized_{0}'.format(calib_mode) params_saved = os.path.join(output_dir, ckpt_name) net.export(params_saved, epoch=0) log.info('Saving quantized model at %s', output_dir)
def calibration(net, dev_data, num_calib_batches, quantized_dtype, calib_mode): """calibration function on the dev dataset.""" print('Now we are doing calibration on dev with cpu.') collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=None) num_calib_examples = dev_batch_size * num_calib_batches quantized_net = mx.contrib.quantization.quantize_net_v2( net, quantized_dtype=quantized_dtype, exclude_layers=[], quantize_mode='smart', quantize_granularity='channel-wise', calib_data=dev_data, calib_mode=calib_mode, num_calib_examples=num_calib_examples, ctx=mx.cpu(), LayerOutputCollector=collector, logger=None) print('Calibration done with success.') return quantized_net
def calibration(net, num_calib_batches, quantized_dtype, calib_mode): """calibration function on the dev dataset.""" log.info('Loading dev data...') if version_2: dev_data = SQuAD('dev', version='2.0') else: dev_data = SQuAD('dev', version='1.1') if args.debug: sampled_data = [dev_data[0], dev_data[1], dev_data[2]] dev_data = mx.gluon.data.SimpleDataset(sampled_data) log.info('Number of records in dev data:{}'.format(len(dev_data))) origin_dev_data_len = len(dev_data) num_calib_examples = test_batch_size * num_calib_batches ### randomly select the calib data from full dataset random_indices = np.random.choice(origin_dev_data_len, num_calib_examples) print ('random_indices: ', random_indices) dev_data=list(dev_data[i] for i in random_indices) log.info('Number of records in dev data:{}'.format(len(dev_data))) batchify_fn_calib = nlp.data.batchify.Tuple( nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to), nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], round_to=args.round_to), nlp.data.batchify.Stack('float32'), nlp.data.batchify.Stack('float32')) dev_data_transform = preprocess_dataset(tokenizer, dev_data, max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, input_features=True, for_calibration=True) dev_dataloader = mx.gluon.data.DataLoader( dev_data_transform, batchify_fn=batchify_fn_calib, num_workers=4, batch_size=test_batch_size, shuffle=True, last_batch='keep') net = run_pass(net, 'custom_pass') assert ctx == mx.cpu(), \ 'Currently only supports CPU with MKL-DNN backend.' log.info('Now we are doing calibration on dev with %s.', ctx) collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=log) net = mx.contrib.quantization.quantize_net_v2(net, quantized_dtype=quantized_dtype, exclude_layers=[], quantize_mode='smart', quantize_granularity='tensor-wise', calib_data=dev_dataloader, calib_mode=calib_mode, num_calib_examples=num_calib_examples, ctx=ctx, LayerOutputCollector=collector, logger=log) if scenario == "offline": net = run_pass(net, 'softmax_mask') else: net = run_pass(net, 'normal_softmax') net = run_pass(net, 'bias_to_s32') # # save params ckpt_name = 'model_bert_squad_quantized_{0}'.format(calib_mode) params_saved = os.path.join(output_dir, ckpt_name) net.hybridize(static_alloc=True, static_shape=True) a = mx.nd.ones((test_batch_size, max_seq_length), dtype='float32') b = mx.nd.ones((test_batch_size, max_seq_length), dtype='float32') c = mx.nd.ones((test_batch_size, ), dtype='float32') net(a,b,c) mx.nd.waitall() net.export(params_saved, epoch=0) log.info('Saving quantized model at %s', output_dir)