Пример #1
0
def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape):
    mod = Module(symbol=qsym, label_names=None, context=mx.current_context())
    mod.bind(for_training=False, data_shapes=[('data', data_shape)])
    mod.set_params(qarg_params, qaux_params)
    mod.forward(batch, is_train=False)
    for output in mod.get_outputs():
        output.wait_to_read()
    return mod.get_outputs()
Пример #2
0
 def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape):
   mod = Module(symbol=qsym, label_names=None, context=mx.current_context())
   mod.bind(for_training=False,
            data_shapes=[('data', data_shape)])
   mod.set_params(qarg_params, qaux_params)
   data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes]
   batch = mx.io.DataBatch(data, [])
   mod.forward(batch, is_train=False)
   for output in mod.get_outputs():
       output.wait_to_read()
Пример #3
0
class Solver(object):
    def __init__(self,
                 symbol,
                 data_names,
                 label_names,
                 data_shapes,
                 label_shapes,
                 logger=logging,
                 context=mx.cpu(),
                 work_load_list=None,
                 fixed_param_names=None):
        self.symbol = symbol
        self.data_names = data_names
        self.label_names = label_names
        self.data_shapes = data_shapes
        self.label_shapes = label_shapes
        self.context = context
        self.work_load_list = work_load_list
        self.fixed_param_names = fixed_param_names

        if logger is None:
            logger = logging.getLogger()
            logger.setLevel(logging.INFO)
        self.logger = logger
        self.module = Module(symbol=self.symbol,
                             data_names=self.data_names,
                             label_names=self.label_names,
                             logger=self.logger,
                             context=self.context,
                             work_load_list=self.work_load_list,
                             fixed_param_names=self.fixed_param_names)

    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            validate_metric=None,
            work_load_list=None,
            epoch_end_callback=None,
            batch_end_callback=None,
            fixed_param_prefix=None,
            initializer=None,
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            optimizer=None,
            optimizer_params=None,
            begin_epoch=0,
            num_epoch=None,
            kvstore='device',
            teacher_modules=None):
        if type(teacher_modules) is not list:
            teacher_modules = [teacher_modules]
        self.module.bind(data_shapes=self.data_shapes,
                         label_shapes=self.label_shapes,
                         for_training=True)
        self.module.init_params(initializer=initializer,
                                arg_params=arg_params,
                                aux_params=aux_params,
                                allow_missing=allow_missing)
        self.module.init_optimizer(kvstore=kvstore,
                                   optimizer=optimizer,
                                   optimizer_params=optimizer_params)

        if validate_metric is None:
            validate_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        # training loop
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch

                if teacher_modules[0] is not None:
                    for teacher_module in teacher_modules:
                        teacher_module.forward(data_batch=data_batch,
                                               is_train=True)
                        transfer_label = teacher_module.get_outputs()
                        data_batch.label = data_batch.label + transfer_label
                self.module.forward(data_batch, is_train=True)
                self.module.backward()
                self.module.update()

                try:
                    next_data_batch = next(data_iter)
                except StopIteration:
                    end_of_batch = True

                self.module.update_metric(eval_metric, data_batch.label)

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            arg_params, aux_params = self.module.get_params()
            self.module.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)
            if eval_data:
                res = self.module.score(eval_data,
                                        validate_metric,
                                        score_end_callback=None,
                                        batch_end_callback=None,
                                        reset=True,
                                        epoch=epoch)
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            train_data.reset()
Пример #4
0
def quantize_model(sym,
                   arg_params,
                   aux_params,
                   data_names=('data', ),
                   label_names=('softmax_label', ),
                   ctx=cpu(),
                   excluded_sym_names=None,
                   calib_mode='entropy',
                   calib_data=None,
                   num_calib_examples=None,
                   calib_layer=None,
                   quantized_dtype='int8',
                   logger=logging):
    """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration.
    The backend quantized operators are only enabled for Linux systems. Please do not run
    inference using the quantized models on Windows for now.
    The quantization implementation adopts the TensorFlow's approach:
    https://www.tensorflow.org/performance/quantization.
    The calibration implementation borrows the idea of Nvidia's 8-bit Inference with TensorRT:
    http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
    and adapts the method to MXNet.
    
    Parameters
    ----------
    sym : str or Symbol
        Defines the structure of a neural network for FP32 data types.
    arg_params : dict
        Dictionary of name to `NDArray`.
    aux_params : dict
        Dictionary of name to `NDArray`.
    data_names : a list of strs
        Data names required for creating a Module object to run forward propagation on the
        calibration dataset.
    label_names : a list of strs
        Label names required for creating a Module object to run forward propagation on the
        calibration dataset.
    ctx : Context
        Defines the device that users want to run forward propagation on the calibration
        dataset for collecting layer output statistics. Currently, only supports single context.
    excluded_sym_names : list of strings
        A list of strings representing the names of the symbols that users want to excluding
        from being quantized.
    calib_mode : str
        If calib_mode='none', no calibration will be used and the thresholds for
        requantization after the corresponding layers will be calculated at runtime by
        calling min and max operators. The quantized models generated in this
        mode are normally 10-20% slower than those with calibrations during inference.
        If calib_mode='naive', the min and max values of the layer outputs from a calibration
        dataset will be directly taken as the thresholds for quantization.
        If calib_mode='entropy' (default mode), the thresholds for quantization will be
        derived such that the KL divergence between the distributions of FP32 layer outputs and
        quantized layer outputs is minimized based upon the calibration dataset.
    calib_data : DataIter
        A data iterator initialized by the calibration dataset.
    num_calib_examples : int or None
        The maximum number of examples that user would like to use for calibration. If not provided,
        the whole calibration dataset will be used.
    calib_layer : function
        Given a layer's output name in string, return True or False for deciding whether to
        calibrate this layer. If yes, the statistics of the layer's output will be collected;
        otherwise, no information of the layer's output will be collected. If not provided,
        all the layers' outputs that need requantization will be collected.
    quantized_dtype : str
        The quantized destination type for input data. Currently support 'int8'
        , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result.
        Default value is 'int8'.
    logger : Object
        A logging object for printing information during the process of quantization.
        
    Returns
    -------
    tuple
        A tuple of quantized symbol, quantized arg_params, and aux_params.
    -------
    """
    if excluded_sym_names is None:
        excluded_sym_names = []
    if not isinstance(excluded_sym_names, list):
        raise ValueError(
            'excluded_sym_names must be a list of strings representing'
            ' the names of the symbols that will not be quantized,'
            ' while received type %s' % str(type(excluded_sym_names)))

    logger.info('Quantizing symbol')
    if quantized_dtype not in ('int8', 'uint8', 'auto'):
        raise ValueError('unknown quantized_dtype %s received,'
                         ' expected `int8`, `uint8` or `auto`' %
                         quantized_dtype)
    qsym = _quantize_symbol(sym,
                            excluded_symbols=excluded_sym_names,
                            offline_params=list(arg_params.keys()),
                            quantized_dtype=quantized_dtype)

    th_dict = {}
    if calib_mode is not None and calib_mode != 'none':
        if not isinstance(ctx, Context):
            raise ValueError(
                'currently only supports single ctx, while received %s' %
                str(ctx))
        if calib_data is None:
            raise ValueError('calib_data must be provided when calib_mode=%s' %
                             calib_mode)
        if not isinstance(calib_data, DataIter):
            raise ValueError(
                'calib_data must be of DataIter type when calib_mode=%s,'
                ' while received type %s' %
                (calib_mode, str(type(calib_data))))

        mod = Module(symbol=sym,
                     data_names=data_names,
                     label_names=label_names,
                     context=ctx)
        if len(calib_data.provide_label) > 0:
            mod.bind(for_training=False,
                     data_shapes=calib_data.provide_data,
                     label_shapes=calib_data.provide_label)
        else:
            mod.bind(for_training=False, data_shapes=calib_data.provide_data)
        mod.set_params(arg_params, aux_params)
        if calib_mode == 'entropy':
            nd_dict, num_examples = _collect_layer_outputs(
                mod,
                calib_data,
                include_layer=calib_layer,
                max_num_examples=num_calib_examples,
                logger=logger)
            logger.info(
                'Collected layer outputs from FP32 model using %d examples' %
                num_examples)
            logger.info('Calculating optimal thresholds for quantization')
            th_dict = _get_optimal_thresholds(nd_dict,
                                              quantized_dtype,
                                              logger=logger)
        elif calib_mode == 'naive':
            th_dict, num_examples = _collect_layer_output_min_max(
                mod,
                calib_data,
                include_layer=calib_layer,
                max_num_examples=num_calib_examples,
                logger=logger)
            logger.info(
                'Collected layer output min/max values from FP32 model using %d examples'
                % num_examples)
        else:
            raise ValueError('unknown calibration mode %s received,'
                             ' expected `none`, `naive`, or `entropy`' %
                             calib_mode)
        logger.info('Calibrating quantized symbol')
        qsym = _calibrate_quantized_sym(qsym, th_dict)

    logger.info('Quantizing parameters')
    qarg_params = _quantize_params(qsym, arg_params, th_dict)

    return qsym, qarg_params, aux_params
