def infer_shapes( graph, # type: ONNXGraph source_shapes=None, # type: typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None] custom_shapes=None, # type: typing.Optional[typing.Dict[str, typing.Callable]] ): # type: (...)->None shape_functions = dict(_DefaultShapes) if custom_shapes: shape_functions.update(custom_shapes) graph.sort() shape_fixer.fix_input_shapes(graph, source_shapes) for op in graph.operations: # Shape prop assert op.name in shape_functions, "No shape function for {}".format( op.name) inferred_shapes, inferred_dtypes = shape_functions[op.name](op) assert not utils.has_le_0(inferred_shapes) assert len(inferred_shapes) == len(inferred_dtypes) == len(op.outputs) for new_shape, new_dtype, tensor in zip(inferred_shapes, inferred_dtypes, op.outputs): assert utils.compatible_shapes(tensor.shape, new_shape) tensor.shape = new_shape assert tensor.dtype is None or tensor.dtype == new_dtype tensor.dtype = new_dtype graph_utils.remove_unreachable(graph)
def fix_input_shapes(graph, source_shapes): # type: (BaseGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None def get_shape_for(name): if isinstance(source_shapes, dict) and name in source_shapes: return source_shapes[name] elif isinstance(source_shapes, list): return list(source_shapes) elif utils.is_anyint(source_shapes): return utils.anyint_to_int(source_shapes) return None placeholders = [ tensor for tensor in graph.tensors if len(tensor.producers) == 0 and not tensor.is_constant and not tensor.is_variable ] if source_shapes is None: if any(tensor.shape is None or -1 in tensor.shape for tensor in placeholders): for tensor in placeholders: print("Info: Input shape: {}: {}".format( tensor.name, tensor.shape)) for tensor in placeholders: # type: BaseTensor shape_for_this = get_shape_for(tensor.name) if tensor.name else None if isinstance(shape_for_this, list): if not utils.compatible_shapes(tensor.shape, shape_for_this): raise utils.NNEFToolsException( "The specified shape is incompatible with the original shape for {}. {} vs. {}" .format(tensor.name, shape_for_this, tensor.shape)) tensor.shape = shape_for_this elif shape_for_this is None or isinstance(shape_for_this, int): if tensor.shape is None: raise utils.NNEFToolsException( "The full shape must be specified for {}, because it is unknown." .format(tensor.name)) elif -1 in tensor.shape: if shape_for_this is None: shape_for_this = 1 print( "Warning: Incomplete input shape is auto-fixed: {}. {} -> {}. " "Use --input-shape if other shape is desired.".format( tensor.name, tensor.shape, [ shape_for_this if dim == -1 else dim for dim in tensor.shape ])) tensor.shape = [ shape_for_this if dim == -1 else dim for dim in tensor.shape ] else: assert False if tensor.dtype is None: raise utils.NNEFToolsException( "An input tensor has incomplete dtype, " "we have thought that this is impossible, " "please file a bug report to NNEF Tools.")
def evaluate_and_convert(tf_graph, source_shapes=None): # type: (TFGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None tf_graph.sort() if isinstance(source_shapes, dict): source_shapes = {(k + ':0' if ':' not in k else k): v for k, v in six.iteritems(source_shapes)} shape_fixer.fix_input_shapes(tf_graph, source_shapes) const_value_by_tensor = {} for tensor in tf_graph.tensors: if tensor.is_constant: const_value_by_tensor[tensor] = tf_pb_eval._evaluate_constant( tensor) elif tensor.is_variable: const_value_by_tensor[tensor] = tensor.data for op in tf_graph.operations: # Shape prop if op.name not in tf_pb_shape_inference._DefaultPropagators: raise utils.NNEFToolsException( "Operation '{}' is not supported".format(op.name)) propagated_shapes, propagated_dtypes = \ tf_pb_shape_inference._DefaultPropagators[op.name](op, const_value_by_tensor) assert not utils.has_le_0(propagated_shapes) assert len(propagated_shapes) == len(propagated_dtypes) == len( op.outputs) for new_shape, new_dtype, tensor in zip(propagated_shapes, propagated_dtypes, op.outputs): assert utils.compatible_shapes(tensor.shape, new_shape) tensor.shape = new_shape assert tensor.dtype is None or tensor.dtype == new_dtype tensor.dtype = new_dtype # Evaluation if op.name in tf_pb_eval._DefaultOpEvaluators: tf_pb_eval._DefaultOpEvaluators[op.name](op, const_value_by_tensor) # Conversion assert op.name in DefaultConverters, "No tf_pb_to_tf_py converter for {}".format( op.name) DefaultConverters[op.name](op, const_value_by_tensor) for tensor in tf_graph.tensors: tensor.dtype = _tf_py_dtype_by_tf_pb_dtype.get(tensor.dtype, None) for tensor in tf_graph.tensors: if tensor.is_variable: label = tensor.name if label is not None: if label.endswith(':0'): label = label[:-2] label = label.replace(':', '_') tensor.label = label
def evaluate_shape_of_operation(op, const_value_by_tensor): # type: (TFOperation, typing.Dict[TFTensor, np.ndarray])->None if all(output.shape is not None and all(s is not None for s in output.shape) for output in op.outputs): return old_shapes = [output.shape for output in op.outputs] if all(output in const_value_by_tensor for output in op.outputs): for output in op.outputs: output.shape = list(np.shape(const_value_by_tensor[output])) elif op.name in _DefaultOpShapeEvaluators: _DefaultOpShapeEvaluators[op.name](op) for old_shape, tensor in zip(old_shapes, op.outputs): assert tensor.shape is not None and all(s is not None for s in tensor.shape) assert utils.compatible_shapes(old_shape, tensor.shape), \ "{}: Evaluated shape ({}) not compatible with original shape ({})".format(tensor, tensor.shape, old_shape)
def propagate(graph, source_shapes=None): # type: (ONNXGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None graph.sort() shape_fixer.fix_input_shapes(graph, source_shapes) for op in graph.operations: # Shape prop assert op.name in _DefaultPropagators, "No shape propagator for {}".format( op.name) propagated_shapes, propagated_dtypes = _DefaultPropagators[op.name](op) assert not utils.has_le_0(propagated_shapes) assert len(propagated_shapes) == len(propagated_dtypes) == len( op.outputs) for new_shape, new_dtype, tensor in zip(propagated_shapes, propagated_dtypes, op.outputs): assert utils.compatible_shapes(tensor.shape, new_shape) tensor.shape = new_shape assert tensor.dtype is None or tensor.dtype == new_dtype tensor.dtype = new_dtype graph_utils.remove_unreachable(graph)
def tf_get_input_shapes(input_shape=None): def get_shape_for(name): if isinstance(input_shape, dict) and name in input_shape: return input_shape[name] elif isinstance(input_shape, list): return list(input_shape) elif utils.is_anyint(input_shape): return utils.anyint_to_int(input_shape) return None if isinstance(input_shape, dict): input_shape = {(k + ':0' if ':' not in k else k): v for k, v in six.iteritems(input_shape)} placeholders = tf_get_placeholders() new_input_shapes = {} if input_shape is None: if any( tf_shape_normalize(tensor.shape) is None or -1 in tf_shape_normalize(tensor.shape) for tensor in placeholders): for tensor in placeholders: print("Info: Input shape: {}: {}".format( tensor.name, tf_shape_normalize(tensor.shape))) for tensor in placeholders: tensor_shape = tf_shape_normalize(tensor.shape) shape_for_this = get_shape_for(tensor.name) if tensor.name else None if isinstance(shape_for_this, list): if not utils.compatible_shapes(tensor_shape, shape_for_this): raise utils.NNEFToolsException( "The specified shape is incompatible with the original shape for {}. {} vs. {}" .format(tensor.name, shape_for_this, tensor_shape)) tensor_shape = shape_for_this elif shape_for_this is None or isinstance(shape_for_this, int): if tensor_shape is None: raise utils.NNEFToolsException( "The full shape must be specified for {}, because it is unknown." .format(tensor.name)) elif -1 in tensor_shape: if shape_for_this is None: shape_for_this = 1 print( "Warning: Incomplete input shape is auto-fixed: {}. {} -> {}. " "Use --input-shape if other shape is desired.".format( tensor.name, tensor_shape, [ shape_for_this if dim == -1 else dim for dim in tensor_shape ])) tensor_shape = [ shape_for_this if dim == -1 else dim for dim in tensor_shape ] else: assert False if tensor.dtype is None: raise utils.NNEFToolsException( "An input tensor has incomplete dtype, " "we have thought that this is impossible, " "please file a bug report to NNEF Tools.") new_input_shapes[tensor.name] = (tensor.dtype.name, tensor_shape) return new_input_shapes