Exemple #1
0
    def get_quantizer_output(self) -> QuantizerOutput:
        """
        Quantization information can be stored both in q_output attribute
        and in meta attributes
        TODO: Merge approaches
        """
        if not self.is_quantized():
            raise ValueError("No quantization output found. Quantize this"
                             " XGraph object before retrieving the"
                             " quantization output")

        if (self.quantizer_output is not None
                and "is_quantized" in self.meta_attrs
                and self.meta_attrs["is_quantized"]):
            warnings.warn("Quantization info found both in XGraph meta"
                          " attributes and q_output attribute")

        if self.quantizer_output is not None:
            return self.quantizer_output

        # Retrieve quantization output from meta attributes
        q_output = QuantizerOutput(self.get_name())
        if "quant_keys" not in self.meta_attrs:
            raise ValueError("Expected `quant_keys` attribute in meta"
                             " attributes")

        for q_key in self.meta_attrs["quant_keys"]:
            q_output.add(q_key=q_key,
                         q_file=self.meta_attrs[q_key]['q_file'],
                         q_info=self.meta_attrs[q_key]['q_info'],
                         orig_pb=self.meta_attrs[q_key]['orig_pb'])
            logger.debug("QOutput q_info: {}".format(
                self.meta_attrs[q_key]['q_info']))

        return q_output
Exemple #2
0
    def quantize(self, stop=None, subgraphs_only=True):
        # type: (str, boolean) -> None
        """
        Start quantization of the executable graph model

        Arguments
        ---------
        stop: str (optional, default = None)
            the name of the operation at which to stop quantization
        """

        self._quantize(stop, subgraphs_only)

        # quant_files = {}
        q_output = QuantizerOutput(self.xgraph.get_name())
        for qkey in self._quant_layers.keys():
            if qkey != 'None':
                quant_file = os.path.join(self.work_dir, qkey + '_quant.json')
                self._quant_param.save_to_dpu_v1_json(self._quant_layers[qkey],
                                                      quant_file)
                q_output.add(qkey, quant_file, None, None)

        self.xgraph.set_quantizer_output(q_output)

        logger.info("QUANTIZATION DONE")

        return self.xgraph
Exemple #3
0
    def __init__(self,
                 xgraph,
                 inputs_func,
                 work_dir=os.path.join(os.getcwd(), 'work')):

        super(ExternalQuantizer, self).__init__(xgraph, inputs_func, work_dir)

        self.gen = TfGenerator()
        self.partition_graphs = {}
        self.res = {}
        self.q_output = QuantizerOutput(name=xgraph.get_name())
Exemple #4
0
    def __init__(self,
                 xgraph,
                 inputs_func,
                 work_dir=os.path.join(os.getcwd(), 'work'),
                 quant_iter=1,
                 **kwargs):

        super(DECENTQuantizer, self).__init__(xgraph, inputs_func, work_dir)

        self.quant_iter = quant_iter
        self.gen = TfGenerator()
        self.partition_graphs = {}
        self.res = {}
        self.kwargs = kwargs

        self.q_output = QuantizerOutput(name=xgraph.get_name())
Exemple #5
0
    def __init__(self,
                 xgraph,
                 inputs_func,
                 bitwidth=8,
                 work_dir=os.path.join(os.getcwd(), 'work'),
                 quant_iter=1,
                 mse_opt_num=50):
        super(XGraphMSEThresholdQuantizer, self).__init__(xgraph)

        self.inputs_func = inputs_func
        self.work_dir = work_dir
        self.bitwidth = bitwidth
        self.mse_opt_num = mse_opt_num

        self.quant_xgraph = None
        self.runtime = None

        self._quant_param = QuantParamFactory()
        self._quant_layers = {}

        self.q_output = QuantizerOutput(name=xgraph.get_name())
