    def test_pass_quantize_convolutions(self) -> None:
        """Test pass."""
        data1 = np.float32(np.random.rand(1, 2, 2, 3))
        data2 = np.float32(np.random.rand(1, 2, 2, 3))
        graph1 = self.create_sample_graph(data1, data2)


        self.assertEqual(graph1.get_op('aqtz1').dtype, QUANTIZED_PACKED(),
                         '[Failed] Found output dtype of activation quantizer not proper')
        self.assertEqual(graph1.get_op('kqtz1').dtype, PackedUint32(),
                         '[Failed] Found output dtype of kernel quantizer not proper')
        self.assertEqual(graph1.get_op('conv2').dtype, Float32(),
                         '[Failed] Found output dtype of conv not proper')

        print("Test pass #5 quantize_convolutions passed!")
def pass_quantize_convolutions(graph: Graph) -> None:
    """Given a convolution node C, if C has proper quantization details, it will mark C as quantized and it will
       assign the correct output data types to the node C and its quantizers. Note that the expected output data type
       on the runtime is defined as QUANTIZED_NOT_PACKED.

        graph (Graph): The input graph. It will be modified in-place.

    b = 32

    exec_list = [n for n in sort_graph(graph) if n.op_type == 'Conv']
    for m in exec_list:
        conv_node = m

        # check if this is a quantized convolution
        if not conv_node.quantizer or not conv_node.a_quantizer:

        # Mark as quantized convolution
        conv_node.is_quantized = True

        # change the output data type of the convolution if thresholds are available
        if conv_node.has_thresholds:
            conv_node.dtype = QUANTIZED_PACKED()
            height = conv_node.height
            width = conv_node.width
            channels = conv_node.channels
            channels_upper = (channels + b - 1) // b
            conv_node.update_shape([channels_upper, height, width, 2, b],

        # change the output data type of the quantizers
        conv_node.quantizer.dtype = PackedUint32()
        for qtz in conv_node.a_quantizer:
            if isinstance(qtz, Lookup):
            qtz.dtype = QUANTIZED_PACKED()
            height = qtz.height
            width = qtz.width
            channels = qtz.channels
            channels_upper = (channels + b - 1) // b
            qtz.update_shape([channels_upper, height, width, 2, b], "ChHWBCl")
def pass_pack_weights(graph: Graph) -> None:
    """Given a Quantized convolution node C, it will pack the weights of C into 32 bit words.
       If the node Q that apply quantization to the weights of C quantizes, for example, into 1 bit values
       then one 32 bit word will contain 32 weights.

        graph (Graph): The input graph. It will be modified in-place.

    exec_list = [n for n in sort_graph(graph) if n.op_type == 'Conv']
    quantization_types = [
        'BinaryMeanScalingQuantizer', 'LinearMidTreadHalfQuantizer',

    word_size = 32
    weight_bitwidth = 1
    packer = Packer(weight_bitwidth, word_size)
    to_be_removed = []
    b = 32

    for m in exec_list:
        conv_node = m

        # check if this is a quantized convolution
        if not conv_node.quantizer or not conv_node.a_quantizer:

        # Check if we support this kind of quantizer
        weight_quantizer = conv_node.quantizer
        if weight_quantizer.op_type not in quantization_types:

        # Quantize the weights

        def pad_to_multiple_of_b(tensor, axis, b):
            shape = list(tensor.shape)
            pad = (((shape[axis] + b - 1) // b) * b) - shape[axis]
            shape[axis] = pad
            return np.zeros(shape) if pad else None

        padded_data = np.copy(weight_quantizer.data)

        for axis in [0, 3]:
            pad_tensor = pad_to_multiple_of_b(padded_data, axis, b)
            if pad_tensor is not None:
                padded_data = np.append(padded_data, pad_tensor, axis=axis)

        tca_output = np.copy(padded_data)
        oc, kh, kw, kd = padded_data.shape[:]
        padded_data = padded_data.flatten()
        tca_output = tca_output.flatten()

        out_index = 0
        for g in range(oc // b):
            for p in range(kd // b):
                for h in range(kh):
                    for w in range(kw):
                        for o in range(b):
                            for d in range(b):
                                idx = g * (kw * kh * kd * b) + p * b + h * (
                                    kw * kd) + w * kd + o * (kw * kh * kd) + d
                                tca_output[out_index] = padded_data[idx]
                                out_index += 1

        kn2row_output = np.zeros(oc * kh * kw * kd)
        out_index = 0
        for h in range(kh):
            for w in range(kw):
                for o in range(oc):
                    for i in range(kd):
                        idx = o * kh * kw * kd + h * kw * kd + w * kd + i
                        kn2row_output[out_index] = padded_data[idx]
                        out_index += 1

        op_data = weight_quantizer.binarizer(padded_data)
        data = packer.run(op_data.astype(np.float32),

        tca_binarized_data = weight_quantizer.binarizer(tca_output)
        tca_packed_data = packer.run(tca_binarized_data.astype(np.float32),

        kn2row_binarized_data = weight_quantizer.binarizer(kn2row_output)
        kn2row_data = packer.run(kn2row_binarized_data.astype(np.float32),

        shape = [oc, kh, kw, kd]
        tca_shape = [oc // b, kd // b, kh, kw, b, b]
        kn2row_shape = [kh, kw, oc, kd]

        # Create the new constant with the quantized weights
        quantized_constant = Constant(
            weight_quantizer.name + '_new',
            data=np.vectorize(lambda k: (~k) & ((0x1 << 32) - 1))(data),
            transposed_data=[(~k) & ((0x1 << 32) - 1)
                             for k in tca_packed_data.flatten()],
            kn2row_data=[k for k in kn2row_data.flatten()],

        # get nodes to be removed after being disconnected
        get_nodes_in_branch(weight_quantizer, None, to_be_removed)

        # Add the constant to the graph and connect the new constant
        for output_name, consumer_list in weight_quantizer.output_ops.items():
            for consumer_node in consumer_list:
                for input_name, input_node in consumer_node.input_ops.items():
                    if input_node == weight_quantizer:
                        consumer_node.add_input(input_name, quantized_constant)

    for op in to_be_removed: