def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape): mod = Module(symbol=qsym, label_names=None, context=mx.current_context()) mod.bind(for_training=False, data_shapes=[('data', data_shape)]) mod.set_params(qarg_params, qaux_params) mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read() return mod.get_outputs()
def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape): mod = Module(symbol=qsym, label_names=None, context=mx.current_context()) mod.bind(for_training=False, data_shapes=[('data', data_shape)]) mod.set_params(qarg_params, qaux_params) data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes] batch = mx.io.DataBatch(data, []) mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read()
class Solver(object): def __init__(self, symbol, data_names, label_names, data_shapes, label_shapes, logger=logging, context=mx.cpu(), work_load_list=None, fixed_param_names=None): self.symbol = symbol self.data_names = data_names self.label_names = label_names self.data_shapes = data_shapes self.label_shapes = label_shapes self.context = context self.work_load_list = work_load_list self.fixed_param_names = fixed_param_names if logger is None: logger = logging.getLogger() logger.setLevel(logging.INFO) self.logger = logger self.module = Module(symbol=self.symbol, data_names=self.data_names, label_names=self.label_names, logger=self.logger, context=self.context, work_load_list=self.work_load_list, fixed_param_names=self.fixed_param_names) def fit(self, train_data, eval_data=None, eval_metric='acc', validate_metric=None, work_load_list=None, epoch_end_callback=None, batch_end_callback=None, fixed_param_prefix=None, initializer=None, arg_params=None, aux_params=None, allow_missing=False, optimizer=None, optimizer_params=None, begin_epoch=0, num_epoch=None, kvstore='device', teacher_modules=None): if type(teacher_modules) is not list: teacher_modules = [teacher_modules] self.module.bind(data_shapes=self.data_shapes, label_shapes=self.label_shapes, for_training=True) self.module.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing) self.module.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validate_metric is None: validate_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) # training loop for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if teacher_modules[0] is not None: for teacher_module in teacher_modules: teacher_module.forward(data_batch=data_batch, is_train=True) transfer_label = teacher_module.get_outputs() data_batch.label = data_batch.label + transfer_label self.module.forward(data_batch, is_train=True) self.module.backward() self.module.update() try: next_data_batch = next(data_iter) except StopIteration: end_of_batch = True self.module.update_metric(eval_metric, data_batch.label) if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) arg_params, aux_params = self.module.get_params() self.module.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) if eval_data: res = self.module.score(eval_data, validate_metric, score_end_callback=None, batch_end_callback=None, reset=True, epoch=epoch) for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) train_data.reset()
def quantize_model(sym, arg_params, aux_params, data_names=('data', ), label_names=('softmax_label', ), ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, calib_layer=None, quantized_dtype='int8', logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. The quantization implementation adopts the TensorFlow's approach: https://www.tensorflow.org/performance/quantization. The calibration implementation borrows the idea of Nvidia's 8-bit Inference with TensorRT: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf and adapts the method to MXNet. Parameters ---------- sym : str or Symbol Defines the structure of a neural network for FP32 data types. arg_params : dict Dictionary of name to `NDArray`. aux_params : dict Dictionary of name to `NDArray`. data_names : a list of strs Data names required for creating a Module object to run forward propagation on the calibration dataset. label_names : a list of strs Label names required for creating a Module object to run forward propagation on the calibration dataset. ctx : Context Defines the device that users want to run forward propagation on the calibration dataset for collecting layer output statistics. Currently, only supports single context. excluded_sym_names : list of strings A list of strings representing the names of the symbols that users want to excluding from being quantized. calib_mode : str If calib_mode='none', no calibration will be used and the thresholds for requantization after the corresponding layers will be calculated at runtime by calling min and max operators. The quantized models generated in this mode are normally 10-20% slower than those with calibrations during inference. If calib_mode='naive', the min and max values of the layer outputs from a calibration dataset will be directly taken as the thresholds for quantization. If calib_mode='entropy' (default mode), the thresholds for quantization will be derived such that the KL divergence between the distributions of FP32 layer outputs and quantized layer outputs is minimized based upon the calibration dataset. calib_data : DataIter A data iterator initialized by the calibration dataset. num_calib_examples : int or None The maximum number of examples that user would like to use for calibration. If not provided, the whole calibration dataset will be used. calib_layer : function Given a layer's output name in string, return True or False for deciding whether to calibrate this layer. If yes, the statistics of the layer's output will be collected; otherwise, no information of the layer's output will be collected. If not provided, all the layers' outputs that need requantization will be collected. quantized_dtype : str The quantized destination type for input data. Currently support 'int8' , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. Default value is 'int8'. logger : Object A logging object for printing information during the process of quantization. Returns ------- tuple A tuple of quantized symbol, quantized arg_params, and aux_params. ------- """ if excluded_sym_names is None: excluded_sym_names = [] if not isinstance(excluded_sym_names, list): raise ValueError( 'excluded_sym_names must be a list of strings representing' ' the names of the symbols that will not be quantized,' ' while received type %s' % str(type(excluded_sym_names))) logger.info('Quantizing symbol') if quantized_dtype not in ('int8', 'uint8', 'auto'): raise ValueError('unknown quantized_dtype %s received,' ' expected `int8`, `uint8` or `auto`' % quantized_dtype) qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, offline_params=list(arg_params.keys()), quantized_dtype=quantized_dtype) th_dict = {} if calib_mode is not None and calib_mode != 'none': if not isinstance(ctx, Context): raise ValueError( 'currently only supports single ctx, while received %s' % str(ctx)) if calib_data is None: raise ValueError('calib_data must be provided when calib_mode=%s' % calib_mode) if not isinstance(calib_data, DataIter): raise ValueError( 'calib_data must be of DataIter type when calib_mode=%s,' ' while received type %s' % (calib_mode, str(type(calib_data)))) mod = Module(symbol=sym, data_names=data_names, label_names=label_names, context=ctx) if len(calib_data.provide_label) > 0: mod.bind(for_training=False, data_shapes=calib_data.provide_data, label_shapes=calib_data.provide_label) else: mod.bind(for_training=False, data_shapes=calib_data.provide_data) mod.set_params(arg_params, aux_params) if calib_mode == 'entropy': nd_dict, num_examples = _collect_layer_outputs( mod, calib_data, include_layer=calib_layer, max_num_examples=num_calib_examples, logger=logger) logger.info( 'Collected layer outputs from FP32 model using %d examples' % num_examples) logger.info('Calculating optimal thresholds for quantization') th_dict = _get_optimal_thresholds(nd_dict, quantized_dtype, logger=logger) elif calib_mode == 'naive': th_dict, num_examples = _collect_layer_output_min_max( mod, calib_data, include_layer=calib_layer, max_num_examples=num_calib_examples, logger=logger) logger.info( 'Collected layer output min/max values from FP32 model using %d examples' % num_examples) else: raise ValueError('unknown calibration mode %s received,' ' expected `none`, `naive`, or `entropy`' % calib_mode) logger.info('Calibrating quantized symbol') qsym = _calibrate_quantized_sym(qsym, th_dict) logger.info('Quantizing parameters') qarg_params = _quantize_params(qsym, arg_params, th_dict) return qsym, qarg_params, aux_params
data = mx.io.ImageRecordIter(path_imgrec=dataset_path, label_width=1, preprocess_threads=data_nthreads, batch_size=batch_size, data_shape=data_shape, label_name=label_name, rand_crop=False, rand_mirror=False, shuffle=shuffle_dataset, shuffle_chunk_seed=shuffle_seed, seed=shuffle_seed, **mean_args) mod = Module(symbol=convnet_code_sym, label_names=None, context=ctx) mod.bind(for_training=False, data_shapes=data.provide_data) mod.set_params(arg_params, aux_params) num_images = 0 convnet_codes = None # N * 1000 resized_images = None # NCHW labels = None for batch in data: if num_images >= args.max_num_images: break mod.forward(data_batch=batch, is_train=False) fc_output = mod.get_outputs()[0].flatten().copyto(mx.cpu(0)) num_images += batch_size fc_output.wait_to_read() if convnet_codes is None: convnet_codes = fc_output else: convnet_codes = mx.nd.concat(*[convnet_codes, fc_output], dim=0)
class Solver(object): def __init__( self, symbol, data_names, label_names, data_shapes, label_shapes, logger=logging, context=mx.cpu(), work_load_list=None, fixed_param_names=None, allow_missing=False, # for evaluate fold bn to create eval symbol config=None): self.symbol = symbol self.data_names = data_names self.label_names = label_names self.data_shapes = data_shapes self.label_shapes = label_shapes self.context = context self.work_load_list = work_load_list self.fixed_param_names = fixed_param_names if logger is None: logger = logging.getLogger() logger.setLevel(logging.INFO) self.logger = logger self.module = Module(symbol=self.symbol, data_names=self.data_names, label_names=self.label_names, logger=self.logger, context=self.context, work_load_list=self.work_load_list, fixed_param_names=self.fixed_param_names) # for fold bn self.config = config def fit(self, train_data, eval_data=None, eval_metric='acc', validate_metric=None, work_load_list=None, epoch_end_callback=None, batch_end_callback=None, fixed_param_prefix=None, initializer=None, arg_params=None, aux_params=None, allow_missing=False, optimizer=None, optimizer_params=None, begin_epoch=0, num_epoch=None, kvstore='device'): self.module.bind(data_shapes=self.data_shapes, label_shapes=self.label_shapes, for_training=True) self.module.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing) self.module.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validate_metric is None: validate_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) temp_count = 0 # # test model size by saving params of model # arg_params, aux_params = self.module.get_params() # for callback in _as_list(epoch_end_callback): # callback(0, self.symbol, arg_params, aux_params) # raise NotImplementedError # training loop for epoch in range(begin_epoch, num_epoch): train_time = AverageMeter() kvstore_sync_time = AverageMeter() get_data_time = AverageMeter() iter_total_time = AverageMeter() tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: start_time = time.time() data_batch = next_data_batch self.module.forward(data_batch, is_train=True) self.module.backward() # ndarray.waitall() train_time.update(time.time() - start_time) self.module.update() # ndarray.waitall() kvstore_sync_time.update(time.time() - start_time) try: next_data_batch = next(data_iter) except StopIteration: end_of_batch = True # ndarray.waitall() get_data_time.update(time.time() - start_time) if isinstance(data_batch, list): self.module.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True) else: self.module.update_metric(eval_metric, data_batch.label) # ndarray.waitall() iter_total_time.update(time.time() - start_time) if batch_end_callback is not None: # batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, # eval_metric=eval_metric, # locals=locals()) batch_end_params = BatchEndParam( epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals(), rank=kvstore.rank, total_iter=temp_count, cur_data_time=get_data_time.val, avg_data_time=get_data_time.avg, cur_batch_time=train_time.val, avg_batch_time=train_time.avg, cur_kvstore_sync_time=kvstore_sync_time.val, avg_kvstore_sync_time=kvstore_sync_time.avg, cur_iter_total_time=iter_total_time.val, avg_iter_total_time=iter_total_time.avg) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 temp_count += 1 for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) arg_params, aux_params = self.module.get_params() self.module.set_params(arg_params, aux_params) if epoch_end_callback is not None and kvstore.rank == 0: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) if eval_data: if self.config.network == 'mobilenet_int8_foldbn': # for fold bn to create inference symbol total_params_path = "./model/%s-%04d.params" % ( self.config.model_prefix, epoch + 1) # total_params_path = "./model/mobilenet_flodbn_0904/mobilenet_int8_flodbn_imagenet_retrain_80_pertensor-fold-0100.params" # _, arg_params, aux_params = mx.model.load_checkpoint('./model/mobilenet_flodbn_0904/mobilenet_int8_flodbn_imagenet_retrain_80_pertensor-fold', 100) import os assert os.path.exists( total_params_path ), "please provide the correct total_params_path for foldbn eval" eval_sym = eval(self.config.network)( num_classes=self.config.num_classes, quant_mod=self.config.quant_mod, delay_quant=self.config.delay_quant, is_weight_perchannel=self.config.is_weight_perchannel, total_params_path=total_params_path, quantize_flag=self.config.quantize_flag) eval_module = Module( symbol=eval_sym, data_names=self.data_names, label_names=self.label_names, logger=self.logger, context=self.context, work_load_list=self.work_load_list, fixed_param_names=self.fixed_param_names) eval_module.bind(data_shapes=self.data_shapes, label_shapes=self.label_shapes, for_training=False) eval_module.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params) res = eval_module.score(eval_data, validate_metric, score_end_callback=None, batch_end_callback=None, reset=True, epoch=epoch) else: res = self.module.score(eval_data, validate_metric, score_end_callback=None, batch_end_callback=None, reset=True, epoch=epoch) for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) train_data.reset()