def tensor_to_inference_result(t): r = cpp_shape_inference_pb2.CppShapeInferenceResult() r.shape.CopyFrom(t.get_shape().as_proto()) # pylint: disable=protected-access if t._handle_data is not None: r.handle_data.CopyFrom(t._handle_data) # pylint: enable=protected-access return r.SerializeToString()
def _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn): """Core implementaton of call_cpp_shape_fn.""" graph_def_version = op.graph.graph_def_versions.producer node_def_str = op.node_def.SerializeToString() def tensor_to_inference_result(t): r = cpp_shape_inference_pb2.CppShapeInferenceResult() r.shape.CopyFrom(t.get_shape().as_proto()) # pylint: disable=protected-access if t._handle_data is not None: r.handle_data.CopyFrom(t._handle_data) # pylint: enable=protected-access return r.SerializeToString() input_shapes = [tensor_to_inference_result(i) for i in op.inputs] input_tensors = [None for i in input_shapes] for idx in input_tensors_needed: v = tensor_util.constant_value(op.inputs[idx]) if v is not None: input_tensors[idx] = np.asarray(v) serialized_unknown_shape = ( tensor_shape.TensorShape(None).as_proto().SerializeToString()) arr = [serialized_unknown_shape for i in input_shapes] for idx in input_tensors_as_shapes_needed: s = tensor_util.constant_value_as_shape(op.inputs[idx]) if s is not None: arr[idx] = s.as_proto().SerializeToString() input_tensors_as_shapes = arr missing_shape_fn = False try: with errors.raise_exception_on_not_ok_status() as status: output = pywrap_tensorflow.RunCppShapeInference( graph_def_version, node_def_str, input_shapes, input_tensors, input_tensors_as_shapes, status) except errors.InvalidArgumentError as err: if err.message.startswith("No shape inference function exists for op"): missing_shape_fn = True else: raise ValueError(err.message) if missing_shape_fn: if require_shape_fn: raise RuntimeError( "No C++ shape function registered for standard op: %s" % op.type) return unknown_shape(op) output_shapes = output[:-1] # Convert TensorShapeProto values in output_shapes. result_protos = [ cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s) for s in output_shapes ] result = [r.shape for r in result_protos] result_handle_data = [ r.handle_data if r.handle_data.is_set else None for r in result_protos ] return { "shapes": result, "handle_data": result_handle_data, "inputs_needed": output[-1] }
def _call_cpp_shape_fn_impl( op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn): """Core implementaton of call_cpp_shape_fn.""" node_def_str = op.node_def.SerializeToString() def tensor_to_inference_result(t): r = cpp_shape_inference_pb2.CppShapeInferenceResult() r.shape.CopyFrom(t.get_shape().as_proto()) # pylint: disable=protected-access r.handle_shape.CopyFrom(t._handle_shape) r.handle_dtype = t._handle_dtype # pylint: enable=protected-access return r.SerializeToString() input_shapes = [tensor_to_inference_result(i) for i in op.inputs] input_tensors = [None for i in input_shapes] for idx in input_tensors_needed: v = tensor_util.constant_value(op.inputs[idx]) if v is not None: input_tensors[idx] = np.asarray(v) serialized_unknown_shape = ( tensor_shape.TensorShape(None).as_proto().SerializeToString()) arr = [serialized_unknown_shape for i in input_shapes] for idx in input_tensors_as_shapes_needed: s = tensor_util.constant_value_as_shape(op.inputs[idx]) if s is not None: arr[idx] = s.as_proto().SerializeToString() input_tensors_as_shapes = arr missing_shape_fn = False try: with errors.raise_exception_on_not_ok_status() as status: output = pywrap_tensorflow.RunCppShapeInference( node_def_str, input_shapes, input_tensors, input_tensors_as_shapes, status) except errors.InvalidArgumentError as err: if err.message.startswith("No shape inference function exists for op"): missing_shape_fn = True else: raise ValueError(err.message) if missing_shape_fn: if require_shape_fn: raise RuntimeError( "No C++ shape function registered for standard op: %s" % op.type) return unknown_shape(op) output_shapes = output[:-1] # Convert TensorShapeProto values in output_shapes. result_protos = [ cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s) for s in output_shapes ] result = [r.shape for r in result_protos] result_handle_shapes = [r.handle_shape for r in result_protos] result_handle_dtypes = [r.handle_dtype for r in result_protos] if debug_python_shape_fn: try: python_result = [tensor_shape.as_shape(s) for s in debug_python_shape_fn(op)] except Exception as err: raise AssertionError("Python shape function return error but " "C++ shape functon did not: %s" % str(err)) result_as_shapes = [tensor_shape.as_shape(s) for s in result] if str(result_as_shapes) != str(python_result): raise ValueError( ("Python vs CPP shape mismatch. " "CPP: %s vs python: %s on node %s " "with input shapes %s") % ( str(result_as_shapes), str(python_result), str(op.node_def), ",".join([str(i.get_shape()) for i in op.inputs]))) return {"shapes": result, "handle_shapes": result_handle_shapes, "handle_dtypes": result_handle_dtypes, "inputs_needed": output[-1]}
def call_cpp_shape_fn(op, input_tensors_needed=None, input_tensors_as_shapes_needed=None, debug_python_shape_fn=None, require_shape_fn=True): """A shape function that delegates to the registered C++ shape function. Args: op: the node in the graph for which to compute output shapes. input_tensors_needed: a list of input tensor indices for which to compute the input tensor's value and pass to the C++ shape function. input_tensors_as_shapes_needed: a list of input tensor indices for which to compute the constant_value_as_shape and pass to the C++ shape function. debug_python_shape_fn: For testing only during migration to using call_cpp_shape_fn. Do not submit calls that set this, as the comparison is slow. If non-None, the python shape function; this function will be called and its output compared to that of the C++ shape function. require_shape_fn: If true, and the C++ shape function is not registered in the current binary then an exception is raised; otherwise, if the C++ shape function is not registered then unknown_shape is used. Returns: A dictionary with the following keys: shapes: A TensorShape list of the output shapes of the op, as computed using the C++ shape inference function registered for the op. handle_shapes: A TensorShape list of the shapes for handle outputs, if any. handle_dtypes: A list of DataType enums for the handle outputs, if any. Raises: ValueError: If the C++ shape function returned an error (e.g. because the shapes of the inputs are of the wrong rank or otherwise incompatible according to the shape function). RuntimeError: If the C++ shape function is not registered and <require_shape_fn> is True. """ if op.type == "Const": # To avoid serializing large constants, we special-case constant # here, even though it has a C++ shape function. When Python # calls the C / C-API directly, we should be able to remove this. return { "shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)], "handle_shapes": [tensor_shape.TensorShape(None).as_proto()], "handle_dtypes": [types_pb2.DT_INVALID] } node_def_str = op.node_def.SerializeToString() def tensor_to_inference_result(t): r = cpp_shape_inference_pb2.CppShapeInferenceResult() r.shape.CopyFrom(t.get_shape().as_proto()) # pylint: disable=protected-access r.handle_shape.CopyFrom(t._handle_shape) r.handle_dtype = t._handle_dtype # pylint: enable=protected-access return r.SerializeToString() input_shapes = [tensor_to_inference_result(i) for i in op.inputs] input_tensors = [None for i in input_shapes] if input_tensors_needed: for idx in input_tensors_needed: v = tensor_util.constant_value(op.inputs[idx]) if v is not None: input_tensors[idx] = np.asarray(v) serialized_unknown_shape = ( tensor_shape.TensorShape(None).as_proto().SerializeToString()) arr = [serialized_unknown_shape for i in input_shapes] if input_tensors_as_shapes_needed: for idx in input_tensors_as_shapes_needed: s = tensor_util.constant_value_as_shape(op.inputs[idx]) if s is not None: arr[idx] = s.as_proto().SerializeToString() input_tensors_as_shapes = arr missing_shape_fn = False try: with errors.raise_exception_on_not_ok_status() as status: output_shapes = pywrap_tensorflow.RunCppShapeInference( node_def_str, input_shapes, input_tensors, input_tensors_as_shapes, status) except errors.InvalidArgumentError as err: if err.message.startswith("No shape inference function exists for op"): missing_shape_fn = True else: raise ValueError(err.message) if missing_shape_fn: if require_shape_fn: raise RuntimeError( "No C++ shape function registered for standard op: %s" % op.type) return unknown_shape(op) # Convert TensorShapeProto values in output_shapes. result_protos = [ cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s) for s in output_shapes ] result = [r.shape for r in result_protos] result_handle_shapes = [r.handle_shape for r in result_protos] result_handle_dtypes = [r.handle_dtype for r in result_protos] if debug_python_shape_fn: try: python_result = [ tensor_shape.as_shape(s) for s in debug_python_shape_fn(op) ] except Exception as err: raise AssertionError("Python shape function return error but " "C++ shape functon did not: %s" % str(err)) result_as_shapes = [tensor_shape.as_shape(s) for s in result] if str(result_as_shapes) != str(python_result): raise ValueError( ("Python vs CPP shape mismatch. " "CPP: %s vs python: %s on node %s " "with input shapes %s") % (str(result_as_shapes), str(python_result), str(op.node_def), ",".join([str(i.get_shape()) for i in op.inputs]))) return { "shapes": result, "handle_shapes": result_handle_shapes, "handle_dtypes": result_handle_dtypes }
def call_cpp_shape_fn(op, input_tensors_needed=None, input_tensors_as_shapes_needed=None, debug_python_shape_fn=None): """A shape function that delegates to the registered C++ shape function. Args: op: the node in the graph for which to compute output shapes. input_tensors_needed: a list of input tensor indices for which to compute the input tensor's value and pass to the C++ shape function. input_tensors_as_shapes_needed: a list of input tensor indices for which to compute the constant_value_as_shape and pass to the C++ shape function. debug_python_shape_fn: For testing only during migration to using call_cpp_shape_fn. Do not submit calls that set this, as the comparison is slow. If non-None, the python shape function; this function will be called and its output compared to that of the C++ shape function. Returns: A dictionary with the following keys: shapes: A TensorShape list of the output shapes of the op, as computed using the C++ shape inference function registered for the op. handle_shapes: A TensorShape list of the shapes for handle outputs, if any. handle_dtypes: A list of DataType enums for the handle outputs, if any. Raises: ValueError: If the C++ shape function returned an error (e.g. because the shapes of the inputs are of the wrong rank or otherwise incompatible according to the shape function). """ node_def_str = op.node_def.SerializeToString() def tensor_to_inference_result(t): r = cpp_shape_inference_pb2.CppShapeInferenceResult() r.shape.CopyFrom(t.get_shape().as_proto()) # pylint: disable=protected-access r.handle_shape.CopyFrom(t._handle_shape) r.handle_dtype = t._handle_dtype # pylint: enable=protected-access return r.SerializeToString() input_shapes = [tensor_to_inference_result(i) for i in op.inputs] input_tensors = [None for i in input_shapes] if input_tensors_needed: for idx in input_tensors_needed: v = tensor_util.constant_value(op.inputs[idx]) if v is not None: input_tensors[idx] = np.asarray(v) serialized_unknown_shape = ( tensor_shape.TensorShape(None).as_proto().SerializeToString()) arr = [serialized_unknown_shape for i in input_shapes] if input_tensors_as_shapes_needed: for idx in input_tensors_as_shapes_needed: s = tensor_util.constant_value_as_shape(op.inputs[idx]) if s is not None: arr[idx] = s.as_proto().SerializeToString() input_tensors_as_shapes = arr try: with errors.raise_exception_on_not_ok_status() as status: output_shapes = pywrap_tensorflow.RunCppShapeInference( node_def_str, input_shapes, input_tensors, input_tensors_as_shapes, status) except errors.InvalidArgumentError as err: raise ValueError(err.message) # Convert TensorShapeProto values in output_shapes. result_protos = [ cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s) for s in output_shapes ] result = [r.shape for r in result_protos] result_handle_shapes = [r.handle_shape for r in result_protos] result_handle_dtypes = [r.handle_dtype for r in result_protos] if debug_python_shape_fn: try: python_result = [ tensor_shape.as_shape(s) for s in debug_python_shape_fn(op) ] except Exception as err: raise AssertionError("Python shape function return error but " "C++ shape functon did not: %s" % str(err)) if str(result) != str(python_result): raise ValueError( ("Python vs CPP shape mismatch. " "CPP: %s vs python: %s on node %s " "with input shapes %s") % (str(result), str(python_result), str(op.node_def), ",".join( [str(i.get_shape()) for i in op.inputs]))) return { "shapes": result, "handle_shapes": result_handle_shapes, "handle_dtypes": result_handle_dtypes }