def __init__(self, prune_graph=True, **kwargs): cls = type(self) if not cls.APPLICABLE_LIBS: logger.warning( "empty APPLICABLE_LIBS detected for {}, ".format(cls) + \ "such transformer will not be applied to any graph by default" ) if cls.KWARGS_NAMESCOPE is None: raise ValueError('kwargs namescope not found for %s' % cls) self.prune_graph = prune_graph ori_transform = self.transform @wraps(ori_transform) def transform(ugraph): if self.APPLICABLE_LIBS is not GENERIC_SENTINEL and ugraph.lib_name not in self.APPLICABLE_LIBS: logger.info( "%s is not applicable to ugraph with lib name %s, skipping", self, ugraph.lib_name, ) return ugraph new_ugraph = ori_transform(ugraph) topologic_order_graph(new_ugraph) if self.prune_graph: return _prune_graph(new_ugraph) return new_ugraph self.transform = transform
def __setitem__(self, entity, alloc): if not isinstance(alloc, TimeSpaceAllocation): raise ValueError('the value should be of type {}, get {}'.format( TimeSpaceAllocation, type(alloc))) if entity in self._plan: logger.warning('duplicate entity detected: {}'.format(entity)) self._plan[entity] = alloc
def parse(self, pb_file, output_nodes=None, model_name=None): graph_def, graph_name = self._load_graph_def(pb_file) if model_name: graph_name = model_name if not self._is_freeze_graph(graph_def): raise ValueError("Given graph_def is not freezed") if output_nodes is None: output_nodes = [node.name for node in graph_def.node] logger.warning( 'output_nodes is not given, use all nodes instead (may cause unexpected behaviour)' ) graph = tf.Graph() with graph.as_default(): tf.import_graph_def(graph_def, name="") ugraph = uTensorGraph( name=graph_name, output_nodes=output_nodes, lib_name="tensorflow", ) for node in graph_def.node: op = graph.get_operation_by_name(node.name) in_tensors = [ TensorInfo( name=tensor.name, ugraph=ugraph, op_name=tensor.op.name, dtype=np.dtype(tensor.dtype.as_numpy_dtype), shape=self._tf_parse_tshape(tensor.shape), ) for tensor in op.inputs ] out_tensors = [ TensorInfo( name=tensor.name, ugraph=ugraph, op_name=op.name, dtype=np.dtype(tensor.dtype.as_numpy_dtype), shape=self._tf_parse_tshape(tensor.shape), ) for tensor in op.outputs ] op_type = node.op op_attr = node.attr op_info = OperationInfo( name=node.name, input_tensors=in_tensors, n_inputs=len(in_tensors), output_tensors=out_tensors, n_outputs=len(out_tensors), op_type=op_type, lib_name="tensorflow", op_attr=op_attr, ugraph=ugraph, ) op_info.op_attr["tensorflow__device"] = node.device ugraph.ops_info[node.name] = op_info topologic_order_graph(ugraph) ugraph = Legalizer.legalize(ugraph, {}) return ugraph
def transform(self, ugraph): logger.warning( "enabling {} will force replacing GatherV2 with Gather".format( self.METHOD_NAME)) for key, op in ugraph.ops_info.items(): if op.op_type == "GatherV2": op.op_type = "Gather" ugraph.ops_info[key] = op return ugraph
def size(self): if self.shape is None: raise RuntimeError('nondeterministic shape has no size') if None in self.shape: logger.warning( 'nondeterministic dimension detected, implicitly converting None to 1: %s, %s', self.name, self.shape, ) return reduce(lambda i, j: i*(j is None and 1 or j), self.shape, 1)
def save_idx(arr, fname): if arr.shape == (): arr = np.array([arr], dtype=arr.dtype) if arr.dtype in [np.int64]: logger.warning( "unsupported int format for idx detected: %s, using int32 instead", arr.dtype) arr = arr.astype(np.int32) out_dir = os.path.dirname(fname) if out_dir and not os.path.exists(out_dir): os.makedirs(out_dir) with open(fname, "wb") as fid: idx2np.convert_to_file(fid, arr) logger.info("%s saved", fname)
def get_opertor(cls, op_info): op_type = op_info.op_type codegen_namespaces = op_info.code_gen_attributes.get( 'namespaces', tuple()) op_cls = cls._operators.get((codegen_namespaces, op_type)) if op_cls is None: missing_op_cls = cls._operators['_MissingOperator'] if op_info.op_type not in cls._warned_missing_ops: logger.warning( '{} is missing, no code will be generated for it'.format( op_info.op_type)) cls._warned_missing_ops.add(op_info.op_type) return missing_op_cls(op_info) return op_cls(op_info)
def __attrs_post_init__(self): skip_pattern = re.compile(r'_utensor_[^_]*') if self.op_attr: op_attr = {} for k, v in self.op_attr.items(): match = skip_pattern.match(k) if match: op_attr[k] = v else: try: op_attr[k] = ConverterDispatcher.get_generic_value(v) except ValueError: logger.warning('cannot convert %s to generic value: %s(%s)', k, v, type(v)) op_attr[k] = v self.op_attr = op_attr self._ugraph.ops_info[self.name] = self if not self.n_inputs == len(self.input_tensors): raise ValueError( 'n_inputs is not equal to the length of input_tensors: {}'.format(self.name) ) if not self.n_outputs == len(self.output_tensors): raise ValueError( 'n_outputs is not equal to the length of output_tensors: {}'.format(self.name) )
def handle_default(self, ugraph): logger.warning('fall back to default graph lowering (do nothing)') return ugraph
def default_op_data(op, fb_mdel): op_type = _get_op_type(op, fb_mdel) logger.warning('the op data parser is missing for %s', op_type) return {}
"""Legacy, DON'T USE """ from utensor_cgen.frontend.tensorflow import GraphDefParser from utensor_cgen.logger import logger from .base import Transformer from .pipeline import TransformerPipeline try: from tensorflow.tools.graph_transforms import TransformGraph except ImportError: logger.warning("trying to import deprecated quantization transformer") TransformGraph = None __all__ = ['QuantizeTransformer'] @TransformerPipeline.register_transformer class QuantizeTransformer(Transformer): METHOD_NAME = 'quantize' KWARGS_NAMESCOPE = '_quantize' APPLICABLE_LIBS = set(["tensorflow"]) def transform(self, ugraph): if ugraph.lib_name != 'tensorflow': raise ValueError('only support tensorflow graph') graph_def = ugraph.graph_def if TransformGraph is None: raise RuntimeError("quantization is temporary not supported") quant_graph_def = TransformGraph( input_graph_def=graph_def,