def infer_shapes( graph, # type: ONNXGraph source_shapes=None, # type: typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None] custom_shapes=None, # type: typing.Optional[typing.Dict[str, typing.Callable]] ): # type: (...)->None shape_functions = dict(_DefaultShapes) if custom_shapes: shape_functions.update(custom_shapes) graph.sort() shape_fixer.fix_input_shapes(graph, source_shapes) for op in graph.operations: # Shape prop assert op.name in shape_functions, "No shape function for {}".format( op.name) inferred_shapes, inferred_dtypes = shape_functions[op.name](op) assert not utils.has_le_0(inferred_shapes) assert len(inferred_shapes) == len(inferred_dtypes) == len(op.outputs) for new_shape, new_dtype, tensor in zip(inferred_shapes, inferred_dtypes, op.outputs): assert utils.compatible_shapes(tensor.shape, new_shape) tensor.shape = new_shape assert tensor.dtype is None or tensor.dtype == new_dtype tensor.dtype = new_dtype graph_utils.remove_unreachable(graph)
def evaluate_and_convert(tf_graph, source_shapes=None): # type: (TFGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None tf_graph.sort() if isinstance(source_shapes, dict): source_shapes = {(k + ':0' if ':' not in k else k): v for k, v in six.iteritems(source_shapes)} shape_fixer.fix_input_shapes(tf_graph, source_shapes) const_value_by_tensor = {} for tensor in tf_graph.tensors: if tensor.is_constant: const_value_by_tensor[tensor] = tf_pb_eval._evaluate_constant( tensor) elif tensor.is_variable: const_value_by_tensor[tensor] = tensor.data for op in tf_graph.operations: # Shape prop if op.name not in tf_pb_shape_inference._DefaultPropagators: raise utils.NNEFToolsException( "Operation '{}' is not supported".format(op.name)) propagated_shapes, propagated_dtypes = \ tf_pb_shape_inference._DefaultPropagators[op.name](op, const_value_by_tensor) assert not utils.has_le_0(propagated_shapes) assert len(propagated_shapes) == len(propagated_dtypes) == len( op.outputs) for new_shape, new_dtype, tensor in zip(propagated_shapes, propagated_dtypes, op.outputs): assert utils.compatible_shapes(tensor.shape, new_shape) tensor.shape = new_shape assert tensor.dtype is None or tensor.dtype == new_dtype tensor.dtype = new_dtype # Evaluation if op.name in tf_pb_eval._DefaultOpEvaluators: tf_pb_eval._DefaultOpEvaluators[op.name](op, const_value_by_tensor) # Conversion assert op.name in DefaultConverters, "No tf_pb_to_tf_py converter for {}".format( op.name) DefaultConverters[op.name](op, const_value_by_tensor) for tensor in tf_graph.tensors: tensor.dtype = _tf_py_dtype_by_tf_pb_dtype.get(tensor.dtype, None) for tensor in tf_graph.tensors: if tensor.is_variable: label = tensor.name if label is not None: if label.endswith(':0'): label = label[:-2] label = label.replace(':', '_') tensor.label = label
def propagate(graph, source_shapes=None, source_dtypes=None): # type: (ONNXGraph, typing.Dict[str, typing.List[int]], typing.Dict[str, str])->None if source_shapes is None: source_shapes = {} if source_dtypes is None: source_dtypes = {} graph.sort() for tensor in graph.tensors: if tensor.name and tensor.name in source_shapes: if not _is_compatible(tensor.shape, source_shapes[tensor.name]): raise utils.NNEFToolsException( "The specified shape is incompatible with the original shape for {}. {} vs. {}" .format(tensor.name, source_shapes[tensor.name], tensor.shape)) tensor.shape = source_shapes[tensor.name] if tensor.name and tensor.name in source_dtypes: if not (tensor.dtype is None or tensor.dtype == source_dtypes[tensor.name]): raise utils.NNEFToolsException( "The specified dtype is incompatible with the original dtype for {}. {} vs. {}" .format(tensor.name, source_dtypes[tensor.name], tensor.dtype)) tensor.dtype = source_dtypes[tensor.name] if len(tensor.producers) == 0 and (tensor.shape is None or -1 in tensor.shape or tensor.dtype is None): raise utils.NNEFToolsException( "Source tensor '{}' has incomplete dtype or shape: {} {}\n" "Please specify it in --input-shape or through the corresponding API." .format(tensor.name, tensor.dtype, tensor.shape)) for op in graph.operations: # Shape prop assert op.name in _DefaultPropagators, "No shape propagator for {}".format( op.name) propagated_shapes, propagated_dtypes = _DefaultPropagators[op.name](op) assert not utils.has_le_0(propagated_shapes) assert len(propagated_shapes) == len(propagated_dtypes) == len( op.outputs) for new_shape, new_dtype, tensor in zip(propagated_shapes, propagated_dtypes, op.outputs): assert _is_compatible(tensor.shape, new_shape) tensor.shape = new_shape assert tensor.dtype is None or tensor.dtype == new_dtype tensor.dtype = new_dtype
def propagate(graph, source_shapes=None): # type: (ONNXGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None graph.sort() shape_fixer.fix_input_shapes(graph, source_shapes) for op in graph.operations: # Shape prop assert op.name in _DefaultPropagators, "No shape propagator for {}".format( op.name) propagated_shapes, propagated_dtypes = _DefaultPropagators[op.name](op) assert not utils.has_le_0(propagated_shapes) assert len(propagated_shapes) == len(propagated_dtypes) == len( op.outputs) for new_shape, new_dtype, tensor in zip(propagated_shapes, propagated_dtypes, op.outputs): assert utils.compatible_shapes(tensor.shape, new_shape) tensor.shape = new_shape assert tensor.dtype is None or tensor.dtype == new_dtype tensor.dtype = new_dtype graph_utils.remove_unreachable(graph)
def evaluate_and_convert(tf_graph, source_shapes=None, source_dtypes=None): # type: (TFGraph, typing.Dict[str, typing.List[int]], typing.Dict[str, str])->None if source_shapes is None: source_shapes = {} if source_dtypes is None: source_dtypes = {} source_shapes = {(k + ':0' if ':' not in k else k): v for k, v in six.iteritems(source_shapes)} source_dtypes = {(k + ':0' if ':' not in k else k): v for k, v in six.iteritems(source_dtypes)} tf_graph.sort() for tensor in tf_graph.tensors: tensor.dtype = _tf_py_dtype_by_tf_pb_dtype.get(tensor.dtype, None) if tensor.name and tensor.name in source_shapes: if not _is_compatible(tensor.shape, source_shapes[tensor.name]): raise utils.NNEFToolsException( "The specified shape is incompatible with the original shape for {}. {} vs. {}" .format(tensor.name, source_shapes[tensor.name], tensor.shape)) tensor.shape = source_shapes[tensor.name] if tensor.name and tensor.name in source_dtypes: if not (tensor.dtype is None or tensor.dtype == source_dtypes[tensor.name]): raise utils.NNEFToolsException( "The specified dtype is incompatible with the original dtype for {}. {} vs. {}" .format(tensor.name, source_dtypes[tensor.name], tensor.dtype)) tensor.dtype = source_dtypes[tensor.name] if len(tensor.producers) == 0 and (tensor.shape is None or -1 in tensor.shape or tensor.dtype is None): raise utils.NNEFToolsException( "Source tensor '{}' has incomplete dtype or shape: {} {}\n" "Please specify it in --input-shape or through the corresponding API." .format(tensor.name, tensor.dtype, tensor.shape)) const_value_by_tensor = {} for tensor in tf_graph.tensors: if tensor.is_constant: const_value_by_tensor[tensor] = tf_pb_eval._evaluate_constant( tensor) elif tensor.is_variable: const_value_by_tensor[tensor] = tensor.data for op in tf_graph.operations: # Shape prop assert op.name in tf_pb_shape_inference._DefaultPropagators, "No shape propagator for {}".format( op.name) propagated_shapes, propagated_dtypes = tf_pb_shape_inference._DefaultPropagators[ op.name](op, const_value_by_tensor) assert not utils.has_le_0(propagated_shapes) assert len(propagated_shapes) == len(propagated_dtypes) == len( op.outputs) for new_shape, new_dtype, tensor in zip(propagated_shapes, propagated_dtypes, op.outputs): assert _is_compatible(tensor.shape, new_shape) tensor.shape = new_shape assert tensor.dtype is None or tensor.dtype == new_dtype # Evaluation if op.name in tf_pb_eval._DefaultOpEvaluators: tf_pb_eval._DefaultOpEvaluators[op.name](op, const_value_by_tensor) # Conversion assert op.name in _DefaultConverters, "No tf_pb_to_tf_py converter for {}".format( op.name) _DefaultConverters[op.name](op, const_value_by_tensor) for tensor in tf_graph.tensors: if tensor.is_variable: label = tensor.name if label is not None: if label.endswith(':0'): label = label[:-2] label = label.replace(':', '_') tensor.label = label