def from_float(cls, mod): r"""Create a dynamic quantized module from a float module or qparams_dict Args: mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ assert type( mod ) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear' assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' if mod.qconfig is not None and mod.qconfig.weight() is not None: weight_observer = mod.qconfig.weight() else: # We have the circular import issues if we import the qconfig in the beginning of this file: # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the # import until we need it. from torch.quantization.QConfig import default_dynamic_qconfig weight_observer = default_dynamic_qconfig.weight() assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' weight_observer(mod.weight) wt_scale, wt_zp = weight_observer.calculate_qparams() qweight = torch.quantize_linear(mod.weight.float(), float(wt_scale), int(wt_zp), torch.qint8) qlinear = Linear(mod.in_features, mod.out_features) qlinear.set_weight_bias(qweight, mod.bias) return qlinear
def from_float(cls, mod): assert type( mod ) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM' assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' if mod.qconfig is not None and mod.qconfig.weight() is not None: weight_observer = mod.qconfig.weight() else: # We have the circular import issues if we import the qconfig in the beginning of this file: # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the # import until we need it. from torch.quantization.QConfig import default_dynamic_qconfig weight_observer = default_dynamic_qconfig.weight() assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' if mod.mode == 'LSTM': qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers, mod.bias, mod.batch_first, mod.dropout, mod.bidirectional) num_directions = 2 if mod.bidirectional else 1 assert mod.bias # TODO: support more than just LSTM if qRNNBase.mode != 'LSTM': raise RuntimeError('Only LSTM is supported for QuantizedRNN') qRNNBase._all_weights = [] packed_weights = [] quantized_weights = [] for layer in range(qRNNBase.num_layers): for direction in range(num_directions): layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions def process_weights(ihhh, layer, suffix): weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix) bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix) weight = getattr(mod, weight_name) bias = getattr(mod, bias_name) # for each layer, for each direction we need to quantize and pack # weights and pack parameters in this order: # # w_ih, w_hh, b_ih, b_hh weight_observer(weight) wt_scale, wt_zp = weight_observer.calculate_qparams() qweight = torch.quantize_linear(weight.float(), float(wt_scale), int(wt_zp), torch.qint8) packed_weight = \ torch.ops.quantized.linear_prepack(qweight, bias) params = [packed_weight, bias] pos_names = ['w', 'b'] ret_name = [ '{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names ] quantized_weights.append(qweight) packed_weights.append(ret_name[0]) return params, ret_name suffix = '_reverse' if direction == 1 else '' ih_params, ih_param_names = process_weights( 'ih', layer, suffix) hh_params, hh_param_names = process_weights( 'hh', layer, suffix) for (ih, ih_name), (hh, hh_name) in zip( zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)): qRNNBase.register_buffer( ih_name, torch.tensor(ih) if not isinstance(ih, torch.Tensor) else ih) qRNNBase.register_buffer( hh_name, torch.tensor(hh) if not isinstance(hh, torch.Tensor) else hh) qRNNBase._all_weights.extend([ih_name, hh_name]) qRNNBase._packed_weights = packed_weights # DO WE NEED _quantized_weights? @jianyuh: will remove _quantized_weight as now we support the fbgemm_linear_unpack function qRNNBase._quantized_weights = quantized_weights return qRNNBase
def from_float(cls, mod): assert type( mod ) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM' assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() else: # We have the circular import issues if we import the qconfig in the beginning of this file: # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the # import until we need it. from torch.quantization.QConfig import default_dynamic_qconfig weight_observer = default_dynamic_qconfig.weight() dtype = weight_observer.dtype supported_scalar_types = [torch.qint8, torch.float16] if dtype not in supported_scalar_types: raise RuntimeError( 'Unsupported dtype for dynamic RNN quantization: {}'.format( dtype)) if mod.mode == 'LSTM': qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers, mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype) else: raise NotImplementedError( 'Only LSTM is supported for QuantizedRNN for now') num_directions = 2 if mod.bidirectional else 1 assert mod.bias qRNNBase._all_weight_names = [] qRNNBase._all_weight_values = [] for layer in range(qRNNBase.num_layers): for direction in range(num_directions): layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions def process_weights(ihhh, layer, suffix, dtype): weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix) bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix) weight = getattr(mod, weight_name) bias = getattr(mod, bias_name) if dtype == torch.qint8: # for each layer, for each direction we need to quantize and pack # weights and pack parameters in this order: # # w_ih, w_hh weight_observer(weight) wt_scale, wt_zp = weight_observer.calculate_qparams() qweight = torch.quantize_per_tensor( weight.float(), float(wt_scale), int(wt_zp), torch.qint8) packed_weight = \ torch.ops.quantized.linear_prepack(qweight, bias) params = [packed_weight] pos_names = ['w'] ret_name = [ '{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names ] return params, ret_name else: # for each layer, for each direction we need to quantize and pack # weights and pack parameters in this order: # # packed_ih, packed_hh, b_ih, b_hh packed_weight = torch.fbgemm_pack_gemm_matrix_fp16( weight.float()) params = [packed_weight, bias] pos_names = ['packed', 'b'] ret_name = [ '{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names ] return params, ret_name suffix = '_reverse' if direction == 1 else '' ih_params, ih_param_names = process_weights( 'ih', layer, suffix, dtype) hh_params, hh_param_names = process_weights( 'hh', layer, suffix, dtype) for (ih, ih_name), (hh, hh_name) in zip( zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)): qRNNBase._all_weight_names.extend([ih_name, hh_name]) qRNNBase._all_weight_values.extend([ih, hh]) return qRNNBase