Exemple #6
0
class ExternalQuantizer(XGraphBaseSubgraphQuantizer, ABC):

    xgraph_factory = XGraphFactory()
    xgraph_partitioner = XGraphPartitioner()

    def __init__(self,
                 xgraph,
                 inputs_func,
                 work_dir=os.path.join(os.getcwd(), 'work')):

        super(ExternalQuantizer, self).__init__(xgraph, inputs_func, work_dir)

        self.gen = TfGenerator()
        self.partition_graphs = {}
        self.res = {}
        self.q_output = QuantizerOutput(name=xgraph.get_name())

    def _propagate_quant_info(self, xgraph):
        # setup empty vqi and vqo for every layer w/o vai_quant
        for layer in xgraph.get_layers():
            if 'vai_quant' not in layer.attrs:
                layer.attrs['vai_quant'] = ['vai_quant_in', 'vai_quant_out']
                layer.attrs['vai_quant_in'] = ''
                layer.attrs['vai_quant_out'] = ''
        # for every layer
        for layer in xgraph.get_layers():
            # if the layer has non empty vqo, propagate it to the output layers
            if layer.attrs['vai_quant_out'] != '':
                l_vqo = layer.attrs['vai_quant_out']
                # for every output layer
                for t_idx, t_name in enumerate(layer.tops):
                    t_layer = xgraph.get(t_name)
                    # if the input quant is not specified in the output layer
                    if t_layer.attrs['vai_quant_in'] == '':
                        # get quant info from current layer, two by two
                        t_vqi = [l_vqo[2 * t_idx], l_vqo[2 * t_idx + 1]]
                        t_layer.attrs['vai_quant_in'] = t_vqi
            # if the layer has non empty vqi, propagate it to the input layers
            if layer.attrs['vai_quant_in'] != '':
                l_vqi = layer.attrs['vai_quant_in']
                # for every input layer
                for b_idx, b_name in enumerate(layer.bottoms):
                    b_layer = xgraph.get(b_name)
                    if b_layer.attrs['vai_quant_out'] == '':
                        b_vqo = [l_vqi[2 * b_idx], l_vqi[2 * b_idx + 1]]
                        b_layer.attrs['vai_quant_out'] = b_vqo

    def quantize(self):
        # NOTE For Conv2Dtranspose layers we need the specific batch size in tensorflow 1.13
        batch_size = list(self.inputs_func(0).values())[0].shape[0]
        fs = self.gen.generate(
            self.xgraph,
            'graph',
            subgraphs_only=True,
            layout='NHWC',
            batch_size=batch_size)
        assert len(fs) == 1, 'Too many partitions'
        partition_key = list(fs.keys())[0]
        pb_path = list(fs.values())[0]
        self.partition_graphs[partition_key] = pb_path

        q_xgraph = super(ExternalQuantizer, self).quantize()

        self.xgraph.meta_attrs["is_quantized"] = True
        for qkey in self.q_output.keys():
            if 'quant_keys' not in self.xgraph.meta_attrs:
                self.xgraph.meta_attrs['quant_keys'] = [qkey]
            else:
                self.xgraph.meta_attrs['quant_keys'].append(qkey)
            quant_file = self.q_output.get_q_file(qkey)
            quant_info_file = self.q_output.get_q_info(qkey)
            quant_orig_pb = self.q_output.get_orig_pb(qkey)
            self.xgraph.meta_attrs[qkey] = {
                'q_file': quant_file,
                'q_info': quant_info_file,
                'orig_pb': quant_orig_pb}
        return q_xgraph
