コード例 #1
0
    def __init__(self, prune_graph=True, **kwargs):
        cls = type(self)
        if not cls.APPLICABLE_LIBS:
            logger.warning(
              "empty APPLICABLE_LIBS detected for {}, ".format(cls) + \
              "such transformer will not be applied to any graph by default"
            )
        if cls.KWARGS_NAMESCOPE is None:
            raise ValueError('kwargs namescope not found for %s' % cls)
        self.prune_graph = prune_graph
        ori_transform = self.transform

        @wraps(ori_transform)
        def transform(ugraph):
            if self.APPLICABLE_LIBS is not GENERIC_SENTINEL and ugraph.lib_name not in self.APPLICABLE_LIBS:
                logger.info(
                    "%s is not applicable to ugraph with lib name %s, skipping",
                    self,
                    ugraph.lib_name,
                )
                return ugraph
            new_ugraph = ori_transform(ugraph)
            topologic_order_graph(new_ugraph)
            if self.prune_graph:
                return _prune_graph(new_ugraph)
            return new_ugraph

        self.transform = transform
コード例 #2
0
 def __setitem__(self, entity, alloc):
     if not isinstance(alloc, TimeSpaceAllocation):
         raise ValueError('the value should be of type {}, get {}'.format(
             TimeSpaceAllocation, type(alloc)))
     if entity in self._plan:
         logger.warning('duplicate entity detected: {}'.format(entity))
     self._plan[entity] = alloc
コード例 #3
0
    def parse(self, pb_file, output_nodes=None, model_name=None):
        graph_def, graph_name = self._load_graph_def(pb_file)
        if model_name:
            graph_name = model_name
        if not self._is_freeze_graph(graph_def):
            raise ValueError("Given graph_def is not freezed")
        if output_nodes is None:
            output_nodes = [node.name for node in graph_def.node]
            logger.warning(
                'output_nodes is not given, use all nodes instead (may cause unexpected behaviour)'
            )

        graph = tf.Graph()
        with graph.as_default():
            tf.import_graph_def(graph_def, name="")
        ugraph = uTensorGraph(
            name=graph_name,
            output_nodes=output_nodes,
            lib_name="tensorflow",
        )
        for node in graph_def.node:
            op = graph.get_operation_by_name(node.name)
            in_tensors = [
                TensorInfo(
                    name=tensor.name,
                    ugraph=ugraph,
                    op_name=tensor.op.name,
                    dtype=np.dtype(tensor.dtype.as_numpy_dtype),
                    shape=self._tf_parse_tshape(tensor.shape),
                ) for tensor in op.inputs
            ]
            out_tensors = [
                TensorInfo(
                    name=tensor.name,
                    ugraph=ugraph,
                    op_name=op.name,
                    dtype=np.dtype(tensor.dtype.as_numpy_dtype),
                    shape=self._tf_parse_tshape(tensor.shape),
                ) for tensor in op.outputs
            ]
            op_type = node.op
            op_attr = node.attr
            op_info = OperationInfo(
                name=node.name,
                input_tensors=in_tensors,
                n_inputs=len(in_tensors),
                output_tensors=out_tensors,
                n_outputs=len(out_tensors),
                op_type=op_type,
                lib_name="tensorflow",
                op_attr=op_attr,
                ugraph=ugraph,
            )
            op_info.op_attr["tensorflow__device"] = node.device
            ugraph.ops_info[node.name] = op_info
        topologic_order_graph(ugraph)
        ugraph = Legalizer.legalize(ugraph, {})
        return ugraph
コード例 #4
0
 def transform(self, ugraph):
     logger.warning(
         "enabling {} will force replacing GatherV2 with Gather".format(
             self.METHOD_NAME))
     for key, op in ugraph.ops_info.items():
         if op.op_type == "GatherV2":
             op.op_type = "Gather"
             ugraph.ops_info[key] = op
     return ugraph