Пример #5
0
    data = mx.io.ImageRecordIter(path_imgrec=dataset_path,
                                 label_width=1,
                                 preprocess_threads=data_nthreads,
                                 batch_size=batch_size,
                                 data_shape=data_shape,
                                 label_name=label_name,
                                 rand_crop=False,
                                 rand_mirror=False,
                                 shuffle=shuffle_dataset,
                                 shuffle_chunk_seed=shuffle_seed,
                                 seed=shuffle_seed,
                                 **mean_args)
    mod = Module(symbol=convnet_code_sym, label_names=None, context=ctx)
    mod.bind(for_training=False, data_shapes=data.provide_data)
    mod.set_params(arg_params, aux_params)
    num_images = 0
    convnet_codes = None  # N * 1000
    resized_images = None  # NCHW
    labels = None
    for batch in data:
        if num_images >= args.max_num_images:
            break
        mod.forward(data_batch=batch, is_train=False)
        fc_output = mod.get_outputs()[0].flatten().copyto(mx.cpu(0))
        num_images += batch_size
        fc_output.wait_to_read()
        if convnet_codes is None:
            convnet_codes = fc_output
        else:
            convnet_codes = mx.nd.concat(*[convnet_codes, fc_output], dim=0)
Пример #6
0
class Solver(object):
    def __init__(
            self,
            symbol,
            data_names,
            label_names,
            data_shapes,
            label_shapes,
            logger=logging,
            context=mx.cpu(),
            work_load_list=None,
            fixed_param_names=None,
            allow_missing=False,
            # for evaluate fold bn to create eval symbol
            config=None):
        self.symbol = symbol
        self.data_names = data_names
        self.label_names = label_names
        self.data_shapes = data_shapes
        self.label_shapes = label_shapes
        self.context = context
        self.work_load_list = work_load_list
        self.fixed_param_names = fixed_param_names

        if logger is None:
            logger = logging.getLogger()
            logger.setLevel(logging.INFO)
        self.logger = logger
        self.module = Module(symbol=self.symbol,
                             data_names=self.data_names,
                             label_names=self.label_names,
                             logger=self.logger,
                             context=self.context,
                             work_load_list=self.work_load_list,
                             fixed_param_names=self.fixed_param_names)
        # for fold bn
        self.config = config

    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            validate_metric=None,
            work_load_list=None,
            epoch_end_callback=None,
            batch_end_callback=None,
            fixed_param_prefix=None,
            initializer=None,
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            optimizer=None,
            optimizer_params=None,
            begin_epoch=0,
            num_epoch=None,
            kvstore='device'):

        self.module.bind(data_shapes=self.data_shapes,
                         label_shapes=self.label_shapes,
                         for_training=True)
        self.module.init_params(initializer=initializer,
                                arg_params=arg_params,
                                aux_params=aux_params,
                                allow_missing=allow_missing)
        self.module.init_optimizer(kvstore=kvstore,
                                   optimizer=optimizer,
                                   optimizer_params=optimizer_params)

        if validate_metric is None:
            validate_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        temp_count = 0

        # # test model size by saving params of model
        # arg_params, aux_params = self.module.get_params()
        # for callback in _as_list(epoch_end_callback):
        #     callback(0, self.symbol, arg_params, aux_params)
        # raise NotImplementedError

        # training loop
        for epoch in range(begin_epoch, num_epoch):

            train_time = AverageMeter()
            kvstore_sync_time = AverageMeter()
            get_data_time = AverageMeter()
            iter_total_time = AverageMeter()

            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                start_time = time.time()
                data_batch = next_data_batch

                self.module.forward(data_batch, is_train=True)
                self.module.backward()

                # ndarray.waitall()
                train_time.update(time.time() - start_time)

                self.module.update()

                # ndarray.waitall()
                kvstore_sync_time.update(time.time() - start_time)

                try:
                    next_data_batch = next(data_iter)
                except StopIteration:
                    end_of_batch = True

                # ndarray.waitall()
                get_data_time.update(time.time() - start_time)

                if isinstance(data_batch, list):
                    self.module.update_metric(eval_metric,
                                              [db.label for db in data_batch],
                                              pre_sliced=True)
                else:
                    self.module.update_metric(eval_metric, data_batch.label)

                # ndarray.waitall()
                iter_total_time.update(time.time() - start_time)

                if batch_end_callback is not None:
                    # batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
                    #                                  eval_metric=eval_metric,
                    #                                  locals=locals())

                    batch_end_params = BatchEndParam(
                        epoch=epoch,
                        nbatch=nbatch,
                        eval_metric=eval_metric,
                        locals=locals(),
                        rank=kvstore.rank,
                        total_iter=temp_count,
                        cur_data_time=get_data_time.val,
                        avg_data_time=get_data_time.avg,
                        cur_batch_time=train_time.val,
                        avg_batch_time=train_time.avg,
                        cur_kvstore_sync_time=kvstore_sync_time.val,
                        avg_kvstore_sync_time=kvstore_sync_time.avg,
                        cur_iter_total_time=iter_total_time.val,
                        avg_iter_total_time=iter_total_time.avg)
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1
                temp_count += 1

            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            arg_params, aux_params = self.module.get_params()
            self.module.set_params(arg_params, aux_params)

            if epoch_end_callback is not None and kvstore.rank == 0:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)
            if eval_data:
                if self.config.network == 'mobilenet_int8_foldbn':
                    # for fold bn to create inference symbol
                    total_params_path = "./model/%s-%04d.params" % (
                        self.config.model_prefix, epoch + 1)
                    # total_params_path = "./model/mobilenet_flodbn_0904/mobilenet_int8_flodbn_imagenet_retrain_80_pertensor-fold-0100.params"
                    # _, arg_params, aux_params = mx.model.load_checkpoint('./model/mobilenet_flodbn_0904/mobilenet_int8_flodbn_imagenet_retrain_80_pertensor-fold', 100)
                    import os
                    assert os.path.exists(
                        total_params_path
                    ), "please provide the correct total_params_path for foldbn eval"
                    eval_sym = eval(self.config.network)(
                        num_classes=self.config.num_classes,
                        quant_mod=self.config.quant_mod,
                        delay_quant=self.config.delay_quant,
                        is_weight_perchannel=self.config.is_weight_perchannel,
                        total_params_path=total_params_path,
                        quantize_flag=self.config.quantize_flag)
                    eval_module = Module(
                        symbol=eval_sym,
                        data_names=self.data_names,
                        label_names=self.label_names,
                        logger=self.logger,
                        context=self.context,
                        work_load_list=self.work_load_list,
                        fixed_param_names=self.fixed_param_names)
                    eval_module.bind(data_shapes=self.data_shapes,
                                     label_shapes=self.label_shapes,
                                     for_training=False)
                    eval_module.init_params(initializer=initializer,
                                            arg_params=arg_params,
                                            aux_params=aux_params)
                    res = eval_module.score(eval_data,
                                            validate_metric,
                                            score_end_callback=None,
                                            batch_end_callback=None,
                                            reset=True,
                                            epoch=epoch)
                else:
                    res = self.module.score(eval_data,
                                            validate_metric,
                                            score_end_callback=None,
                                            batch_end_callback=None,
                                            reset=True,
                                            epoch=epoch)
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            train_data.reset()