def quantize_for_train(data, config): """ 对数据进行量化处理,返回量化处理后的数据 :param data: 需要量化处理的数据 :param config: 量化方式配置(dict) :return: 量化处理后的数据 """ shift, scale, offset, bitnum, update_step, m = _init_quant_param( config["save_name"], True, bitnum=config["bitnum"]) global_step = tf.train.get_or_create_global_step() new_shift, new_scale, new_offset, \ new_update_step, new_bitnum, new_m = tf.cond(tf.equal(global_step, update_step) | tf.less(global_step, 2), lambda: quantize_strategy[config["adaptive_strategy"]](data, global_step, bitnum, m, config), lambda: (shift, scale, offset, update_step, bitnum, m)) # 更新量化参数 assign_shift = tf.assign(shift, new_shift) assign_scale = tf.assign(scale, new_scale) assign_offset = tf.assign(offset, new_offset) assign_bitnum = tf.assign(bitnum, new_bitnum) assign_update_step = tf.assign(update_step, new_update_step) assign_m = tf.assign(m, new_m) with tf.control_dependencies([ assign_shift, assign_scale, assign_offset, assign_bitnum, assign_update_step, assign_m ]): quantize_data = float2fix(data, shift, scale, offset) return quantize_data
def _compute_bitnum(data, bitnum, config): """ 计算新的量化位宽 """ shift, f, o = compute_quant_param(data, bitnum) outdata = float2fix(data, shift, f, o, bitnum=bitnum) diff = _compute_mean_diff(data, outdata) loop = [diff, bitnum, data] cond = lambda diff, bitnum, data: tf.greater(diff, config["ths"]) & tf.less(bitnum, 33) body = lambda diff, bitnum, data: _loop_body(diff, bitnum, data) diff, bitnum, data = tf.while_loop(cond, body, loop) return bitnum
def _compute_interval(data, shift, scale, offset, bitnum, m, config): """ 计算量化间隔 """ diff1 = config["alpha"] * tf.abs(shift - m) quant_data = float2fix(data, shift, scale, offset, bitnum=bitnum) metrics = _compute_mean_diff(data, quant_data) diff2 = config["delta"] * metrics**2 diff = tf.maximum(diff1, diff2) interval = config["beta"] / diff - config["gamma"] interval = tf.cast(tf.maximum(interval, 1), tf.int64) interval = tf.minimum(interval, config["step_per_epoch"]) return interval
def quantize_for_eval(data, config): if config["offline"] == "ckpt": shift, scale, offset, bitnum = _init_quant_param( config["save_name"], True, bitnum=config["bitnum"]) elif config["offline"] == "pkl": shift, scale, offset, bitnum = restore_from_pkl(config) else: bitnum = config["bitnum"] shift, scale, offset = compute_quant_param( data, bitnum=config["bitnum"], ifscale=config["ifscale"], ifoffset=config["ifoffset"], ifchannel=config["ifchannel"]) quantize_data = float2fix(data, shift, scale, offset, bitnum) return quantize_data
def _loop_body(diff, bitnum, data): bitnum = bitnum + 8 new_shift, f, o = compute_quant_param(data, bitnum) outdata = float2fix(data, new_shift, f, o, bitnum=bitnum) diff = _compute_mean_diff(data, outdata) return diff, bitnum, data