コード例 #5
0
 def size(self):
   if self.shape is None:
     raise RuntimeError('nondeterministic shape has no size')
   if None in self.shape:
     logger.warning(
       'nondeterministic dimension detected, implicitly converting None to 1: %s, %s',
       self.name,
       self.shape,
     )
   return reduce(lambda i, j: i*(j is None and 1 or j), self.shape, 1)
コード例 #6
0
ファイル: utils.py プロジェクト: yuezha01/utensor_cgen
def save_idx(arr, fname):
    if arr.shape == ():
        arr = np.array([arr], dtype=arr.dtype)
    if arr.dtype in [np.int64]:
        logger.warning(
            "unsupported int format for idx detected: %s, using int32 instead",
            arr.dtype)
        arr = arr.astype(np.int32)
    out_dir = os.path.dirname(fname)
    if out_dir and not os.path.exists(out_dir):
        os.makedirs(out_dir)
    with open(fname, "wb") as fid:
        idx2np.convert_to_file(fid, arr)
    logger.info("%s saved", fname)
コード例 #7
0
 def get_opertor(cls, op_info):
     op_type = op_info.op_type
     codegen_namespaces = op_info.code_gen_attributes.get(
         'namespaces', tuple())
     op_cls = cls._operators.get((codegen_namespaces, op_type))
     if op_cls is None:
         missing_op_cls = cls._operators['_MissingOperator']
         if op_info.op_type not in cls._warned_missing_ops:
             logger.warning(
                 '{} is missing, no code will be generated for it'.format(
                     op_info.op_type))
             cls._warned_missing_ops.add(op_info.op_type)
         return missing_op_cls(op_info)
     return op_cls(op_info)
コード例 #8
0
 def __attrs_post_init__(self):
   skip_pattern = re.compile(r'_utensor_[^_]*')
   if self.op_attr:
     op_attr = {}
     for k, v in self.op_attr.items():
       match = skip_pattern.match(k)
       if match:
         op_attr[k] = v
       else:
         try:
           op_attr[k] = ConverterDispatcher.get_generic_value(v)
         except ValueError:
           logger.warning('cannot convert %s to generic value: %s(%s)', k, v, type(v))
           op_attr[k] = v
     self.op_attr = op_attr
   self._ugraph.ops_info[self.name] = self
   if not self.n_inputs == len(self.input_tensors):
     raise ValueError(
       'n_inputs is not equal to the length of input_tensors: {}'.format(self.name)
     )
   if not self.n_outputs == len(self.output_tensors):
     raise ValueError(
       'n_outputs is not equal to the length of output_tensors: {}'.format(self.name)
     )
コード例 #9
0
 def handle_default(self, ugraph):
     logger.warning('fall back to default graph lowering (do nothing)')
     return ugraph
コード例 #10
0
def default_op_data(op, fb_mdel):
    op_type = _get_op_type(op, fb_mdel)
    logger.warning('the op data parser is missing for %s', op_type)
    return {}
コード例 #11
0
"""Legacy, DON'T USE
"""
from utensor_cgen.frontend.tensorflow import GraphDefParser
from utensor_cgen.logger import logger

from .base import Transformer
from .pipeline import TransformerPipeline

try:
    from tensorflow.tools.graph_transforms import TransformGraph
except ImportError:
    logger.warning("trying to import deprecated quantization transformer")
    TransformGraph = None

__all__ = ['QuantizeTransformer']


@TransformerPipeline.register_transformer
class QuantizeTransformer(Transformer):
    METHOD_NAME = 'quantize'
    KWARGS_NAMESCOPE = '_quantize'
    APPLICABLE_LIBS = set(["tensorflow"])

    def transform(self, ugraph):
        if ugraph.lib_name != 'tensorflow':
            raise ValueError('only support tensorflow graph')
        graph_def = ugraph.graph_def
        if TransformGraph is None:
            raise RuntimeError("quantization is temporary not supported")
        quant_graph_def = TransformGraph(
            input_graph_def=graph_def,