Exemple #7
0
class DECENTQuantizer(XGraphBaseSubgraphQuantizer):

    # try:
    #     if hasattr(tf.contrib, 'decent_q'):
    #         from tensorflow.contrib import decent_q
    # except Exception as e:
    #     warnings.warn("Could not import decent_q module")
    try:
        #     from tensorflow.contrib import decent_q
        import tensorflow as tf
        if hasattr(tf, 'contrib') and hasattr(tf.contrib, 'decent_q'):
            from tensorflow.contrib import decent_q
        else:
            warnings.warn("Could not import decent_q module. Please check"
                          " if installed.")
    except ImportError:
        warnings.warn("Could not import decent_q module. Please check"
                      " if installed.")

    xgraph_factory = XGraphFactory()
    xgraph_partitioner = XGraphPartitioner()

    def __init__(self,
                 xgraph,
                 inputs_func,
                 work_dir=os.path.join(os.getcwd(), 'work'),
                 quant_iter=1,
                 **kwargs):

        super(DECENTQuantizer, self).__init__(xgraph, inputs_func, work_dir)

        self.quant_iter = quant_iter
        self.gen = TfGenerator()
        self.partition_graphs = {}
        self.res = {}
        self.kwargs = kwargs

        self.q_output = QuantizerOutput(name=xgraph.get_name())

    def quantize_subgraph(self, xgraph, inputs, input_names, output_names):
        # type: (XGraph, Dict[str, numpy.ndarray])
        """ Quantize subgraph with inputs """

        # Import Tensorflow only when needed to avoid strict dependency
        import tensorflow as tf

        frozen_graph = self.partition_graphs[xgraph.get_name()]
        logger.info("Load frozen graph from: {}".format(frozen_graph))
        input_graph_def = tf.compat.v1.GraphDef()
        with tf.io.gfile.GFile(frozen_graph, "rb") as f:
            input_graph_def.ParseFromString(f.read())

        logger.info("Quantization input: {} and output names: {}".format(
            input_names, output_names))
        input_shapes = [X.shapes.tolist() for X in xgraph.get_input_layers()]

        def inputs_func(iter):
            import numpy as np
            nonlocal inputs

            return inputs

        logger.info("START decent quantization for graph partition: {}".format(
            xgraph.get_name()))
        q_config = self.decent_q.QuantizeConfig(input_nodes=input_names,
                                                output_nodes=output_names,
                                                input_shapes=input_shapes,
                                                output_dir=self.work_dir,
                                                method='1',
                                                calib_iter=self.quant_iter)
        self.decent_q.quantize_frozen(input_graph_def, inputs_func, q_config)

        netcfg = os.path.join(self.work_dir, "deploy_model.pb")
        q_eval_file = os.path.join(self.work_dir, "quantize_eval_model.pb")
        quant_info_file = os.path.join(
            self.work_dir, 'quant_info_{}.txt'.format(xgraph.get_name()))
        self._save_quant_info(netcfg, quant_info_file)

        self.q_output.add(xgraph.get_name(), netcfg, quant_info_file,
                          frozen_graph, q_eval_file)

        # TODO
        # Add quantization info to corresponding XLayers
        self._add_quant_info_to_xgraph(netcfg)

    def quantize(self) -> None:
        """Quantize the XGraph model using the decent_q quantizer"""

        # NOTE For Conv2Dtranspose layers we need the specific batch size in
        #   tensorflow 1.13
        batch_size = list(self.inputs_func(0).values())[0].shape[0]

        fs = self.gen.generate(self.xgraph,
                               'graph',
                               subgraphs_only=True,
                               layout='NHWC',
                               batch_size=batch_size,
                               out_dir=self.work_dir,
                               **self.kwargs)

        if len(fs) != 1:
            raise ValueError("DECENT quantization currently only supports"
                             " models with one DPU compatible partition,"
                             " but got: {}".format(len(fs)))

        partition_key = list(fs.keys())[0]
        pb_path = list(fs.values())[0]

        self.partition_graphs[partition_key] = pb_path

        q_xgraph = super(DECENTQuantizer, self).quantize()

        self.xgraph.meta_attrs["is_quantized"] = True
        for qkey in self.q_output.keys():
            if 'quant_keys' not in self.xgraph.meta_attrs:
                self.xgraph.meta_attrs['quant_keys'] = [qkey]
            else:
                self.xgraph.meta_attrs['quant_keys'].append(qkey)
            quant_file = self.q_output.get_q_file(qkey)
            quant_info_file = self.q_output.get_q_info(qkey)
            quant_orig_pb = self.q_output.get_orig_pb(qkey)
            quant_eval_file = self.q_output.get_q_eval(qkey)
            self.xgraph.meta_attrs[qkey] = {
                'q_file': quant_file,
                'q_info': quant_info_file,
                'orig_pb': quant_orig_pb,
                'q_eval': quant_eval_file
            }

        self.xgraph.set_quantizer_output(self.q_output)
        # import pdb; pdb.set_trace()

        return q_xgraph

    def _add_quant_info_to_xgraph(self, deploy_frozen_graph: str) -> None:
        """
        Retrieve the quantization info from the provided quantized model and
        add the information to the corresponding XLayers
        """

        # Import tensorflow only when needed to avoid strict dependency
        import tensorflow as tf

        quant_info = []

        input_graph_def = tf.compat.v1.GraphDef()
        with tf.io.gfile.GFile(deploy_frozen_graph, "rb") as f:
            input_graph_def.ParseFromString(f.read())

            for idx, node in enumerate(input_graph_def.node):

                if node.name in self.xgraph:
                    X = self.xgraph.get(node.name)
                    X.attrs['vai_quant_idx'] = idx + 1

                    if 'ipos' in node.attr.keys():
                        X.attrs['vai_quant'] = ['vai_quant_in']
                        X.attrs['vai_quant_in'] = \
                            [int(v) for v in node.attr['ipos'].list.i]
                    if 'opos' in node.attr.keys():
                        X.attrs['vai_quant'].append('vai_quant_out')
                        X.attrs['vai_quant_out'] = \
                            [int(v) for v in node.attr['opos'].list.i]
                    if 'wpos' in node.attr.keys():
                        X.attrs['vai_quant'].append('vai_quant_weights')
                        X.attrs['vai_quant_weights'] = \
                            [int(v) for v in node.attr['wpos'].list.i]
                    if 'bpos' in node.attr.keys():
                        X.attrs['vai_quant'].append('vai_quant_biases')
                        X.attrs['vai_quant_biases'] = \
                            [int(v) for v in node.attr['bpos'].list.i]

    def _save_quant_info(self, deploy_frozen_graph, filename):
        # type: (str) -> None
        """
        Retrieve the quantization info from the provided quantized model
        """
        quant_info = self._get_quant_info(deploy_frozen_graph)

        lines = [[q_op['idx']] + [q_op['name']] +
                 [str(i)
                  for i in q_op['ipos']] + [str(i) for i in q_op['opos']] +
                 [str(i)
                  for i in q_op['wpos']] + [str(i) for i in q_op['bpos']]
                 for q_op in quant_info]
        s = '\n'.join([' '.join(line) for line in lines])

        with open(filename, 'w') as f:
            f.write(s)

    def _get_quant_info(self, deploy_frozen_graph):
        # type: (str) -> List[dict]
        """
        Retrieve the quantization info from the provided quantized model
        """

        # import tensorflow only when needed to avoid strict dependency
        import tensorflow as tf

        quant_info = []

        input_graph_def = tf.compat.v1.GraphDef()
        with tf.io.gfile.GFile(deploy_frozen_graph, "rb") as f:
            input_graph_def.ParseFromString(f.read())

            for idx, node in enumerate(input_graph_def.node):

                q_op = {
                    'idx': str(idx + 1),
                    'name': node.name,
                    'ipos': [],
                    'opos': [],
                    'wpos': [],
                    'bpos': []
                }

                if 'ipos' in node.attr.keys():
                    q_op['ipos'].extend(
                        [int(v) for v in node.attr['ipos'].list.i])
                if 'opos' in node.attr.keys():
                    q_op['opos'].extend(
                        [int(v) for v in node.attr['opos'].list.i])
                if 'wpos' in node.attr.keys():
                    q_op['wpos'].extend(
                        [int(v) for v in node.attr['wpos'].list.i])
                if 'bpos' in node.attr.keys():
                    q_op['bpos'].extend(
                        [int(v) for v in node.attr['bpos'].list.i])

                quant_info.append(q_op)

        return quant_info

    def eval(self,
             val_dir,
             gold_file,
             synset_words,
             batch_size,
             nb_batches,
             class_num=1000,
             gpu=0):
        #
        """
        """

        input_fn_data = {
            "prep_key": self.data_prep_key,
            "dir": val_dir,
            "batch": batch_size,
            "inputs": self.xgraph.get_input_names()
        }

        with open(os.path.join(FILE_PATH, 'calibration.json'), 'w') as f:
            json.dump(input_fn_data, f)

        with open(gold_file) as f:
            val_set = [line.strip('\n').split(' ') for line in f.readlines()]

        # frozen_graph_file = os.path.join(os.getcwd(), 'test.pb')
        frozen_graph_file = os.path.join(self.output_dir,
                                         "quantize_eval_model.pb")
        # TODO
        assert (len(self.xgraph.get_input_names()) == 1)
        assert (len(self.xgraph.get_output_names()) == 1)
        input_node = self.xgraph.get_input_names()[0]
        output_node = self.xgraph.get_output_names()[0]

        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
        input_graph_def = tf.Graph().as_graph_def()
        input_graph_def.ParseFromString(
            tf.gfile.FastGFile(frozen_graph_file, "rb").read())

        tf.import_graph_def(input_graph_def, name='')

        # Get input tensors
        input_tensor = tf.get_default_graph()\
            .get_tensor_by_name(input_node+':0')
        input_labels = tf.compat.v1.placeholder(tf.float32,
                                                shape=[None, class_num])

        # Calculate accuracy
        output = tf.get_default_graph().get_tensor_by_name(output_node + ':0')
        prediction = tf.reshape(output, [batch_size, class_num])
        # correct_labels = tf.argmax(input_labels, 1)
        # top1_prediction = tf.nn.in_top_k(prediction, correct_labels, k = 1)
        # top5_prediction = tf.nn.in_top_k(prediction, correct_labels, k = 5)
        # top1_accuracy = tf.reduce_mean(tf.cast(top1_prediction,'float'))
        # top5_accuracy = tf.reduce_mean(tf.cast(top5_prediction,'float'))

        # Start evaluation
        logger.info("Start Evaluation for {} Batches...".format(nb_batches))
        with tf.Session() as sess:
            progress = ProgressBar()

            top1_sum_acc = 0
            top5_sum_acc = 0

            for iter in progress(range(0, nb_batches)):
                input_data = decent_prepfn.input_fn(iter)
                images = input_data[input_node]
                # labels = input_data['labels']
                logger.debug("IMAGES", images)
                labels = [
                    elem[1] for elem in val_set[iter * batch_size:(iter + 1) *
                                                batch_size]
                ]
                feed_dict = {input_tensor: images}
                raw_predictions = sess.run(prediction, feed_dict)
                logger.debug(raw_predictions)

                # logger.debug("Predictions shape: {}"
                #              .format(raw_predictions.shape))
                # logger.debug("Labels length: {}".format(len(labels)))
                top_1 = classification.get_top_k_accuracy(
                    raw_predictions, synset_words, 1, labels)
                top_5 = classification.get_top_k_accuracy(
                    raw_predictions, synset_words, 5, labels)
                top1_sum_acc += top_1
                top5_sum_acc += top_5
                logger.debug("int: {}, {}".format(top_1, top_5))

        final_top1_acc = top1_sum_acc / nb_batches
        final_top5_acc = top5_sum_acc / nb_batches

        print("Accuracy: Top1: {}, Top5: {}".format(final_top1_acc,
                                                    final_top5_acc))

    def dump(self, img_dir, input_names, max_dump_batches=1, dump_float=0):
        #
        """
        TODO: inupt_names
        """
        input_fn_data = {
            "prep_key": self.data_prep_key,
            "dir": img_dir,
            "batch": 1,
            "inputs": input_names
        }

        with open(os.path.join(FILE_PATH, 'calibration.json'), 'w') as f:
            json.dump(input_fn_data, f)

        frozen_graph = os.path.join(self.output_dir, 'quantize_eval_model.pb')

        command = """
        decent_q dump \
            --input_frozen_graph {} \
            --input_fn decent_prepfn.input_fn \
            --max_dump_batches {} \
            --dump_float {} \
            --output_dir {}
        """.format(frozen_graph, max_dump_batches, dump_float, self.output_dir)

        print("COMMAND", command)

        process = subprocess.Popen(command.split(),
                                   cwd=FILE_PATH,
                                   stdout=subprocess.PIPE)
        output, error = process.communicate()

        print(output, error)
Exemple #8
0
class XGraphMSEThresholdQuantizer(XGraphBaseQuantizer):
    """

    Attributes
    ----------
    xgraph: XGraph
        the XGraph instance to be quantized
    inputs_func: Function
        the inputs functions to be used for quantization, should accept and
        iterator and return a dictionary mapping from input names to example
        input data
    bitwidth: int
        the bitwidth to be used for quantization
    work_dir: str
        the work firectory to be used for storing quantization files
    quant_iter: int
        the number of iterations for quantization
    mse_opt_num: int
        the number of trials for optimizing mean squared (MSE) error between
        full precision and quantized outputs
    """
    def __init__(self,
                 xgraph,
                 inputs_func,
                 bitwidth=8,
                 work_dir=os.path.join(os.getcwd(), 'work'),
                 quant_iter=1,
                 mse_opt_num=50):
        super(XGraphMSEThresholdQuantizer, self).__init__(xgraph)

        self.inputs_func = inputs_func
        self.work_dir = work_dir
        self.bitwidth = bitwidth
        self.mse_opt_num = mse_opt_num

        self.quant_xgraph = None
        self.runtime = None

        self._quant_param = QuantParamFactory()
        self._quant_layers = {}

        self.q_output = QuantizerOutput(name=xgraph.get_name())

    def quantize(self, stop=None, subgraphs_only=True):
        # (str, boolean) -> None
        """ Start MSE quantization """

        self._quantize(self.xgraph, stop, subgraphs_only=subgraphs_only)

        for qkey in self._quant_layers.keys():

            if qkey != 'None':
                quant_file = os.path.join(self.work_dir, qkey + '_quant.json')
                self._quant_param.save_to_dpu_v1_json(self._quant_layers[qkey],
                                                      quant_file)

                self.q_output.add(qkey, quant_file, q_info=None, orig_pb=None)

                # TODO Add scaling layers
                # TODO Move adding scaling layer to before optimization
                fancy_logger.banner(
                    "ADD QUANTIZATION SCALING LAYERS FOR: {}".format(qkey))

                quant_params = QuantParams(quant_file)
                graph_pass = XGraphQuantScalingPass(
                    quant_params,
                    quant_file,
                    output_png='tvm_quant_eltwise_scaling.png'
                    if logger.getEffectiveLevel() <= 10 else None)
                xgraph = graph_pass.execute(self.xgraph)

                self.xgraph = xgraph

        self.xgraph.set_quantizer_output(self.q_output)

        fancy_logger.banner("FINISHED QUANTIZATION")

        return xgraph

    def _quantize(self, xgraph, stop=None, subgraphs_only=True):
        # (str, boolean) -> None
        """ Start MSE quantization """

        # Graph pass to construct new graph with quantization layers
        graph_pass = XGraphPassAddMSEQuantLayers(
            bitwidth=self.bitwidth,
            mse_opt_num=self.mse_opt_num,
            subgraphs_only=subgraphs_only,
            output_png='tvm_mse_quant.png'
            if logger.getEffectiveLevel() <= 10 else None,
            name=xgraph.get_name())
        xgraph = graph_pass.execute(xgraph=xgraph)

        self.quant_xgraph = xgraph
        self.runtime = pyxir.build(self.quant_xgraph, target='cpu')

        # Run graph to set Variable layer thresholds in graph
        fancy_logger.banner("EXECUTE QUANTIZATION GRAPH")

        inpts = self.inputs_func(0)
        out, params = self.runtime.optimize(inpts)

        logger.info("Done executing graph")
        # logger.info(out.shape, out)
        # logger.info(thresholds)

        logger.info("Retrieving quantization parameters...")

        self._retrieve_quant_params(params, xgraph, subgraphs_only)

    def _retrieve_quant_params(self, thresholds, xgraph, subgraphs_only):
        # type: (dict, XGraph) -> None
        """ """
        # TODO
        logger.debug("Thresholds: {}".format(thresholds))

        # TODO implement as a graph pass??
        for X in xgraph.get_layers():

            bottom_Xs = xgraph.get_bottom_layers(X.name)
            top_Xs = xgraph.get_top_layers(X.name)

            if subgraphs_only and X.subgraph is not None:
                qkey = X.subgraph
            elif subgraphs_only:
                qkey = "None"
            else:
                qkey = xgraph.get_name()

            if qkey not in self._quant_layers:
                self._quant_layers[qkey] = []

            # if 'Input' in X.type and len(top_Xs) == 1 and\
            #         'MSEQuantize' in top_Xs[0].type:

            #     self._quant_layers[qkey].append((X.name, 'Input', None))

            #     assert(len(top_Xs[0].bottoms) == 2)
            #     th_out = thresholds[top_Xs[0].bottoms[1]]

            #     self._quant_param.bw_layer_in[X.name] = self.bitwidth
            #     self._quant_param.th_layer_in[X.name] = th_out
            #     self._quant_param.bw_layer_out[X.name] = self.bitwidth
            #     self._quant_param.th_layer_out[X.name] = th_out

            if 'Convolution' in X.type and len(top_Xs) == 1 and\
                    'MSEQuantize' in top_Xs[0].type:

                self._quant_layers[qkey].append((X.name, 'Convolution', None))
                assert len(bottom_Xs) == 3
                assert 'MSEQuantize' in bottom_Xs[0].type or\
                       'MSEMockQuantize' in bottom_Xs[0].type
                assert 'MSEQuantize' in bottom_Xs[1].type

                assert (len(top_Xs[0].bottoms) == 2)

                th_in = thresholds[bottom_Xs[0].bottoms[1]]
                th_params = thresholds[bottom_Xs[1].bottoms[1]]
                th_out = thresholds[top_Xs[0].bottoms[1]]

                self._quant_param.bw_layer_in[X.name] = self.bitwidth
                self._quant_param.th_layer_in[X.name] = th_in
                self._quant_param.bw_params[X.name] = self.bitwidth
                self._quant_param.th_params[X.name] = th_params
                self._quant_param.bw_layer_out[X.name] = self.bitwidth
                self._quant_param.th_layer_out[X.name] = th_out

            elif 'Scale' in X.type and len(top_Xs) == 1 and\
                    'MSEQuantize' in top_Xs[0].type:

                gamma = X.data.gamma
                self._quant_layers[qkey].append(
                    (X.name, 'Scale', [LayerParams(gamma)]))

                assert (len(bottom_Xs) == 3)
                assert 'MSEQuantize' in bottom_Xs[0].type or\
                       'MSEMockQuantize' in bottom_Xs[0].type
                assert ('Input' in bottom_Xs[1].type)
                assert ('MSEQuantizeBias' in bottom_Xs[2].type)
                assert (len(top_Xs[0].bottoms) == 2)

                th_in = thresholds[bottom_Xs[0].bottoms[1]]
                th_params = X.data.gamma
                th_out = thresholds[top_Xs[0].bottoms[1]]

                self._quant_param.bw_layer_in[X.name] = self.bitwidth
                self._quant_param.th_layer_in[X.name] = th_in
                self._quant_param.bw_params[X.name] = self.bitwidth
                self._quant_param.th_params[X.name] = th_params
                self._quant_param.bw_layer_out[X.name] = self.bitwidth
                self._quant_param.th_layer_out[X.name] = th_out

            elif 'MSEQuantizeEltwise' in X.type and len(top_Xs) == 1 and\
                    'MSEQuantize' in top_Xs[0].type:
                # 'MSEQuantizeEltwise'

                self._quant_layers[qkey].append((X.name, 'Eltwise', None))

                assert (len(bottom_Xs) == 5)
                assert 'MSEQuantize' in bottom_Xs[1].type or\
                       'MSEMockQuantize' in bottom_Xs[1].type
                assert 'MSEQuantize' in bottom_Xs[3].type or\
                       'MSEMockQuantize' in bottom_Xs[3].type

                # th_in_1 = thresholds[bottom_Xs[0].bottoms[1]]
                # th_in_2 = thresholds[bottom_Xs[1].bottoms[1]]
                # th_in = np.maximum(th_in_1, th_in_2)
                th_in = thresholds[X.bottoms[4]]
                th_out = thresholds[top_Xs[0].bottoms[1]]
                assert (len(top_Xs[0].bottoms) in [2, 4])

                self._quant_param.bw_layer_in[X.name] = self.bitwidth
                self._quant_param.th_layer_in[X.name] = th_in
                self._quant_param.bw_layer_out[X.name] = self.bitwidth
                self._quant_param.th_layer_out[X.name] = th_out

            elif 'Concat' in X.type and len(top_Xs) == 1 and \
                    'MSEQuantize' in top_Xs[0].type:

                self._quant_layers[qkey].append((X.name, 'Concat', None))
                logger.debug("CONCAT!!")

                for bottom_X in bottom_Xs:
                    assert 'MSEQuantize' in bottom_X.type or\
                           'MSEMockQuantize' in bottom_X.type

                th_in = thresholds[top_Xs[0].bottoms[1]]
                th_out = thresholds[top_Xs[0].bottoms[1]]
                assert (len(top_Xs[0].bottoms) in [2, 4])

                self._quant_param.bw_layer_in[X.name] = self.bitwidth
                self._quant_param.th_layer_in[X.name] = th_in
                self._quant_param.bw_layer_out[X.name] = self.bitwidth
                self._quant_param.th_layer_out[X.name] = th_out

            elif 'Pooling' in X.type and len(top_Xs) == 1 and \
                    ('MSEMockQuantize' in top_Xs[0].type or
                     'MSEQuantize' in bottom_Xs[0].type):

                assert (len(top_Xs) == 1)
                assert (len(bottom_Xs) == 1)
                assert (len(top_Xs[0].bottoms) == 2)
                # assert('MSEQuantize' in bottom_Xs[0].type or\
                #     'MSEMockQuantize' in bottom_Xs[0].type)

                if X.attrs['pool_type'] == 'Max':
                    # Maxpool
                    pool_divisor = [1]
                elif X.attrs['pool_type'] == 'Avg':
                    # Avg pool
                    pool_divisor = [np.prod(X.attrs['kernel_size'])]

                self._quant_layers[qkey].append(
                    (X.name, 'Pooling', LayerParams(pool_divisor)))

                th_in = thresholds[bottom_Xs[0].bottoms[1]]
                th_out = thresholds[top_Xs[0].bottoms[1]]

                if X.attrs['pool_type'] == 'Max':
                    assert (th_in == th_out)

                self._quant_param.bw_layer_in[X.name] = self.bitwidth
                self._quant_param.th_layer_in[X.name] = th_in
                self._quant_param.bw_layer_out[X.name] = self.bitwidth
                self._quant_param.th_layer_out[X.name] = th_out