Exemplo n.º 1
0
def _move_quant_attributes_into_annotations(model):
    """Move quantization info in attributes into quantization_annotation"""
    if onnx is None:
        raise ModuleNotFoundError("Installation of ONNX is required.")

    model = copy.deepcopy(model)
    qaname = "finn_datatype"
    for n in model.graph.node:
        for a in n.attribute:
            mark_for_removal = False
            if a.name == "weight_qnt":
                # assume second input is weight, make sure it has an initializer
                w_tensor_name = n.input[1]
                assert w_tensor_name in [
                    x.name for x in model.graph.initializer
                ]
                tq = onnx.StringStringEntryProto(key=qaname, value=a.s)
                ta = onnx.TensorAnnotation(tensor_name=w_tensor_name,
                                           quant_parameter_tensor_names=[tq])
                model.graph.quantization_annotation.append(ta)
                mark_for_removal = True
            elif a.name == "activation_qnt":
                a_tensor_name = n.output[0]
                tq = onnx.StringStringEntryProto(key=qaname, value=a.s)
                ta = onnx.TensorAnnotation(tensor_name=a_tensor_name,
                                           quant_parameter_tensor_names=[tq])
                model.graph.quantization_annotation.append(ta)
                mark_for_removal = True
            if mark_for_removal:
                n.attribute.remove(a)
    return model
Exemplo n.º 2
0
 def set_tensor_layout(self, tensor_name, data_layout):
     """Sets the data layout annotation of tensor with given name. See
     get_tensor_layout for examples."""
     tensor_shape = self.get_tensor_shape(tensor_name)
     assert type(data_layout) == list, "data_layout must be a list"
     if tensor_shape is not None:
         assert len(tensor_shape) == len(
             data_layout
         ), """Mismatch between number
         of dimensions of tensor shape and data layout annotation."""
     graph = self._model_proto.graph
     qnt_annotations = graph.quantization_annotation
     ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
     if ret is not None:
         ret_tl = util.get_by_name(
             ret.quant_parameter_tensor_names, "tensor_layout", "key"
         )
         if ret_tl is not None:
             ret_tl.value = str(data_layout)
         else:
             tl = onnx.StringStringEntryProto()
             tl.key = "tensor_layout"
             tl.value = str(data_layout)
             ret.quant_parameter_tensor_names.append(tl)
     else:
         qa = onnx.TensorAnnotation()
         dt = onnx.StringStringEntryProto()
         dt.key = "tensor_layout"
         dt.value = str(data_layout)
         qa.tensor_name = tensor_name
         qa.quant_parameter_tensor_names.append(dt)
         qnt_annotations.append(qa)
Exemplo n.º 3
0
 def set_tensor_datatype(self, tensor_name, datatype):
     """Sets the FINN DataType of tensor with given name."""
     graph = self._model_proto.graph
     qnt_annotations = graph.quantization_annotation
     ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
     if ret is not None:
         ret_dt = util.get_by_name(ret.quant_parameter_tensor_names,
                                   "finn_datatype", "key")
         if ret_dt is not None:
             if datatype is None:
                 ret_dt.Clear()
             else:
                 ret_dt.value = datatype.name
         elif datatype is not None:
             dt = onnx.StringStringEntryProto()
             dt.key = "finn_datatype"
             dt.value = datatype.name
             ret.quant_parameter_tensor_names.append(dt)
     elif datatype is not None:
         qa = onnx.TensorAnnotation()
         dt = onnx.StringStringEntryProto()
         dt.key = "finn_datatype"
         dt.value = datatype.name
         qa.tensor_name = tensor_name
         qa.quant_parameter_tensor_names.append(dt)
         qnt_annotations.append(qa)
Exemplo n.º 4
0
 def set_tensor_sparsity(self, tensor_name, sparsity_dict):
     """Sets the sparsity annotation of a tensor with given name."""
     graph = self._model_proto.graph
     qnt_annotations = graph.quantization_annotation
     ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
     if ret is not None:
         ret_ts = util.get_by_name(ret.quant_parameter_tensor_names,
                                   "tensor_sparsity", "key")
         if ret_ts is not None:
             ret_ts.value = str(sparsity_dict)
         else:
             ts = onnx.StringStringEntryProto()
             ts.key = "tensor_sparsity"
             ts.value = str(sparsity_dict)
             ret.quant_parameter_tensor_names.append(ts)
     else:
         qa = onnx.TensorAnnotation()
         dt = onnx.StringStringEntryProto()
         dt.key = "tensor_sparsity"
         dt.value = str(sparsity_dict)
         qa.tensor_name = tensor_name
         qa.quant_parameter_tensor_names.append(dt)
         qnt_annotations.append(qa)