def _init_from_proto(self, hparam_def): """Creates a new HParams from `HParamDef` protocol buffer. Args: hparam_def: `HParamDef` protocol buffer. """ assert isinstance(hparam_def, hparam_pb2.HParamDef) for name, value in hparam_def.hparam.items(): kind = value.WhichOneof('kind') if kind.endswith('_value'): # Single value. if kind.startswith('int64'): # Setting attribute value to be 'int' to ensure the type is compatible # with both Python2 and Python3. self.add_hparam(name, int(getattr(value, kind))) elif kind.startswith('bytes'): # Setting attribute value to be 'str' to ensure the type is compatible # with both Python2 and Python3. UTF-8 encoding is assumed. self.add_hparam(name, compat.as_str(getattr(value, kind))) else: self.add_hparam(name, getattr(value, kind)) else: # List of values. if kind.startswith('int64'): # Setting attribute value to be 'int' to ensure the type is compatible # with both Python2 and Python3. self.add_hparam(name, [int(v) for v in getattr(value, kind).value]) elif kind.startswith('bytes'): # Setting attribute value to be 'str' to ensure the type is compatible # with both Python2 and Python3. UTF-8 encoding is assumed. self.add_hparam( name, [compat.as_str(v) for v in getattr(value, kind).value]) else: self.add_hparam(name, [v for v in getattr(value, kind).value])
def _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements): """Populates the TF_ImportGraphDefOptions `options`.""" c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix) c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True) c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True) for input_src, input_dst in input_map.items(): input_src = compat.as_str(input_src) if input_src.startswith('^'): src_name = compat.as_bytes(input_src[1:]) dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access c_api.TF_ImportGraphDefOptionsRemapControlDependency(options, src_name, dst_op) else: src_name, src_idx = _ParseTensorName(input_src) src_name = compat.as_str(src_name) dst_output = input_dst._as_tf_output() # pylint: disable=protected-access c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_idx, dst_output) for name in return_elements or []: if ':' in name: op_name, index = _ParseTensorName(name) op_name = compat.as_str(op_name) c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index) else: c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, compat.as_str(name))
def meta_graph_transform( base_meta_graph_def, input_names, output_names, transforms, tags, checkpoint_path=None): """Apply the Graph Transform tool to a MetaGraphDef. Args: base_meta_graph_def: A MetaGraphDef protocol buffer to transform. input_names: Names of input nodes. output_names: Names of output nodes. transforms: A list of strings naming the graph transforms to be applied in order. These transform names are exactly those supported by the Graph Transform Tool, with the addition of the 'freeze_graph' transform. tags: A list of tags with which to annotate the transformed MetaGraphDef. checkpoint_path: A path to a checkpoint to restore during freezing, if needed (default None). Returns: A new transformed MetaGraphDef protocol buffer. """ meta_graph_def = _meta_graph_pb2.MetaGraphDef() initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def) transformed_graph_def = _do_transforms( base_meta_graph_def.graph_def, input_names, output_names, initializer_names, transforms, base_meta_graph_def.saver_def, checkpoint_path) meta_graph_def.graph_def.CopyFrom(transformed_graph_def) meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def) meta_graph_def.meta_info_def.ClearField('tags') for tag in tags: meta_graph_def.meta_info_def.tags.append(tag) base_op_names = [compat.as_str(node.name) for node in base_meta_graph_def.graph_def.node] retained_op_names = [compat.as_str(node.name) for node in meta_graph_def.graph_def.node] removed_op_names = set(base_op_names) - set(retained_op_names) # Copy saver, excluding any pruned nodes _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names) # Copy collections, excluding any pruned nodes for collection_name in base_meta_graph_def.collection_def: _add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) # Copy signature_defs, excluding any pruned nodes for signature_name in base_meta_graph_def.signature_def: _add_pruned_signature( base_meta_graph_def, meta_graph_def, signature_name, removed_op_names) return meta_graph_def
def assert_equal_graph_def(actual, expected, checkpoint_v2=False): """Asserts that two `GraphDef`s are (mostly) the same. Compares two `GraphDef` protos for equality, ignoring versions and ordering of nodes, attrs, and control inputs. Node names are used to match up nodes between the graphs, so the naming of nodes must be consistent. Args: actual: The `GraphDef` we have. expected: The `GraphDef` we expected. checkpoint_v2: boolean determining whether to ignore randomized attribute values that appear in V2 checkpoints. Raises: AssertionError: If the `GraphDef`s do not match. TypeError: If either argument is not a `GraphDef`. """ if not isinstance(actual, graph_pb2.GraphDef): raise TypeError("Expected tf.GraphDef for actual, got %s" % type(actual).__name__) if not isinstance(expected, graph_pb2.GraphDef): raise TypeError("Expected tf.GraphDef for expected, got %s" % type(expected).__name__) if checkpoint_v2: _strip_checkpoint_v2_randomized(actual) _strip_checkpoint_v2_randomized(expected) diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(), expected.SerializeToString()) if diff: raise AssertionError(compat.as_str(diff))
def _create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: func_graph: function._FuncGraph Returns: The name of the new TF_Function. """ c_func = c_api.TF_GraphToFunction_wrapper( func_graph._c_graph, compat.as_str(func_graph.name), False, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in func_graph.inputs], [t._as_tf_output() for t in func_graph.outputs], [], None, # opts None) # description _ = c_api_util.ScopedTFFunction(c_func) # TODO(b/109833212): this sucks, we're serializing the TF_Function*, # deserializing it into a Python FunctionDef, then reserializing it to create # a new TF_Function that we add to the graph. fdef = _function.function_def_from_tf_function(c_func) defined_func = _function._from_definition(fdef) defined_func._sub_functions = func_graph._functions defined_func.add_to_graph(func_graph._outer_graph) return func_graph.name
def _ProcessReturnElementsParam(return_elements): """Type-checks and possibly canonicalizes `return_elements`.""" if return_elements is None: return None if not all(isinstance(x, compat.bytes_or_text_types) for x in return_elements): raise TypeError('return_elements must be a list of strings.') return tuple(compat.as_str(x) for x in return_elements)
def encode_arg(arg, path): """A representation for this argument, for converting into signatures.""" if isinstance(arg, ops.Tensor): user_specified_name = None try: user_specified_name = compat.as_str( arg.op.get_attr("_user_specified_name")) except ValueError: pass if path and user_specified_name and user_specified_name != path[0]: # The user has explicitly named the argument differently than the name # of the function argument. name = user_specified_name else: name = "/".join([str(p) for p in path]) return tensor_spec.TensorSpec(arg.shape, arg.dtype, name) if isinstance(arg, ( int, float, bool, type(None), dtypes.DType, tensor_spec.TensorSpec, )): return arg return UnknownArgument()
def _clean_save_and_restore(graph_def, op, removed_op_names): """Clean the specified save and restore op. Updates the dtypes attribute of the save / restore op and the associated name and shape tensors to remove entries for variables that have been removed. Args: graph_def: A GraphDef proto to be transformed. op: The save or restore op to update. removed_op_names: List of op names that have been removed. """ name = op.name + '/tensor_names' shape = op.name + '/shape_and_slices' name_op = _find_op(graph_def, name) shape_op = _find_op(graph_def, shape) name_op_value_tensor = name_op.attr['value'].tensor shape_op_value_tensor = shape_op.attr['value'].tensor names = [] shapes = [] dtypes = [] for index, value in enumerate(name_op_value_tensor.string_val): if not _is_removed(compat.as_str(value), removed_op_names): names.append(value) shapes.append(shape_op_value_tensor.string_val[index]) dtypes.append(op.attr['dtypes'].list.type[index]) name_op_value_tensor.string_val[:] = names name_op_value_tensor.tensor_shape.dim[0].size = len(names) shape_op_value_tensor.string_val[:] = shapes shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes) op.attr['dtypes'].list.type[:] = dtypes name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names) shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function attrs: dict mapping names of attributes to their AttrValue values """ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str("")) for name, attr_value in attrs.items(): serialized = attr_value.SerializeToString() # TODO(iga): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None
def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): """Create a `NodeDef` proto with export_scope stripped. Args: from_node_def: A `node_def_pb2.NodeDef` protocol buffer. export_scope: A `string` representing the name scope to remove. unbound_inputs: An array of unbound input names if they exist. clear_devices: Boolean which controls whether to clear device information from node_def. Default false. Returns: A `node_def_pb2.NodeDef` protocol buffer. """ node_def = copy.deepcopy(from_node_def) for i, v in enumerate(node_def.input): if (export_scope and not node_def.input[i].lstrip("^").startswith(export_scope)): # Adds "$unbound_inputs_" prefix to the unbound name so they are easily # identifiable. node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1" + _UNBOUND_INPUT_PREFIX + r"\2", compat.as_str(v)) unbound_inputs.append(node_def.input[i]) else: node_def.input[i] = ops.strip_name_scope(v, export_scope) node_def.name = compat.as_bytes( ops.strip_name_scope(from_node_def.name, export_scope)) for k, v in six.iteritems(from_node_def.attr): if k == "_class": new_s = [compat.as_bytes( ops.strip_name_scope(s, export_scope)) for s in v.list.s if not export_scope or compat.as_str(s).split("@")[1].startswith(export_scope)] node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_s))) elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": if not export_scope or compat.as_str(v.s).startswith(export_scope): new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) else: node_def.attr[k].CopyFrom(v) if clear_devices: node_def.device = "" return node_def
def save(self, sess, save_path, global_step=None, latest_filename=None): """Saves variables. This method runs the ops added by the constructor for saving variables. It requires a session in which the graph was launched. The variables to save must also have been initialized. The method returns the path of the newly created checkpoint file. This path can be passed directly to a call to `restore()`. Args: sess: A Session to use to save the variables. save_path: string. Path to the checkpoint filename. If the saver is `sharded`, this is the prefix of the sharded checkpoint filename. global_step: If provided the global step number is appended to `save_path` to create the checkpoint filename. The optional argument can be a `Tensor`, a `Tensor` name or an integer. latest_filename: Optional name for the protocol buffer file that will contains the list of most recent checkpoint filenames. That file, kept in the same directory as the checkpoint files, is automatically managed by the saver to keep track of recent checkpoints. Defaults to 'checkpoint'. Returns: A string: path at which the variables were saved. If the saver is sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' is the number of shards created. Raises: TypeError: If `sess` is not a `Session`. ValueError: If `latest_filename` contains path components. """ if latest_filename is None: latest_filename = "checkpoint" if os.path.split(latest_filename)[0]: raise ValueError( "'latest_filename' must not contain path components") if global_step is not None: if not isinstance(global_step, compat.integral_types): global_step = training_util.global_step(sess, global_step) checkpoint_file = "%s-%d" % (save_path, global_step) else: checkpoint_file = save_path save_path = os.path.dirname(save_path) if not isinstance(sess, session.SessionInterface): raise TypeError("'sess' must be a Session; %s" % sess) model_checkpoint_path = sess.run( self._save_tensor_name, {self._filename_tensor_name: checkpoint_file}) model_checkpoint_path = compat.as_str(model_checkpoint_path) self._MaybeDeleteOldCheckpoints(model_checkpoint_path) update_checkpoint_state(save_path, model_checkpoint_path, self.last_checkpoints, latest_filename) return model_checkpoint_path
def _ProcessReturnElementsParam(return_elements): """Type-checks and possibly canonicalizes `return_elements`.""" if return_elements is None: return None if not all( isinstance(x, compat.bytes_or_text_types) for x in return_elements): raise TypeError('return_elements must be a list of strings.') return tuple(compat.as_str(x) for x in return_elements)
def save(self, sess, save_path, global_step=None, latest_filename=None): """Saves variables. This method runs the ops added by the constructor for saving variables. It requires a session in which the graph was launched. The variables to save must also have been initialized. The method returns the path of the newly created checkpoint file. This path can be passed directly to a call to `restore()`. Args: sess: A Session to use to save the variables. save_path: String. Path to the checkpoint filename. If the saver is `sharded`, this is the prefix of the sharded checkpoint filename. global_step: If provided the global step number is appended to `save_path` to create the checkpoint filename. The optional argument can be a `Tensor`, a `Tensor` name or an integer. latest_filename: Optional name for the protocol buffer file that will contains the list of most recent checkpoint filenames. That file, kept in the same directory as the checkpoint files, is automatically managed by the saver to keep track of recent checkpoints. Defaults to 'checkpoint'. Returns: A string: path at which the variables were saved. If the saver is sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' is the number of shards created. Raises: TypeError: If `sess` is not a `Session`. ValueError: If `latest_filename` contains path components. """ if latest_filename is None: latest_filename = "checkpoint" if os.path.split(latest_filename)[0]: raise ValueError("'latest_filename' must not contain path components") if global_step is not None: if not isinstance(global_step, compat.integral_types): global_step = training_util.global_step(sess, global_step) checkpoint_file = "%s-%d" % (save_path, global_step) else: checkpoint_file = save_path save_path = os.path.dirname(save_path) if not isinstance(sess, session.SessionInterface): raise TypeError("'sess' must be a Session; %s" % sess) model_checkpoint_path = sess.run( self._save_tensor_name, {self._filename_tensor_name: checkpoint_file}) model_checkpoint_path = compat.as_str(model_checkpoint_path) self._MaybeDeleteOldCheckpoints(model_checkpoint_path) update_checkpoint_state(save_path, model_checkpoint_path, self.last_checkpoints, latest_filename) return model_checkpoint_path
def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): """Create a `NodeDef` proto with export_scope stripped. Args: from_node_def: A `node_def_pb2.NodeDef` protocol buffer. export_scope: A `string` representing the name scope to remove. unbound_inputs: An array of unbound input names if they exist. clear_devices: Boolean which controls whether to clear device information from node_def. Default false. Returns: A `node_def_pb2.NodeDef` protocol buffer. """ node_def = copy.deepcopy(from_node_def) for i, v in enumerate(node_def.input): if (export_scope and not node_def.input[i].lstrip("^").startswith(export_scope)): # Adds "$unbound_inputs_" prefix to the unbound name so they are easily # identifiable. node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1" + _UNBOUND_INPUT_PREFIX + r"\2", compat.as_str(v)) unbound_inputs.append(node_def.input[i]) else: node_def.input[i] = ops.strip_name_scope(v, export_scope) node_def.name = compat.as_bytes( ops.strip_name_scope(from_node_def.name, export_scope)) for k, v in six.iteritems(from_node_def.attr): if k == "_class": new_s = [compat.as_bytes( ops.strip_name_scope(s, export_scope)) for s in v.list.s if not export_scope or compat.as_str(s).split("@")[1].startswith(export_scope)] node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_s))) else: node_def.attr[k].CopyFrom(v) if clear_devices: node_def.device = "" return node_def
def __init__(self, name, graph, operations, inputs, outputs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = fn self._grad_func = None
def __init__(self, name, graph, operations, inputs, outputs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.in_eager_mode(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = fn self._grad_func = None
def canonicalize_signatures(signatures): """Converts `signatures` into a dictionary of concrete functions.""" if signatures is None: return {} if not isinstance(signatures, collections.Mapping): signatures = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures } concrete_signatures = {} for signature_key, function in signatures.items(): signature_function = _get_signature(function) if signature_function is None: raise ValueError(( "Expected a TensorFlow function to generate a signature for, but " "got {}. Only `tf.functions` with an input signature or " "concrete functions can be used as a signature." ).format(function)) # Re-wrap the function so that it returns a dictionary of Tensors. This # matches the format of 1.x-style signatures. # pylint: disable=cell-var-from-loop @def_function.function def signature_wrapper(**kwargs): structured_outputs = signature_function(**kwargs) return _normalize_outputs(structured_outputs, signature_function.name, signature_key) # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names # always match keyword arguments. tensor_spec_signature = {} for keyword, tensor in zip( signature_function._arg_keywords, # pylint: disable=protected-access signature_function.inputs): keyword = compat.as_str(keyword) tensor_spec_signature[ keyword] = tensor_spec.TensorSpec.from_tensor(tensor, name=keyword) final_concrete = signature_wrapper.get_concrete_function( **tensor_spec_signature) # pylint: disable=protected-access if len(final_concrete._arg_keywords) == 1: # If there is only one input to the signature, a very common case, then # ordering is unambiguous and we can let people pass a positional # argument. Since SignatureDefs are unordered (protobuf "map") multiple # arguments means we need to be keyword-only. final_concrete._num_positional_args = 1 else: final_concrete._num_positional_args = 0 # pylint: enable=protected-access concrete_signatures[signature_key] = final_concrete # pylint: enable=cell-var-from-loop return concrete_signatures
def _ReadAndCheckRowsUsingFeatures(self, num_rows): self.server.handler.num_rows = num_rows with self.test_session() as sess: feature_configs = { "int64_col": parsing_ops.FixedLenFeature([1], dtype=dtypes.int64), "string_col": parsing_ops.FixedLenFeature([1], dtype=dtypes.string, default_value="s_default"), } reader = cloud.BigQueryReader( project_id=_PROJECT, dataset_id=_DATASET, table_id=_TABLE, num_partitions=4, features=feature_configs, timestamp_millis=1, test_end_point=("%s:%s" % (self.server.httpd.server_address[0], self.server.httpd.server_address[1]))) key, value = _SetUpQueue(reader) seen_rows = [] features = parsing_ops.parse_example(array_ops.reshape(value, [1]), feature_configs) for _ in range(num_rows): int_value, str_value = sess.run( [features["int64_col"], features["string_col"]]) # Parse values returned from the session. self.assertEqual(int_value.shape, (1, 1)) self.assertEqual(str_value.shape, (1, 1)) int64_col = int_value[0][0] string_col = str_value[0][0] seen_rows.append(int64_col) # Compare. expected_row = _ROWS[int64_col] self.assertEqual(int64_col, expected_row[0]) self.assertEqual( compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1] else "s_default") self.assertItemsEqual(seen_rows, range(num_rows)) with self.assertRaisesOpError( "is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): sess.run([key, value])
def request_stop(self, ex=None): """Request that the threads stop. After this is called, calls to `should_stop()` will return `True`. Args: ex: Optional `Exception`, or Python `exc_info` tuple as returned by `sys.exc_info()`. If this is the first call to `request_stop()` the corresponding exception is recorded and re-raised from `join()`. """ with self._lock: if not self._stop_event.is_set(): if ex and self._exc_info_to_raise is None: if isinstance(ex, tuple): logging.info("Error reported to Coordinator: %s", compat.as_str(unicode(ex[1]))) self._exc_info_to_raise = ex else: logging.info("Error reported to Coordinator: %s", compat.as_str(unicode(ex))) self._exc_info_to_raise = sys.exc_info() self._stop_event.set()
def _node_def_unbound(from_node_def, export_scope, unbound_inputs, as_unbound_inputs, clear_devices=False): """Create a `NodeDef` proto with export_scope stripped given input names that are treated as unbound. Args: from_node_def: A `node_def_pb2.NodeDef` protocol buffer. export_scope: A `string` representing the name scope to remove. unbound_inputs: An array of unbound input names if they exist. as_unbound_inputs: A list of `String`s. Input names that are treated as unbound when exporting Operations. clear_devices: Boolean which controls whether to clear device information from node_def. Default false. Returns: A `node_def_pb2.NodeDef` protocol buffer. """ node_def = copy.deepcopy(from_node_def) as_unbound_inputs = set(as_unbound_inputs) for i, v in enumerate(node_def.input): if node_def.input[i] in as_unbound_inputs: # Adds "$unbound_inputs_" prefix to the unbound name so they are easily # identifiable. node_def.input[i] = _unbound_name(v) unbound_inputs.append(node_def.input[i]) else: node_def.input[i] = ops.strip_name_scope(v, export_scope) node_def.name = compat.as_bytes( ops.strip_name_scope(from_node_def.name, export_scope)) for k, v in six.iteritems(from_node_def.attr): if k == "_class": new_s = [] for s in v.list.s: if compat.as_str(s) in as_unbound_inputs: new_s.append(compat.as_bytes(_unbound_name(s))) else: new_s.append( compat.as_bytes(ops.strip_name_scope(s, export_scope))) node_def.attr[k].CopyFrom( attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_s))) else: node_def.attr[k].CopyFrom(v) if clear_devices: node_def.device = "" return node_def
def _init_from_proto(self, hparam_def): """Creates a new HParams from `HParamDef` protocol buffer. Args: hparam_def: `HParamDef` protocol buffer. """ if not isinstance(hparam_def, hparam_pb2.HParamDef): raise AssertionError('Wrong "hparam_def" type') for name, value in hparam_def.hparam.items(): kind = value.WhichOneof('kind') if kind.endswith('_value'): # Single value. if kind.startswith('int64'): # Setting attribute value to be 'int' to ensure the type is compatible # with both Python2 and Python3. self.add_hparam(name, int(getattr(value, kind))) elif kind.startswith('bytes'): # Setting attribute value to be 'str' to ensure the type is compatible # with both Python2 and Python3. UTF-8 encoding is assumed. self.add_hparam(name, compat.as_str(getattr(value, kind))) else: self.add_hparam(name, getattr(value, kind)) else: # List of values. if kind.startswith('int64'): # Setting attribute value to be 'int' to ensure the type is compatible # with both Python2 and Python3. self.add_hparam( name, [int(v) for v in getattr(value, kind).value]) elif kind.startswith('bytes'): # Setting attribute value to be 'str' to ensure the type is compatible # with both Python2 and Python3. UTF-8 encoding is assumed. self.add_hparam( name, [compat.as_str(v) for v in getattr(value, kind).value]) else: self.add_hparam(name, [v for v in getattr(value, kind).value])
def _set_c_attrs(self, attrs): """Sets `attrs` as attributes of self._c_func. Requires that self._c_func is not None. Args: attrs: a dictionary from attribute name to attribute proto value """ for name, attr_value in attrs.items(): serialized = attr_value.SerializeToString() # TODO(skyewm): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use the same status. c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name), serialized)
def _validate_namespace_whitelist(namespace_whitelist): """Validates namespace whitelist argument.""" if namespace_whitelist is None: return [] if not isinstance(namespace_whitelist, list): raise TypeError("Namespace whitelist must be a list of strings.") processed = [] for namespace in namespace_whitelist: if not isinstance(namespace, six.string_types): raise ValueError("Whitelisted namespace must be a string. Got: {} of type" " {}.".format(namespace, type(namespace))) processed.append(compat.as_str(namespace)) return processed
def _ReadAndCheckRowsUsingFeatures(self, num_rows): self.server.handler.num_rows = num_rows with self.test_session() as sess: feature_configs = { "int64_col": parsing_ops.FixedLenFeature( [1], dtype=dtypes.int64), "string_col": parsing_ops.FixedLenFeature( [1], dtype=dtypes.string, default_value="s_default"), } reader = cloud.BigQueryReader( project_id=_PROJECT, dataset_id=_DATASET, table_id=_TABLE, num_partitions=4, features=feature_configs, timestamp_millis=1, test_end_point=("%s:%s" % (self.server.httpd.server_address[0], self.server.httpd.server_address[1]))) key, value = _SetUpQueue(reader) seen_rows = [] features = parsing_ops.parse_example( array_ops.reshape(value, [1]), feature_configs) for _ in range(num_rows): int_value, str_value = sess.run( [features["int64_col"], features["string_col"]]) # Parse values returned from the session. self.assertEqual(int_value.shape, (1, 1)) self.assertEqual(str_value.shape, (1, 1)) int64_col = int_value[0][0] string_col = str_value[0][0] seen_rows.append(int64_col) # Compare. expected_row = _ROWS[int64_col] self.assertEqual(int64_col, expected_row[0]) self.assertEqual( compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1] else "s_default") self.assertItemsEqual(seen_rows, range(num_rows)) with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): sess.run([key, value])
def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices. Args: cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.tpu.Topology object for the topology of the TPU cluster. """ if cluster_resolver is None: cluster_resolver = TPUClusterResolver("") master = cluster_resolver.master() logging.info("Initializing the TPU system.") if context.executing_eagerly(): # This function looks as it is for the following non-intuitive reasons. # tpu.initialize_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. # The easiest way to trigger a rewrite is to run the function with # TPUPartitionedCallOp. @function.defun def _tpu_init_fn(): return tpu.initialize_system() # We can't call _tpu_init_fn normally (because it contains just a dummy op, # see above) but need to define it to get it added to eager context # and get its assigned name. # pylint: disable=protected-access graph_func = _tpu_init_fn._get_concrete_function_internal() func_name = compat.as_str(graph_func._inference_function.name) # pylint: enable=protected-access output = tpu_functional_ops.TPUPartitionedCall(args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name) serialized_topology = output[0].numpy() else: session_config = config_pb2.ConfigProto(allow_soft_placement=True) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) logging.info("Finished initializing TPU system.") return topology.Topology(serialized=serialized_topology)
def lookup(self, name): """Looks up "name". Args: name: a string specifying the registry key for the candidate. Returns: Registered object if found Raises: LookupError: if "name" has not been registered. """ name = compat.as_str(name) if name in self._registry: return self._registry[name][_TYPE_TAG] else: raise LookupError("%s registry has no entry for: %s" % (self._name, name))
def lookup(self, name): """Looks up "name". Args: name: a string specifying the registry key for the candidate. Returns: Registered object if found Raises: LookupError: if "name" has not been registered. """ name = compat.as_str(name) if name in self._registry: return self._registry[name][_TYPE_TAG] else: raise LookupError( "%s registry has no entry for: %s" % (self._name, name))
def _node_def(from_node_def, export_scope, unbound_inputs): """Create a `NodeDef` proto with export_scope stripped. Args: from_node_def: A `node_def_pb2.NodeDef` protocol buffer. export_scope: A `string` representing the name scope to remove. unbound_inputs: An array of unbound input names if they exist. Returns: A `node_def_pb2.NodeDef` protocol buffer. """ node_def = copy.deepcopy(from_node_def) for i, v in enumerate(node_def.input): if (export_scope and not node_def.input[i].lstrip("^").startswith(export_scope)): # Adds "$unbound_inputs_" prefix to the unbound name so they are easily # identifiable. node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1$unbound_inputs_\2", compat.as_str(v)) unbound_inputs.append(node_def.input[i]) else: node_def.input[i] = ops.strip_name_scope(v, export_scope) node_def.name = compat.as_bytes( ops.strip_name_scope(from_node_def.name, export_scope)) for k, v in six.iteritems(from_node_def.attr): if k == "_class": new_s = [compat.as_bytes( ops.strip_name_scope(s, export_scope)) for s in v.list.s if not export_scope or compat.as_str(s).split("@")[1].startswith(export_scope)] node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_s))) else: node_def.attr[k].CopyFrom(v) return node_def
def _GetColocationNames(op): """Returns names of the ops that `op` should be colocated with.""" colocation_names = [] try: class_values = op.get_attr('_class') except ValueError: # No _class attr return for val in class_values: val = compat.as_str(val) if val.startswith('loc:@'): colocation_node_name = val[len('loc:@'):] if colocation_node_name != op.name: colocation_names.append(colocation_node_name) return colocation_names
def canonicalize_signatures(signatures): """Converts `signatures` into a dictionary of concrete functions.""" if signatures is None: return {} if not isinstance(signatures, collections.Mapping): signatures = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures} concrete_signatures = {} for signature_key, function in signatures.items(): signature_function = _get_signature(function) if signature_function is None: raise ValueError( ("Expected a TensorFlow function to generate a signature for, but " "got {}. Only `tf.functions` with an input signature or " "concrete functions can be used as a signature.").format(function)) # Re-wrap the function so that it returns a dictionary of Tensors. This # matches the format of 1.x-style signatures. # pylint: disable=cell-var-from-loop @def_function.function def signature_wrapper(**kwargs): structured_outputs = signature_function(**kwargs) return _normalize_outputs( structured_outputs, signature_function.name, signature_key) # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names # always match keyword arguments. tensor_spec_signature = {} for keyword, tensor in zip( signature_function._arg_keywords, # pylint: disable=protected-access signature_function.inputs): keyword = compat.as_str(keyword) tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor( tensor, name=keyword) final_concrete = signature_wrapper.get_concrete_function( **tensor_spec_signature) # pylint: disable=protected-access if len(final_concrete._arg_keywords) == 1: # If there is only one input to the signature, a very common case, then # ordering is unambiguous and we can let people pass a positional # argument. Since SignatureDefs are unordered (protobuf "map") multiple # arguments means we need to be keyword-only. final_concrete._num_positional_args = 1 else: final_concrete._num_positional_args = 0 # pylint: enable=protected-access concrete_signatures[signature_key] = final_concrete # pylint: enable=cell-var-from-loop return concrete_signatures
def _validate_namespace_whitelist(namespace_whitelist): """Validates namespace whitelist argument.""" if namespace_whitelist is None: return None if not isinstance(namespace_whitelist, list): raise TypeError("`namespace_whitelist` must be a list of strings. Got: " f"{namespace_whitelist} with type " f"{type(namespace_whitelist)}.") processed = [] for namespace in namespace_whitelist: if not isinstance(namespace, six.string_types): raise ValueError("Whitelisted namespace must be a string. Got: " f"{namespace} of type {type(namespace)}.") processed.append(compat.as_str(namespace)) return processed
def _get_signature_name_changes(concrete_function): """Checks for user-specified signature input names that are normalized.""" # Map of {user-given name: normalized name} if the names are un-identical. name_changes = {} for signature_input_name, graph_input in zip( concrete_function.function_def.signature.input_arg, concrete_function.graph.inputs): try: user_specified_name = compat.as_str( graph_input.op.get_attr("_user_specified_name")) if signature_input_name.name != user_specified_name: name_changes[user_specified_name] = signature_input_name.name except ValueError: # Signature input does not have a user-specified name. pass return name_changes
def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices in a separate session and graph. Args: cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.contrib.tpu.Topology object for the topology of the TPU cluster. """ if cluster_resolver is None: cluster_resolver = TPUClusterResolver("") master = cluster_resolver.master() logging.info("Initializing the TPU system.") if context.executing_eagerly(): # This function looks as it is for the following non-intuitive reasons. # tpu.initialize_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. # The easiest way to trigger a rewrite is to run the function with # TPUPartitionedCallOp. @function.defun def _tpu_init_fn(): return tpu.initialize_system() # We can't call _tpu_init_fn normally (because it contains just a dummy op, # see above) but need to define it to get it added to eager context # and get its assigned name. # pylint: disable=protected-access graph_func = _tpu_init_fn._get_concrete_function_internal() func_name = compat.as_str(graph_func._inference_function.name) # pylint: enable=protected-access output = tpu_functional_ops.TPUPartitionedCall( args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name) serialized_topology = output[0].numpy() else: session_config = config_pb2.ConfigProto(allow_soft_placement=True) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) logging.info("Finished initializing TPU system.") return topology.Topology(serialized=serialized_topology)
def imperative_grad(tape, target, sources, output_gradients=None, sources_raw=None, unconnected_gradients=UnconnectedGradients.NONE): """Computes gradients from the imperatively defined tape on top of the stack. Works by filtering the tape, computing how many downstream usages are of each tensor and entry, and repeatedly applying backward functions until we have gradients for all sources. Args: tape: the gradient tape which stores the trace. target: either a Tensor or list of Tensors to be differentiated. sources: list of Tensors for which we want gradients output_gradients: if not None, a list of gradient provided for each Target, or None if we are to use the target's computed downstream gradient. sources_raw: if not None, a list of the source python objects from which the sources were generated. Should have the same length as sources. Only needs to be populated if unconnected_gradients is 'zero'. unconnected_gradients: determines the value returned if the target and sources are unconnected. When 'none' the value returned is None wheras when 'zero' a zero tensor in the same shape as the sources is returned. Returns: the gradient wrt each of the sources. Raises: ValueError: if the arguments are invalid. RuntimeError: if something goes wrong. """ try: unconnected_gradients = UnconnectedGradients(unconnected_gradients) except ValueError: raise ValueError( "Unknown value for unconnected_gradients: %r" % unconnected_gradients) return pywrap_tfe.TFE_Py_TapeGradient( tape._tape, # pylint: disable=protected-access target, sources, output_gradients, sources_raw, compat.as_str(unconnected_gradients.value))
def _revive_metric_from_config(self, metadata, node_id): class_name = compat.as_str(metadata['class_name']) config = metadata.get('config') if not generic_utils.validate_config(config): return None try: obj = metrics.deserialize( generic_utils.serialize_keras_class_and_config(class_name, config)) except ValueError: return None build_input_shape = metadata.get('build_input_shape') if build_input_shape is not None and hasattr(obj, '_build'): obj._build(build_input_shape) # pylint: disable=protected-access return obj
def get_temp_export_dir(timestamped_export_dir): """Builds a directory name based on the argument but starting with 'temp-'. This relies on the fact that TensorFlow Serving ignores subdirectories of the base directory that can't be parsed as integers. Args: timestamped_export_dir: the name of the eventual export directory, e.g. /foo/bar/<timestamp> Returns: A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>. """ (dirname, basename) = os.path.split(timestamped_export_dir) temp_export_dir = os.path.join( compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(compat.as_str(basename)))) return temp_export_dir
def get_matching_files(filename): """Returns a list of files that match the given pattern. Args: filename: string, the pattern Returns: Returns a list of strings containing filenames that match the given pattern. Raises: errors.OpError: If there are filesystem / directory listing errors. """ with errors.raise_exception_on_not_ok_status() as status: # Convert each element to string, since the return values of the # vector of string should be interpreted as strings, not bytes. return [compat.as_str(matching_filename) for matching_filename in pywrap_tensorflow.GetMatchingFiles( compat.as_bytes(filename), status)]
def testReadWrite(self): with self.test_session() as sess: contents = "ASDASDASDASDASDAS" filename = "iptf://repo/root/foo" meta_filename = "iptf://meta/repo/root/foo" wf = io_ops.write_file(filename=constant_op.constant(filename), contents=constant_op.constant(contents)) reader = io_ops.WholeFileReader("test_reader") queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) queue.enqueue_many([[filename]]).run() queue.close().run() with sess.graph.control_dependencies([wf]): key, value = sess.run(reader.read(queue)) self.assertEqual(key, compat.as_bytes(filename)) self.assertEqual(value, compat.as_bytes(contents)) queue2 = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) queue2.enqueue_many([[meta_filename]]).run() queue2.close().run() key, value = sess.run(reader.read(queue2)) d = json.loads(compat.as_str(value)) ipfs_path = d["IpfsPath"] queue3 = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) queue3.enqueue_many([[ipfs_path]]).run() queue3.close().run() with sess.graph.control_dependencies([wf]): key, value = sess.run(reader.read(queue3)) self.assertEqual(key, compat.as_bytes(ipfs_path)) self.assertEqual(value, compat.as_bytes(contents)) with gfile.Open(meta_filename, "wb") as f: f.write(compat.as_bytes('{"command": "publish"}')) ipns_path = d["IpnsPath"] queue4 = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) queue4.enqueue_many([[ipns_path]]).run() queue4.close().run() with sess.graph.control_dependencies([wf]): key, value = sess.run(reader.read(queue4)) self.assertEqual(key, compat.as_bytes(ipns_path)) self.assertEqual(value, compat.as_bytes(contents))
def canonicalize_signatures(signatures): """Converts `signatures` into a dictionary of concrete functions.""" if signatures is None: return {} if not isinstance(signatures, collections.Mapping): signatures = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures } concrete_signatures = {} for signature_key, function in signatures.items(): signature_function = _get_signature(function) if signature_function is None: raise ValueError(( "Expected a TensorFlow function to generate a signature for, but " "got {}. Only `tf.functions` with an input signature or " "concrete functions can be used as a signature." ).format(function)) # Re-wrap the function so that it only takes keyword arguments and it # returns a dictionary of Tensors. This matches the format of 1.x-style # signatures. # pylint: disable=cell-var-from-loop @def_function.function def signature_wrapper(**kwargs): structured_outputs = signature_function(**kwargs) return _normalize_outputs(structured_outputs, signature_function.name, signature_key) # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names # always match keyword arguments. tensor_spec_signature = {} for keyword, tensor in zip( signature_function._arg_keywords, # pylint: disable=protected-access signature_function.inputs): keyword = compat.as_str(keyword) tensor_spec_signature[ keyword] = tensor_spec.TensorSpec.from_tensor(tensor, name=keyword) concrete_signatures[signature_key] = ( signature_wrapper.get_concrete_function(**tensor_spec_signature)) # pylint: enable=cell-var-from-loop return concrete_signatures
def imperative_grad( tape, target, sources, output_gradients=None, unconnected_gradients=gradients_impl.UnconnectedGradients.NONE): """Computes gradients from the imperatively defined tape on top of the stack. Works by filtering the tape, computing how many downstream usages are of each tensor and entry, and repeatedly applying backward functions until we have gradients for all sources. Args: tape: the gradient tape which stores the trace. target: either a Tensor or list of Tensors to be differentiated. sources: list of Tensors for which we want gradients output_gradients: if not None, a list of gradient provided for each Target, or None if we are to use the target's computed downstream gradient. unconnected_gradients: determines the value returned if the target and sources are unconnected. When 'none' the value returned is None wheras when 'zero' a zero tensor in the same shape as the sources is returned. Returns: the gradient wrt each of the sources. Raises: ValueError: if the arguments are invalid. RuntimeError: if something goes wrong. """ try: unconnected_gradients = gradients_impl.UnconnectedGradients( unconnected_gradients) except ValueError: raise ValueError( "Unknown value for unconnected_gradients: %r" % unconnected_gradients) return pywrap_tensorflow.TFE_Py_TapeGradient( tape._tape, # pylint: disable=protected-access target, sources, output_gradients, compat.as_str(unconnected_gradients.value))
def get_matching_files(filename): """Returns a list of files that match the given pattern. Args: filename: string, the pattern Returns: Returns a list of strings containing filenames that match the given pattern. Raises: errors.OpError: If there are filesystem / directory listing errors. """ with errors.raise_exception_on_not_ok_status() as status: # Convert each element to string, since the return values of the # vector of string should be interpreted as strings, not bytes. return [ compat.as_str(matching_filename) for matching_filename in pywrap_tensorflow.GetMatchingFiles( compat.as_bytes(filename), status) ]
def _recreate_base_user_object(self, proto): revived_classes = { '_tf_keras_layer': (RevivedLayer, base_layer.Layer), '_tf_keras_network': (RevivedNetwork, network_lib.Network), '_tf_keras_model': (RevivedModel, training_lib.Model), '_tf_keras_sequential': (RevivedSequential, models_lib.Sequential) } parent_classes = revived_classes.get(proto.identifier, None) if parent_classes is not None: parent_classes = revived_classes[proto.identifier] metadata = json.loads(proto.metadata) revived_cls = type(compat.as_str(metadata['class_name']), parent_classes, {'__setattr__': parent_classes[1].__setattr__}) obj = revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access return obj, revived_cls._revive_setter # pylint: disable=protected-access return super(KerasObjectLoader, self)._recreate_base_user_object(proto)
def _clean_save_and_restore(graph_def, op, removed_op_names): """Clean the specified save and restore op. Updates the dtypes attribute of the save / restore op and the associated name and shape tensors to remove entries for variables that have been removed. Args: graph_def: A GraphDef proto to be transformed. op: The save or restore op to update. removed_op_names: List of op names that have been removed. """ name = op.name + '/tensor_names' shape = op.name + '/shape_and_slices' name_op = _find_op(graph_def, name) shape_op = _find_op(graph_def, shape) name_op_value_tensor = name_op.attr['value'].tensor shape_op_value_tensor = shape_op.attr['value'].tensor names = [] shapes = [] dtypes = [] for index, value in enumerate(name_op_value_tensor.string_val): if not _is_removed(compat.as_str(value), removed_op_names): names.append(value) shapes.append(shape_op_value_tensor.string_val[index]) dtypes.append(op.attr['dtypes'].list.type[index]) name_op_value_tensor.string_val[:] = names name_op_value_tensor.tensor_shape.dim[0].size = len(names) shape_op_value_tensor.string_val[:] = shapes shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes) op.attr['dtypes'].list.type[:] = dtypes if not name_op.attr['_output_shapes'].list.shape: name_op.attr['_output_shapes'].list.shape.add() name_op.attr['_output_shapes'].list.shape[0].dim.add() name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names) if not shape_op.attr['_output_shapes'].list.shape: shape_op.attr['_output_shapes'].list.shape.add() shape_op.attr['_output_shapes'].list.shape[0].dim.add() shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
def revive_custom_object(identifier, metadata): """Revives object from SavedModel.""" if ops.executing_eagerly_outside_functions(): model_class = training_lib.Model else: model_class = training_lib_v1.Model revived_classes = { '_tf_keras_layer': (RevivedLayer, base_layer.Layer), '_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer), '_tf_keras_network': (RevivedNetwork, network_lib.Network), '_tf_keras_model': (RevivedNetwork, model_class), '_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential) } parent_classes = revived_classes.get(identifier, None) if parent_classes is not None: parent_classes = revived_classes[identifier] revived_cls = type(compat.as_str(metadata['class_name']), parent_classes, {}) return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access
def _revive_graph_network(self, identifier, metadata, node_id): """Revives a graph network from config.""" # Determine whether the metadata contains information for reviving a # functional or Sequential model. config = metadata.get('config') if not generic_utils.validate_config(config): return None class_name = compat.as_str(metadata['class_name']) if generic_utils.get_registered_object(class_name) is not None: return None model_is_functional_or_sequential = ( metadata.get('is_graph_network', False) or class_name == 'Sequential' or class_name == 'Functional') if not model_is_functional_or_sequential: return None # Revive functional and sequential models as blank model objects for now ( # must be initialized to enable setattr tracking and attribute caching). # Reconstruction of the network is deferred until all of the model's layers # have been revived. if class_name == 'Sequential': model = models_lib.Sequential(name=config['name']) # The model is a custom Sequential model. elif identifier == constants.SEQUENTIAL_IDENTIFIER: # Uses the custom class name, since the config does not have one. model = models_lib.Sequential(name=class_name) else: model = models_lib.Functional( inputs=[], outputs=[], name=config['name']) # Record this model and its layers. This will later be used to reconstruct # the model. layers = self._get_child_layer_node_ids(node_id) self.model_layer_dependencies[node_id] = (model, layers) if not layers: self._models_to_reconstruct.append(node_id) return model
def _revive_graph_network(self, metadata, node_id): """Revives a graph network from config.""" class_name = compat.as_str(metadata['class_name']) config = metadata.get('config') # Determine whether the metadata contains information for reviving a # functional or Sequential model. model_is_functional_or_sequential = ( metadata.get('is_graph_network', False) or metadata['class_name'] == 'Sequential' or metadata['class_name'] == 'Functional') if not ( generic_utils.validate_config(config) and model_is_functional_or_sequential ) or generic_utils.get_registered_object(class_name) is not None: # Model should not be revived as a graph network. Try reviving directly # from config or as a custom model. return None # Revive functional and sequential models as blank model objects for now ( # must be initialized to enable setattr tracking and attribute caching). # Reconstruction of the network is deferred until all of the model's layers # have been revived. if class_name == 'Sequential': model = models_lib.Sequential(name=config['name']) else: model = models_lib.Functional(inputs=[], outputs=[], name=config['name']) # Record this model and its layers. This will later be used to reconstruct # the model. layers = self._get_child_layer_node_ids(node_id, model.name) self.model_layer_dependencies[node_id] = (model, layers) if not layers: self._models_to_reconstruct.append(node_id) return model
def get_timestamped_export_dir(export_dir_base): """Builds a path to a new subdirectory within the base directory. Each export is written into a new subdirectory named using the current time. This guarantees monotonically increasing version numbers even across multiple runs of the pipeline. The timestamp used is the number of seconds since epoch UTC. Args: export_dir_base: A string containing a directory to write the exported graph and checkpoints. Returns: The full path of the new subdirectory (which is not actually created yet). Raises: RuntimeError: if repeated attempts fail to obtain a unique timestamped directory name. """ attempts = 0 while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS: timestamp = int(time.time()) result_dir = os.path.join(compat.as_bytes(export_dir_base), compat.as_bytes(str(timestamp))) if not gfile.Exists(result_dir): # Collisions are still possible (though extremely unlikely): this # directory is not actually created yet, but it will be almost # instantly on return from this function. return result_dir time.sleep(1) attempts += 1 logging.warn( 'Directory {} already exists; retrying (attempt {}/{})'.format( compat.as_str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) raise RuntimeError('Failed to obtain a unique export directory name after ' '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
def _recreate_base_user_object(self, proto): if ops.executing_eagerly_outside_functions(): model_class = training_lib.Model else: model_class = training_lib_v1.Model revived_classes = { '_tf_keras_layer': (RevivedLayer, base_layer.Layer), '_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer), '_tf_keras_network': (RevivedNetwork, network_lib.Network), '_tf_keras_model': (RevivedNetwork, model_class), '_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential) } parent_classes = revived_classes.get(proto.identifier, None) if parent_classes is not None: parent_classes = revived_classes[proto.identifier] metadata = json.loads(proto.metadata) revived_cls = type( compat.as_str(metadata['class_name']), parent_classes, {}) return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access return super(KerasObjectLoader, self)._recreate_base_user_object(proto)
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value if not input_map or v not in input_map])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. variable_objects = {} for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto( proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def meta_graph_transform(base_meta_graph_def, input_names, output_names, transforms, tags, checkpoint_path=None): """Apply the Graph Transform tool to a MetaGraphDef. Args: base_meta_graph_def: A MetaGraphDef protocol buffer to transform. input_names: Names of input nodes. output_names: Names of output nodes. transforms: A list of strings naming the graph transforms to be applied in order. These transform names are exactly those supported by the Graph Transform Tool, with the addition of the 'freeze_graph' transform. tags: A list of tags with which to annotate the transformed MetaGraphDef. checkpoint_path: A path to a checkpoint to restore during freezing, if needed (default None). Returns: A new transformed MetaGraphDef protocol buffer. """ meta_graph_def = _meta_graph_pb2.MetaGraphDef() initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def) transformed_graph_def, updated_initializer_names = _do_transforms( base_meta_graph_def.graph_def, input_names, output_names, initializer_names, transforms, base_meta_graph_def.saver_def, checkpoint_path) meta_graph_def.graph_def.CopyFrom(transformed_graph_def) meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def) meta_graph_def.meta_info_def.ClearField('tags') for tag in tags: meta_graph_def.meta_info_def.tags.append(tag) base_op_names = [ compat.as_str(node.name) for node in base_meta_graph_def.graph_def.node ] retained_op_names = [ compat.as_str(node.name) for node in meta_graph_def.graph_def.node ] removed_op_names = set(base_op_names) - set(retained_op_names) # Copy saver, excluding any pruned nodes if graph was not frozen. # TODO(b/63447631): Revisit this once the problem is addressed. Currently # _add_pruned_saver assumes that the save and restore nodes have not been # removed but freeze_graph (correctly) removes them. if _FREEZE_GRAPH_TRANSFORM not in transforms: _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names) # Copy collections, excluding any pruned nodes for collection_name in base_meta_graph_def.collection_def: _add_pruned_collection(base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) # Append newly added initalizers to collection. _add_new_inits_to_collection(meta_graph_def, updated_initializer_names) # Copy signature_defs, excluding any pruned nodes for signature_name in base_meta_graph_def.signature_def: _add_pruned_signature(base_meta_graph_def, meta_graph_def, signature_name, removed_op_names) return meta_graph_def
def __init__(self, tpu=None, zone=None, project=None, job_name='worker', coordinator_name=None, coordinator_address=None, credentials='default', service=None, discovery_url=None): """Creates a new TPUClusterResolver object. The ClusterResolver will then use the parameters to query the Cloud TPU APIs for the IP addresses and ports of each Cloud TPU listed. Args: tpu: Either a string, or a list of strings corresponding to the TPUs to use. If the single string is the empty string, the string 'local', or a string that begins with 'grpc://' or '/bns', then it is assumed to not correspond with a Cloud TPU and will instead be passed as the session master and no ClusterSpec propagation will be done. zone: Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service. project: Name of the GCP project containing Cloud TPUs. If omitted or empty, we will try to discover the project name of the GCE VM from the GCE metadata service. job_name: Name of the TensorFlow job the TPUs belong to. coordinator_name: The name to use for the coordinator. Set to None if the coordinator should not be included in the computed ClusterSpec. coordinator_address: The address of the coordinator (typically an ip:port pair). If set to None, a TF server will be started. If coordinator_name is None, a TF server will not be started even if coordinator_address is None. credentials: GCE Credentials. If None, then we use default credentials from the oauth2client service: The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored. discovery_url: A URL template that points to the location of the discovery service. It should have two parameters {api} and {apiVersion} that when filled in produce an absolute URL to the discovery document for that service. The environment variable 'TPU_API_DISCOVERY_URL' will override this. Raises: ImportError: If the googleapiclient is not installed. ValueError: If no TPUs are specified. """ if isinstance(tpu, list): if not tpu: raise ValueError('At least one TPU must be specified.') if len(tpu) != 1: raise NotImplementedError( 'Using multiple TPUs in a single session is not yet implemented') tpu = tpu[0] in_gke = self._inGke() # When using GKE with Cloud TPUs, the env variable will be set. if tpu is None: if in_gke: tpu = self._gkeEndpoints() else: tpu = self._envVarFallback() if tpu is None: raise ValueError('Please provide a TPU Name to connect to.') self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name self._credentials = credentials should_resolve = self._shouldResolve() if not project and should_resolve: project = compat.as_str( self._requestComputeMetadata('project/project-id')) if not zone and should_resolve: zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone = zone_path.split('/')[-1] self._project = project self._zone = zone if credentials == 'default' and should_resolve: if _GOOGLE_API_CLIENT_INSTALLED: self._credentials = GoogleCredentials.get_application_default() if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: raise ImportError('googleapiclient and oauth2client must be installed ' 'before using the TPU cluster resolver. Execute: ' '`pip install --upgrade google-api-python-client` ' 'and `pip install --upgrade oauth2client` to ' 'install with pip.') final_discovery_url = self._discoveryUrl() or discovery_url if final_discovery_url: self._service = discovery.build( 'tpu', 'v1alpha1', credentials=self._credentials, discoveryServiceUrl=final_discovery_url) else: self._service = discovery.build( 'tpu', 'v1alpha1', credentials=self._credentials) else: self._service = service self._coordinator_name = coordinator_name if coordinator_name and not coordinator_address and (should_resolve or in_gke): self._start_local_server() else: self._coordinator_address = coordinator_address
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, attrs for ops in `graph_def` that are not in `op_dict` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: if not (isinstance(input_map, dict) and all(isinstance(k, compat.bytes_or_text_types) for k in input_map.keys())): raise TypeError('input_map must be a dictionary mapping strings to ' 'Tensor objects.') if return_elements is not None: return_elements = tuple(return_elements) if not all(isinstance(x, compat.bytes_or_text_types) for x in return_elements): raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} if op_dict is None: op_dict = op_def_registry.get_registered_ops() if producer_op_list is None: producer_op_dict = None else: producer_op_dict = {op.name: op for op in producer_op_list.op} g = ops.get_default_graph() # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original op_dict = copy.copy(op_dict) # pylint: disable=protected-access # Note that we do not prepend `name` to the function name. The reasoning is # that function names are similar to op definition names, which currently do # not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: g._add_function(f) op_dict[f.name] = f.definition.signature # pylint: enable=protected-access # LINT.IfChange with ops.name_scope(name, 'import', input_map.values()) as scope: # TODO(ashankar): Should this just copy over or should it do some # more nuanced merging? For example, the graph may already have some # marked "bad versions" and we don't want to lose those because of # what's in graph_def.versions? The C++ ImporGraphDef does something # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) if not all(isinstance(v, ops.Tensor) for v in input_map.values()): if not scope: # The caller must have passed `name=''`. raise ValueError( 'tf.import_graph_def() requires a non-empty `name` if `input_map` ' 'contains non-Tensor values. Try calling tf.convert_to_tensor() on ' '`input_map` values before calling tf.import_graph_def().') with ops.name_scope('_inputs'): input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Set any default attr values that aren't present. if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] for attr_def in op_def.attr: key = attr_def.name if attr_def.HasField('default_value'): value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) if producer_op_dict: # Remove any default attr values that aren't in op_def. if node.op in producer_op_dict: producer_op_def = producer_op_dict[node.op] # We make a copy of node.attr to iterate through since we # may modify node.attr inside the loop. for key in list(node.attr): if _FindAttrInOpDef(key, op_def) is None: # No attr_def in consumer, look in producer. attr_def = _FindAttrInOpDef(key, producer_op_def) if (attr_def and attr_def.HasField('default_value') and node.attr[key] == attr_def.default_value): # Unknown attr had default value in producer, delete it # so it can be understood by consumer. del node.attr[key] output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError('Specified colocation to an op that ' 'does not exist during import: %s in %s' % ( op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append(compat.as_bytes( 'loc:@' + original_op.name)) else: new_class_values.append(class_value) value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not appear # in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name,))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError(_InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name,))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] source_tensor = list(source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name,))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError(_InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected-access if op._input_dtypes != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join(dtypes.as_dtype(x).name for x in input_types), ', '.join(x.name for x in op._input_dtypes)))) # pylint: enable=protected-access if not g._is_function(op.type): # pylint: disable=protected-access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape information # may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [dim.size if dim.size >= 0 else None for dim in dims.dim]) try: output.set_shape(output_shape) except ValueError as e: # If the output shape is incompatible with what is inferred # by the graph for a very specific whitelist of ops, then we # ignore this output shape. This can happen if there is a # bug in the shape function for some operation, and the # serialized graph def has the incorrect shape set when # running on a newer binary with the fixed shape function. # This is an escape hatch that allows us to correct shape # functions that are not critical to correct execution but # would cause graphs to fail if imported after correcting. # # This can be removed after 2017/03/08. if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue', 'FIFOQueue', 'PriorityQueue', 'QueueSize', 'Stack', 'Barrier', 'BarrierReadySize', 'BarrierIncompleteSize', 'HashTable', 'MutableHashTable', 'MutableHashTableOfTensors', 'Mutex', 'CuckooTable', 'IndexTable', 'WholeFileReader', 'TextLineReader', 'FixedLengthRecordReader', 'TFRecordReader', 'IdentityReader', 'RefSwitch', 'RefEnter', 'RefNextIteration', 'RefMerge', 'RefIdentity']: pass elif op.type in [ 'ConditionalAccumulator', 'SparseConditionalAccumulator', 'Table' ]: # This can be removed after 2017/04/24. pass else: raise e del op.node_def.attr['_output_shapes'] # Apply device functions for this op. # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # Treat unused input mappings as an error, because they are likely to be # due to a typo. unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys) if unused_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(unused_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append(name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def _CanonicalInputName(input_name): input_name = compat.as_str(input_name) if _IsControlInput(input_name): return input_name input_op_name, output_index = _ParseTensorName(input_name) return '%s:%d' % (input_op_name, output_index)
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ graph_def = _ProcessGraphDefParam(graph_def) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) op_dict = op_def_registry.get_registered_ops() if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() if graph._c_graph: # pylint: disable=protected-access with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: with errors.raise_exception_on_not_ok_status() as status: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options, status) # pylint: disable=protected-access except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) _ProcessNewOps(graph) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results)) if missing_unused_input_keys: missing_unused_input_keys = [compat.as_str(s) for s in missing_unused_input_keys] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results) else: g = graph # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original op_dict = copy.copy(op_dict) # pylint: disable=protected-access # Note that we do not prepend `name` to the function name. The reasoning # is that function names are similar to op definition names, which # currently do not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(g) op_dict[f.name] = f.definition.signature # pylint: enable=protected-access # LINT.IfChange with ops.name_scope(name, 'import', input_map.values()) as scope: # TODO(ashankar): Should this just copy over or should it do some # more nuanced merging? For example, the graph may already have some # marked "bad versions" and we don't want to lose those because of # what's in graph_def.versions? The C++ ImporGraphDef does something # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) input_map = _ConvertInputMapValues(name, input_map) # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Check to see if this op's name matches a previously seen op if node.name in name_to_op: raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) # Set any default attr values that aren't present. if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] for attr_def in op_def.attr: key = attr_def.name if attr_def.HasField('default_value'): value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # Maps from a node to the ops it is colocated with, if colocation # is specified in the attributes. colocation_pairs = collections.defaultdict(list) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) apply_device_function = True # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError('Specified colocation to an op that ' 'does not exist during import: %s in %s' % ( op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append(compat.as_bytes( 'loc:@' + original_op.name)) if op_to_bind_to != node.name: # Keep track of this mapping for a later phase. colocation_pairs[op].append(original_op) # Don't apply this op's device function, # the colocation constraint will ensure # the proper device gets assigned at runtime. apply_device_function = False else: new_class_values.append(class_value) value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not # appear in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name,))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError(_InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name,))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] source_tensor = list(source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name,))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError(_InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected-access if op._input_types != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join(dtypes.as_dtype(x).name for x in input_types), ', '.join(x.name for x in op._input_types)))) # pylint: enable=protected-access if not g._is_function(op.type): # pylint: disable=protected-access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape # information may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [dim.size if dim.size >= 0 else None for dim in dims.dim]) try: output.set_shape(output_shape) except ValueError as e: # If the output shape is incompatible with what is inferred # by the graph for a very specific whitelist of ops, then we # ignore this output shape. This can happen if there is a # bug in the shape function for some operation, and the # serialized graph def has the incorrect shape set when # running on a newer binary with the fixed shape function. # This is an escape hatch that allows us to correct shape # functions that are not critical to correct execution but # would cause graphs to fail if imported after correcting. # # This can be removed after 2017/03/08. if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue', 'FIFOQueue', 'PriorityQueue', 'QueueSize', 'Stack', 'Barrier', 'BarrierReadySize', 'BarrierIncompleteSize', 'HashTable', 'MutableHashTable', 'MutableHashTableOfTensors', 'Mutex', 'CuckooTable', 'IndexTable', 'WholeFileReader', 'TextLineReader', 'FixedLengthRecordReader', 'TFRecordReader', 'IdentityReader', 'LMDBReader', 'RefSwitch', 'RefEnter', 'RefNextIteration', 'RefMerge', 'RefIdentity']: pass elif op.type in [ 'ConditionalAccumulator', 'SparseConditionalAccumulator', 'Table' ]: # This can be removed after 2017/04/24. pass else: raise e del op.node_def.attr['_output_shapes'] # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. if apply_device_function: with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # The following loop populates the device field of ops that are # colocated with another op. This is implied by the colocation # attribute, but we propagate the device field for completeness. for op, coloc_op_list in colocation_pairs.items(): coloc_device = None # Find any device in the list of colocated ops that have a # device, if it exists. We assume that if multiple ops # have devices, they refer to the same device. Otherwise, a # runtime error will occur since the colocation property # cannot be guaranteed. # # One possible improvement is to try to check for compatibility # of all devices in this list at import time here, which would # require implementing a compatibility function for device specs # in python. for coloc_op in coloc_op_list: if coloc_op.device: coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) break if coloc_device: op._set_device(coloc_device) # pylint: disable=protected-access # Treat input mappings that don't appear in the graph as an error, # because they are likely to be due to a typo. def _IsImportedNodeOutput(tensor_name): operation_name, output_index = _ParseTensorName(tensor_name) try: return output_index < len(name_to_op[operation_name].outputs) except KeyError: return False absent_input_keys = [ k for k in frozenset(input_map.keys()).difference(used_input_keys) if not _IsImportedNodeOutput(k)] if absent_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(absent_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append(name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return temp_graph = func_graph_from_py_func( self._func, self._arg_names, self._arg_types, self._func_name, self._capture_by_value, self._caller_device) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name else: base_func_name = _get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), temp_graph.inputs, temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in temp_graph.inputs], [t._as_tf_output() for t in temp_graph.outputs], output_names, None, # opts description) self._c_func = c_api_util.ScopedTFFunction(c_func) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name) self._stateful_ops = [(op.name, op.type) for op in temp_graph.get_operations() if op.op_def.is_stateful]
def meta_graph_transform( base_meta_graph_def, input_names, output_names, transforms, tags, checkpoint_path=None): """Apply the Graph Transform tool to a MetaGraphDef. Args: base_meta_graph_def: A MetaGraphDef protocol buffer to transform. input_names: Names of input nodes. output_names: Names of output nodes. transforms: A list of strings naming the graph transforms to be applied in order. These transform names are exactly those supported by the Graph Transform Tool, with the addition of the 'freeze_graph' and 'sparsify_gather' transforms. tags: A list of tags with which to annotate the transformed MetaGraphDef. checkpoint_path: A path to a checkpoint to restore during freezing, if needed (default None). Returns: A new transformed MetaGraphDef protocol buffer. """ meta_graph_def = _meta_graph_pb2.MetaGraphDef() initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def) transformed_graph_def, updated_initializer_names = _do_transforms( base_meta_graph_def.graph_def, input_names, output_names, initializer_names, transforms, base_meta_graph_def.saver_def, checkpoint_path) meta_graph_def.graph_def.CopyFrom(transformed_graph_def) meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def) meta_graph_def.meta_info_def.ClearField('tags') for tag in tags: meta_graph_def.meta_info_def.tags.append(tag) base_op_names = [compat.as_str(node.name) for node in base_meta_graph_def.graph_def.node] retained_op_names = [compat.as_str(node.name) for node in meta_graph_def.graph_def.node] removed_op_names = set(base_op_names) - set(retained_op_names) # Copy saver, excluding any pruned nodes if graph was not frozen. # TODO(b/63447631): Revisit this once the problem is addressed. Currently # _add_pruned_saver assumes that the save and restore nodes have not been # removed but freeze_graph (correctly) removes them. if _FREEZE_GRAPH_TRANSFORM not in transforms: _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names) # Copy collections, excluding any pruned nodes for collection_name in base_meta_graph_def.collection_def: _add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) # Append newly added initializers to collection. _add_new_inits_to_collection(meta_graph_def, updated_initializer_names) # Copy signature_defs, excluding any pruned nodes for signature_name in base_meta_graph_def.signature_def: _add_pruned_signature( base_meta_graph_def, meta_graph_def, signature_name, removed_op_names) return meta_graph_def
def __init__(self, tpu=None, zone=None, project=None, job_name='worker', coordinator_name=None, coordinator_address=None, credentials='default', service=None, discovery_url=None): """Creates a new TPUClusterResolver object. The ClusterResolver will then use the parameters to query the Cloud TPU APIs for the IP addresses and ports of each Cloud TPU listed. Args: tpu: A string corresponding to the TPU to use. If the string is the empty string, the string 'local', or a string that begins with 'grpc://' or '/bns', then it is assumed to not correspond with a Cloud TPU and will instead be passed as the session master and no ClusterSpec propagation will be done. In the future, this may also support a list of strings when multiple Cloud TPUs are used. zone: Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service. project: Name of the GCP project containing Cloud TPUs. If omitted or empty, we will try to discover the project name of the GCE VM from the GCE metadata service. job_name: Name of the TensorFlow job the TPUs belong to. coordinator_name: The name to use for the coordinator. Set to None if the coordinator should not be included in the computed ClusterSpec. coordinator_address: The address of the coordinator (typically an ip:port pair). If set to None, a TF server will be started. If coordinator_name is None, a TF server will not be started even if coordinator_address is None. credentials: GCE Credentials. If None, then we use default credentials from the oauth2client service: The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored. discovery_url: A URL template that points to the location of the discovery service. It should have two parameters {api} and {apiVersion} that when filled in produce an absolute URL to the discovery document for that service. The environment variable 'TPU_API_DISCOVERY_URL' will override this. Raises: ImportError: If the googleapiclient is not installed. ValueError: If no TPUs are specified. RuntimeError: If an empty TPU name is specified and this is running in a Google Cloud environment. """ if isinstance(tpu, list): if not tpu: raise ValueError('At least one TPU must be specified.') if len(tpu) != 1: raise NotImplementedError( 'Using multiple TPUs in a single session is not yet implemented') tpu = tpu[0] in_gke = self._inGke() # When using GKE with Cloud TPUs, the env variable will be set. if tpu is None: if in_gke: tpu = self._gkeEndpoints() else: tpu = self._envVarFallback() if tpu is None: raise ValueError('Please provide a TPU Name to connect to.') self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes # If we are running in Cloud and don't specify a TPU name if self._isRunningInGCE() and not self._tpu: raise RuntimeError('You need to specify a TPU Name if you are running in ' 'the Google Cloud environment.') # By default the task_type is 'worker` and the task_id is 0 (which is the # first worker in the task). self.task_type = job_name self.task_id = 0 if tpu.startswith('grpc://'): # Cloud environment, where we are using GRPC to communicate to TPUs. self._environment = '' elif tpu == 'local' or not tpu: # Google environment, where the TPU is attached to the host. self._environment = 'google' elif tpu.startswith('/bns') or tpu.startswith('uptc://'): # Google environment, where we reach the TPU through BNS. self._environment = 'google' # If TPU is in the Google environment or exists locally, we don't use any # RPC layer. if tpu.startswith('/bns') or tpu.startswith( 'uptc://') or tpu == 'local' or not tpu: self.rpc_layer = None else: self.rpc_layer = 'grpc' # Setting this overrides the return value of self._shouldResolve() self._should_resolve_override = None # We strip out the protocol if it is included, and override the # shouldResolve function to never resolve. We are adding the protocol back # in later in self.master(). if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'): tpu = tpu[len(self.rpc_layer + '://'):] self._tpu = tpu self._should_resolve_override = False # Whether we should actually attempt to contact Cloud APIs should_resolve = self._shouldResolve() # We error out if we are in a non-Cloud environment which cannot talk to the # Cloud APIs using the standard class and a special object is not passed in. self._service = service if (self._service is None and should_resolve and not _GOOGLE_API_CLIENT_INSTALLED): raise ImportError('googleapiclient and oauth2client must be installed ' 'before using the TPU cluster resolver. Execute: ' '`pip install --upgrade google-api-python-client` ' 'and `pip install --upgrade oauth2client` to ' 'install with pip.') # We save user-passed credentials, unless the user didn't pass in anything. self._credentials = credentials if (credentials == 'default' and should_resolve and _GOOGLE_API_CLIENT_INSTALLED): self._credentials = None # Automatically detect project and zone if unspecified. if not project and should_resolve: project = compat.as_str( self._requestComputeMetadata('project/project-id')) if not zone and should_resolve: zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone = zone_path.split('/')[-1] self._project = project self._zone = zone self._discovery_url = self._environmentDiscoveryUrl() or discovery_url self._coordinator_name = coordinator_name if (coordinator_name and not coordinator_address and (should_resolve or in_gke)): self._start_local_server() else: self._coordinator_address = coordinator_address