示例#1
0
    def from_float(cls, mod):
        r"""Create a dynamic quantized module from a float module or qparams_dict

        Args:
            mod (Module): a float module, either produced by torch.quantization
                          utilities or provided by the user
        """
        assert type(
            mod
        ) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear'
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        if mod.qconfig is not None and mod.qconfig.weight() is not None:
            weight_observer = mod.qconfig.weight()
        else:
            # We have the circular import issues if we import the qconfig in the beginning of this file:
            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
            # import until we need it.
            from torch.quantization.QConfig import default_dynamic_qconfig
            weight_observer = default_dynamic_qconfig.weight()
        assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
        weight_observer(mod.weight)
        wt_scale, wt_zp = weight_observer.calculate_qparams()
        qweight = torch.quantize_linear(mod.weight.float(), float(wt_scale),
                                        int(wt_zp), torch.qint8)
        qlinear = Linear(mod.in_features, mod.out_features)
        qlinear.set_weight_bias(qweight, mod.bias)
        return qlinear
示例#2
0
文件: rnn.py 项目: yuk12/pytorch
    def from_float(cls, mod):
        assert type(
            mod
        ) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM'
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        if mod.qconfig is not None and mod.qconfig.weight() is not None:
            weight_observer = mod.qconfig.weight()
        else:
            # We have the circular import issues if we import the qconfig in the beginning of this file:
            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
            # import until we need it.
            from torch.quantization.QConfig import default_dynamic_qconfig
            weight_observer = default_dynamic_qconfig.weight()
        assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'

        if mod.mode == 'LSTM':
            qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
                            mod.bias, mod.batch_first, mod.dropout,
                            mod.bidirectional)

        num_directions = 2 if mod.bidirectional else 1

        assert mod.bias

        # TODO: support more than just LSTM
        if qRNNBase.mode != 'LSTM':
            raise RuntimeError('Only LSTM is supported for QuantizedRNN')

        qRNNBase._all_weights = []
        packed_weights = []
        quantized_weights = []
        for layer in range(qRNNBase.num_layers):
            for direction in range(num_directions):
                layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions

                def process_weights(ihhh, layer, suffix):
                    weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
                    bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)

                    weight = getattr(mod, weight_name)
                    bias = getattr(mod, bias_name)
                    # for each layer, for each direction we need to quantize and pack
                    # weights and pack parameters in this order:
                    #
                    #   w_ih, w_hh, b_ih, b_hh
                    weight_observer(weight)
                    wt_scale, wt_zp = weight_observer.calculate_qparams()
                    qweight = torch.quantize_linear(weight.float(),
                                                    float(wt_scale),
                                                    int(wt_zp), torch.qint8)
                    packed_weight = \
                        torch.ops.quantized.linear_prepack(qweight, bias)

                    params = [packed_weight, bias]
                    pos_names = ['w', 'b']
                    ret_name = [
                        '{}_{}_l{}{}'.format(name, ihhh, layer, suffix)
                        for name in pos_names
                    ]
                    quantized_weights.append(qweight)
                    packed_weights.append(ret_name[0])
                    return params, ret_name

                suffix = '_reverse' if direction == 1 else ''
                ih_params, ih_param_names = process_weights(
                    'ih', layer, suffix)
                hh_params, hh_param_names = process_weights(
                    'hh', layer, suffix)

                for (ih, ih_name), (hh, hh_name) in zip(
                        zip(ih_params, ih_param_names),
                        zip(hh_params, hh_param_names)):
                    qRNNBase.register_buffer(
                        ih_name,
                        torch.tensor(ih)
                        if not isinstance(ih, torch.Tensor) else ih)
                    qRNNBase.register_buffer(
                        hh_name,
                        torch.tensor(hh)
                        if not isinstance(hh, torch.Tensor) else hh)
                    qRNNBase._all_weights.extend([ih_name, hh_name])

        qRNNBase._packed_weights = packed_weights
        # DO WE NEED _quantized_weights? @jianyuh: will remove _quantized_weight as now we support the fbgemm_linear_unpack function
        qRNNBase._quantized_weights = quantized_weights

        return qRNNBase
示例#3
0
    def from_float(cls, mod):
        assert type(
            mod
        ) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM'
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'

        if mod.qconfig is not None and mod.qconfig.weight is not None:
            weight_observer = mod.qconfig.weight()
        else:
            # We have the circular import issues if we import the qconfig in the beginning of this file:
            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
            # import until we need it.
            from torch.quantization.QConfig import default_dynamic_qconfig
            weight_observer = default_dynamic_qconfig.weight()

        dtype = weight_observer.dtype
        supported_scalar_types = [torch.qint8, torch.float16]
        if dtype not in supported_scalar_types:
            raise RuntimeError(
                'Unsupported dtype for dynamic RNN quantization: {}'.format(
                    dtype))

        if mod.mode == 'LSTM':
            qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
                            mod.bias, mod.batch_first, mod.dropout,
                            mod.bidirectional, dtype)
        else:
            raise NotImplementedError(
                'Only LSTM is supported for QuantizedRNN for now')

        num_directions = 2 if mod.bidirectional else 1

        assert mod.bias

        qRNNBase._all_weight_names = []
        qRNNBase._all_weight_values = []
        for layer in range(qRNNBase.num_layers):
            for direction in range(num_directions):
                layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions

                def process_weights(ihhh, layer, suffix, dtype):
                    weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
                    bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)

                    weight = getattr(mod, weight_name)
                    bias = getattr(mod, bias_name)

                    if dtype == torch.qint8:
                        # for each layer, for each direction we need to quantize and pack
                        # weights and pack parameters in this order:
                        #
                        #   w_ih, w_hh
                        weight_observer(weight)
                        wt_scale, wt_zp = weight_observer.calculate_qparams()
                        qweight = torch.quantize_per_tensor(
                            weight.float(), float(wt_scale), int(wt_zp),
                            torch.qint8)
                        packed_weight = \
                            torch.ops.quantized.linear_prepack(qweight, bias)

                        params = [packed_weight]
                        pos_names = ['w']
                        ret_name = [
                            '{}_{}_l{}{}'.format(name, ihhh, layer, suffix)
                            for name in pos_names
                        ]
                        return params, ret_name
                    else:
                        # for each layer, for each direction we need to quantize and pack
                        # weights and pack parameters in this order:
                        #
                        #   packed_ih, packed_hh, b_ih, b_hh
                        packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
                            weight.float())

                        params = [packed_weight, bias]
                        pos_names = ['packed', 'b']
                        ret_name = [
                            '{}_{}_l{}{}'.format(name, ihhh, layer, suffix)
                            for name in pos_names
                        ]
                        return params, ret_name

                suffix = '_reverse' if direction == 1 else ''
                ih_params, ih_param_names = process_weights(
                    'ih', layer, suffix, dtype)
                hh_params, hh_param_names = process_weights(
                    'hh', layer, suffix, dtype)

                for (ih, ih_name), (hh, hh_name) in zip(
                        zip(ih_params, ih_param_names),
                        zip(hh_params, hh_param_names)):
                    qRNNBase._all_weight_names.extend([ih_name, hh_name])
                    qRNNBase._all_weight_values.extend([ih, hh])

        return qRNNBase