def load_model(saved_model_path): """Load a keras.Model from SavedModel. load_model reinstantiates model state by: 1) loading model topology from json (this will eventually come from metagraph). 2) loading model weights from checkpoint. Args: saved_model_path: a string specifying the path to an existing SavedModel. Returns: a keras.Model instance. """ # restore model topology from json string model_json_filepath = os.path.join( compat.as_bytes(saved_model_path), compat.as_bytes(constants.ASSETS_DIRECTORY), compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) model_json = file_io.read_file_to_string(model_json_filepath) model = model_from_json(model_json) # restore model weights checkpoint_prefix = os.path.join( compat.as_text(saved_model_path), compat.as_text(constants.VARIABLES_DIRECTORY), compat.as_text(constants.VARIABLES_FILENAME)) model.load_weights(checkpoint_prefix) return model
def testFormatOneTensorOneDimVarySummarize(self): with self.test_session(): tensor = math_ops.range(6) format_output = string_ops.string_format("{}", tensor, summarize=-1) out = self.evaluate(format_output) expected = "[0 1 2 3 4 5]" self.assertEqual(compat.as_text(out), expected) with self.test_session(): tensor = math_ops.range(6) format_output = string_ops.string_format("{}", tensor, summarize=1) out = self.evaluate(format_output) expected = "[0 ... 5]" self.assertEqual(compat.as_text(out), expected) with self.test_session(): tensor = math_ops.range(6) format_output = string_ops.string_format("{}", tensor, summarize=2) out = self.evaluate(format_output) expected = "[0 1 ... 4 5]" self.assertEqual(compat.as_text(out), expected) with self.test_session(): tensor = math_ops.range(6) format_output = string_ops.string_format("{}", tensor, summarize=10) out = self.evaluate(format_output) expected = "[0 1 2 3 4 5]" self.assertEqual(compat.as_text(out), expected)
def add_meta_graph_and_variables(self, sess, tags, signature_def_map=None, assets_collection=None, legacy_init_op=None): """Adds the current meta graph to the SavedModel and saves variables. Creates a Saver to save the variables from the provided session. Exports the corresponding meta graph def. This function assumes that the variables to be saved have been initialized. For a given `SavedModelBuilder`, this API must be called exactly once and for the first meta graph to save. For subsequent meta graph defs to be added, the `add_meta_graph()` API must be used. Args: sess: The TensorFlow session from which to save the meta graph and variables. tags: The set of tags with which to save the meta graph. signature_def_map: The map of signature def map to add to the meta graph def. assets_collection: Assets collection to be saved with SavedModel. legacy_init_op: Op or group of ops to execute after the restore op upon a load. """ if self._has_saved_variables: raise AssertionError("Variables and assets have already been saved. " "Please invoke `add_meta_graph()` instead.") # Save asset files and write them to disk, if any. self._save_and_write_assets(assets_collection) # Create the variables sub-directory, if it does not exist. variables_dir = os.path.join( compat.as_text(self._export_dir), compat.as_text(constants.VARIABLES_DIRECTORY)) if not file_io.file_exists(variables_dir): file_io.recursive_create_dir(variables_dir) variables_path = os.path.join( compat.as_text(variables_dir), compat.as_text(constants.VARIABLES_FILENAME)) # Add legacy init op to the SavedModel. self._maybe_add_legacy_init_op(legacy_init_op) # Save the variables and export meta graph def. saver = tf_saver.Saver( variables.all_variables(), sharded=True, write_version=saver_pb2.SaverDef.V2) saver.save(sess, variables_path, write_meta_graph=False) meta_graph_def = saver.export_meta_graph() # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) # Mark this instance of SavedModel as having saved variables, such that # subsequent attempts to save variables will fail. self._has_saved_variables = True
def _do_run(self, target_list, fetch_list, feed_dict): """Runs a step based on the given fetches and feeds. Args: target_list: A list of byte arrays corresponding to names of tensors or operations to be run to, but not fetched. fetch_list: A list of byte arrays corresponding to names of tensors to be fetched and operations to be run. feed_dict: A dictionary that maps tensor names (as byte arrays) to numpy ndarrays. Returns: A list of numpy ndarrays, corresponding to the elements of `fetch_list`. If the ith element of `fetch_list` contains the name of an operation, the first Tensor output of that operation will be returned for that element. """ try: # Ensure any changes to the graph are reflected in the runtime. with self._extend_lock: if self._graph.version > self._current_version: graph_def = self._graph.as_graph_def( from_version=self._current_version) try: status = tf_session.TF_NewStatus() tf_session.TF_ExtendGraph( self._session, graph_def.SerializeToString(), status) if tf_session.TF_GetCode(status) != 0: raise RuntimeError(compat.as_text(tf_session.TF_Message(status))) self._opened = True finally: tf_session.TF_DeleteStatus(status) self._current_version = self._graph.version return tf_session.TF_Run(self._session, feed_dict, fetch_list, target_list) except tf_session.StatusNotOK as e: e_type, e_value, e_traceback = sys.exc_info() error_message = compat.as_text(e.error_message) m = BaseSession._NODEDEF_NAME_RE.search(error_message) if m is not None: node_name = m.group(1) node_def = None try: op = self._graph.get_operation_by_name(node_name) node_def = op.node_def except KeyError: op = None # pylint: disable=protected-access raise errors._make_specific_exception(node_def, op, error_message, e.code) # pylint: enable=protected-access six.reraise(e_type, e_value, e_traceback)
def load_from_saved_model(saved_model_path, custom_objects=None): """Loads a keras Model from a SavedModel created by `export_saved_model()`. This function reinstantiates model state by: 1) loading model topology from json (this will eventually come from metagraph). 2) loading model weights from checkpoint. Example: ```python import tensorflow as tf # Create a tf.keras model. model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(1, input_shape=[10])) model.summary() # Save the tf.keras model in the SavedModel format. path = '/tmp/simple_keras_model' tf.keras.experimental.export_saved_model(model, path) # Load the saved keras model back. new_model = tf.keras.experimental.load_from_saved_model(path) new_model.summary() ``` Args: saved_model_path: a string specifying the path to an existing SavedModel. custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. Returns: a keras.Model instance. """ # restore model topology from json string model_json_filepath = os.path.join( compat.as_bytes(saved_model_path), compat.as_bytes(constants.ASSETS_DIRECTORY), compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) model_json = file_io.read_file_to_string(model_json_filepath) model = model_from_json(model_json, custom_objects=custom_objects) # restore model weights checkpoint_prefix = os.path.join( compat.as_text(saved_model_path), compat.as_text(constants.VARIABLES_DIRECTORY), compat.as_text(constants.VARIABLES_FILENAME)) model.load_weights(checkpoint_prefix) return model
def _start_local_server(self): address = self._requestComputeMetadata('instance/network-interfaces/0/ip') self._server = server_lib.Server( { 'local': ['0.0.0.0:0'] }, protocol='grpc', config=None, start=True) # self._server.target is of the form: grpc://ipaddress:port target = compat.as_bytes(self._server.target) splits = target.split(compat.as_bytes(':')) assert len(splits) == 3, self._server.target assert splits[0] == compat.as_bytes('grpc'), self._server.target self._coordinator_port = compat.as_text(splits[2]) self._coordinator_address = '%s:%s' % ( address, compat.as_text(self._coordinator_port))
def testFormatOneTensorOneDim(self): with self.test_session(): tensor = math_ops.range(10) format_output = string_ops.string_format("{}", tensor) out = self.evaluate(format_output) expected = "[0 1 2 ... 7 8 9]" self.assertEqual(compat.as_text(out), expected) with self.test_session(): tensor = math_ops.range(10) format_output = string_ops.string_format("{}", [tensor]) out = self.evaluate(format_output) expected = "[0 1 2 ... 7 8 9]" self.assertEqual(compat.as_text(out), expected)
def save_model(model, saved_model_path): """Save a `tf.keras.Model` into Tensorflow SavedModel format. `save_model` generates such files/folders under the `saved_model_path` folder: 1) an asset folder containing the json string of the model's configuration(topology). 2) a checkpoint containing the model weights. Note that subclassed models can not be saved via this function, unless you provide an implementation for get_config() and from_config(). Also note that `tf.keras.optimizers.Optimizer` instances can not currently be saved to checkpoints. Use optimizers from `tf.train`. Args: model: A `tf.keras.Model` to be saved. saved_model_path: a string specifying the path to the SavedModel directory. Raises: NotImplementedError: If the passed in model is a subclassed model. """ if not model._is_graph_network: raise NotImplementedError # save model configuration as a json string under assets folder. model_json = model.to_json() assets_destination_dir = os.path.join( compat.as_bytes(saved_model_path), compat.as_bytes(constants.ASSETS_DIRECTORY)) if not file_io.file_exists(assets_destination_dir): file_io.recursive_create_dir(assets_destination_dir) model_json_filepath = os.path.join( compat.as_bytes(assets_destination_dir), compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) file_io.write_string_to_file(model_json_filepath, model_json) # save model weights in checkpoint format. checkpoint_destination_dir = os.path.join( compat.as_bytes(saved_model_path), compat.as_bytes(constants.VARIABLES_DIRECTORY)) if not file_io.file_exists(checkpoint_destination_dir): file_io.recursive_create_dir(checkpoint_destination_dir) checkpoint_prefix = os.path.join( compat.as_text(checkpoint_destination_dir), compat.as_text(constants.VARIABLES_FILENAME)) model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
def load_keras_model(saved_model_path): """Load a keras.Model from SavedModel. load_model reinstantiates model state by: 1) loading model topology from json (this will eventually come from metagraph). 2) loading model weights from checkpoint. Example: ```python import tensorflow as tf # Create a tf.keras model. model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(1, input_shape=[10])) model.summary() # Save the tf.keras model in the SavedModel format. saved_to_path = tf.contrib.saved_model.save_keras_model( model, '/tmp/my_simple_tf_keras_saved_model') # Load the saved keras model back. model_prime = tf.contrib.saved_model.load_keras_model(saved_to_path) model_prime.summary() ``` Args: saved_model_path: a string specifying the path to an existing SavedModel. Returns: a keras.Model instance. """ # restore model topology from json string model_json_filepath = os.path.join( compat.as_bytes(saved_model_path), compat.as_bytes(constants.ASSETS_DIRECTORY), compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) model_json = file_io.read_file_to_string(model_json_filepath) model = model_from_json(model_json) # restore model weights checkpoint_prefix = os.path.join( compat.as_text(saved_model_path), compat.as_text(constants.VARIABLES_DIRECTORY), compat.as_text(constants.VARIABLES_FILENAME)) model.load_weights(checkpoint_prefix) return model
def testWriteEvents(self): file_prefix = os.path.join(self.get_temp_dir(), "events") writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(file_prefix)) filename = compat.as_text(writer.FileName()) event_written = event_pb2.Event( wall_time=123.45, step=67, summary=summary_pb2.Summary( value=[summary_pb2.Summary.Value(tag="foo", simple_value=89.0)])) writer.WriteEvent(event_written) writer.Flush() writer.Close() with self.assertRaises(IOError): for r in tf_record.tf_record_iterator(filename + "DOES_NOT_EXIST"): self.assertTrue(False) reader = tf_record.tf_record_iterator(filename) event_read = event_pb2.Event() event_read.ParseFromString(next(reader)) self.assertTrue(event_read.HasField("file_version")) event_read.ParseFromString(next(reader)) # Second event self.assertProtoEquals(""" wall_time: 123.45 step: 67 summary { value { tag: 'foo' simple_value: 89.0 } } """, event_read) with self.assertRaises(StopIteration): next(reader)
def _export_model_json(model, saved_model_path): """Saves model configuration as a json string under assets folder.""" model_json = model.to_json() model_json_filepath = os.path.join( saved_model_utils.get_or_create_assets_dir(saved_model_path), compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) file_io.write_string_to_file(model_json_filepath, model_json)
def testZLibFlushRecord(self): fn = self._WriteRecordsToFile([b"small record"], "small_record") with open(fn, "rb") as h: buff = h.read() # creating more blocks and trailing blocks shouldn't break reads compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS) output = b"" for c in buff: if isinstance(c, int): c = six.int2byte(c) output += compressor.compress(c) output += compressor.flush(zlib.Z_FULL_FLUSH) output += compressor.flush(zlib.Z_FULL_FLUSH) output += compressor.flush(zlib.Z_FULL_FLUSH) output += compressor.flush(zlib.Z_FINISH) # overwrite the original file with the compressed data with open(fn, "wb") as h: h.write(output) with self.test_session() as sess: options = tf_record.TFRecordOptions( compression_type=TFRecordCompressionType.ZLIB) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=()) key, value = reader.read(queue) queue.enqueue(fn).run() queue.close().run() k, v = sess.run([key, value]) self.assertTrue(compat.as_text(k).startswith("%s:" % fn)) self.assertAllEqual(b"small record", v)
def testReadGzipFiles(self): files = self._CreateFiles() gzip_files = [] for i, fn in enumerate(files): with open(fn, "rb") as f: cdata = f.read() zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) with gzip.GzipFile(zfn, "wb") as f: f.write(cdata) gzip_files.append(zfn) with self.test_session() as sess: options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([gzip_files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i])) self.assertAllEqual(self._Record(i, j), v)
def _TestOneEpochWithHopBytes(self, files, num_overlapped_records, encoding=None): with self.test_session() as sess: reader = io_ops.FixedLengthRecordReader( header_bytes=self._header_bytes, record_bytes=self._record_bytes, footer_bytes=self._footer_bytes, hop_bytes=self._hop_bytes, encoding=encoding, name="test_reader") queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(num_overlapped_records): k, v = sess.run([key, value]) self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k)) self.assertAllEqual(self._OverlappedRecord(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value])
def testFormatOneTensorOneDimFloat(self): with self.test_session(): tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) format_output = string_ops.string_format("{}", tensor) out = self.evaluate(format_output) expected = "[0 0.1 0.2 ... 0.5 0.6 0.7]" self.assertEqual(compat.as_text(out), expected)
def testFormatOneTensorOneDimAlmostSummarize(self): with self.test_session(): tensor = math_ops.range(5) format_output = string_ops.string_format("{}", tensor, summarize=3) out = self.evaluate(format_output) expected = "[0 1 2 3 4]" self.assertEqual(compat.as_text(out), expected)
def testComplexCodeView(self): ops.reset_default_graph() outfile = os.path.join(test.get_temp_dir(), 'dump') opts = (builder(builder.trainable_variables_parameter()) .with_file_output(outfile) .with_accounted_types(['.*']) .with_node_names(show_name_regexes= ['.*model_analyzer_testlib.py.*']) .account_displayed_op_only(False) .select(['params', 'float_ops']).build()) with profile_context.ProfileContext(test.get_temp_dir(), trace_steps=[], dump_steps=[]) as pctx: with session.Session() as sess: x = lib.BuildFullModel() sess.run(variables.global_variables_initializer()) pctx.trace_next_step() _ = sess.run(x) tfprof_node = pctx.profiler.profile_python(options=opts) # pylint: disable=line-too-long with gfile.Open(outfile, 'r') as f: lines = f.read().split('\n') self.assertGreater(len(lines), 5) result = '\n'.join([l[:min(len(l), 80)] for l in lines]) self.assertTrue( compat.as_text(lib.CheckAndRemoveDoc(result)) .startswith('node name | # parameters | # float_ops')) self.assertLess(0, tfprof_node.total_exec_micros) self.assertEqual(2844, tfprof_node.total_parameters) self.assertLess(145660, tfprof_node.total_float_ops) self.assertEqual(8, len(tfprof_node.children)) self.assertEqual('_TFProfRoot', tfprof_node.name) self.assertEqual( 'model_analyzer_testlib.py:63:BuildFullModel', tfprof_node.children[0].name) self.assertEqual( 'model_analyzer_testlib.py:63:BuildFullModel (gradient)', tfprof_node.children[1].name) self.assertEqual( 'model_analyzer_testlib.py:67:BuildFullModel', tfprof_node.children[2].name) self.assertEqual( 'model_analyzer_testlib.py:67:BuildFullModel (gradient)', tfprof_node.children[3].name) self.assertEqual( 'model_analyzer_testlib.py:69:BuildFullModel', tfprof_node.children[4].name) self.assertEqual( 'model_analyzer_testlib.py:70:BuildFullModel', tfprof_node.children[5].name) self.assertEqual( 'model_analyzer_testlib.py:70:BuildFullModel (gradient)', tfprof_node.children[6].name) self.assertEqual( 'model_analyzer_testlib.py:72:BuildFullModel', tfprof_node.children[7].name)
def run_benchmark(sess, init_op, add_op): """Returns MB/s rate of addition.""" logdir=FLAGS.logdir_prefix+'/'+FLAGS.name os.system('mkdir -p '+logdir) # TODO: make events follow same format as eager writer writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events')) filename = compat.as_text(writer.FileName()) training_util.get_or_create_global_step() sess.run(init_op) for step in range(FLAGS.iters): start_time = time.time() for i in range(FLAGS.iters_per_step): sess.run(add_op.op) elapsed_time = time.time() - start_time rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time event = make_event('rate', rate, step) writer.WriteEvent(event) writer.Flush() writer.Close()
def save(self, as_text=False): """Writes a `SavedModel` protocol buffer to disk. The function writes the SavedModel protocol buffer to the export directory in serialized format. Args: as_text: Writes the SavedModel protocol buffer in text format to disk. Returns: The path to which the SavedModel protocol buffer was written. """ if not file_io.file_exists(self._export_dir): file_io.recursive_create_dir(self._export_dir) if as_text: path = os.path.join( compat.as_bytes(self._export_dir), compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) file_io.write_string_to_file(path, str(self._saved_model)) else: path = os.path.join( compat.as_bytes(self._export_dir), compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) file_io.write_string_to_file(path, self._saved_model.SerializeToString()) tf_logging.info("SavedModel written to: %s", compat.as_text(path)) return path
def _save_and_write_assets(self, assets_collection_to_add=None): """Saves asset to the meta graph and writes asset files to disk. Args: assets_collection_to_add: The collection where the asset paths are setup. """ asset_filename_map = _maybe_save_assets(assets_collection_to_add) # Return if there are no assets to write. if not asset_filename_map: tf_logging.info("No assets to write.") return assets_destination_dir = saved_model_utils.get_or_create_assets_dir( self._export_dir) # Copy each asset from source path to destination path. for asset_basename, asset_source_filepath in asset_filename_map.items(): asset_destination_filepath = os.path.join( compat.as_bytes(assets_destination_dir), compat.as_bytes(asset_basename)) # Only copy the asset file to the destination if it does not already # exist. This is to ensure that an asset with the same name defined as # part of multiple graphs is only copied the first time. if not file_io.file_exists(asset_destination_filepath): file_io.copy(asset_source_filepath, asset_destination_filepath) tf_logging.info("Assets written to: %s", compat.as_text(assets_destination_dir))
def load_file_system_library(library_filename): """Loads a TensorFlow plugin, containing file system implementation. Pass `library_filename` to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: None. Raises: RuntimeError: when unable to load the library. """ status = py_tf.TF_NewStatus() lib_handle = py_tf.TF_LoadLibrary(library_filename, status) try: error_code = py_tf.TF_GetCode(status) if error_code != 0: error_msg = compat.as_text(py_tf.TF_Message(status)) # pylint: disable=protected-access raise errors_impl._make_specific_exception( None, None, error_msg, error_code) # pylint: enable=protected-access finally: py_tf.TF_DeleteStatus(status)
def load_op_library(library_filename): """Loads a TensorFlow plugin, containing custom ops and kernels. Pass "library_filename" to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. When the library is loaded, ops and kernels registered in the library via the REGISTER_* macros are made available in the TensorFlow process. Note that ops with the same name as an existing op are rejected and not registered with the process. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: A python module containing the Python wrappers for Ops defined in the plugin. Raises: RuntimeError: when unable to load the library or get the python wrappers. """ status = py_tf.TF_NewStatus() lib_handle = py_tf.TF_LoadLibrary(library_filename, status) try: error_code = py_tf.TF_GetCode(status) if error_code != 0: error_msg = compat.as_text(py_tf.TF_Message(status)) with _OP_LIBRARY_MAP_LOCK: if (error_code == error_codes_pb2.ALREADY_EXISTS and 'has already been loaded' in error_msg and library_filename in _OP_LIBRARY_MAP): return _OP_LIBRARY_MAP[library_filename] # pylint: disable=protected-access raise errors._make_specific_exception(None, None, error_msg, error_code) # pylint: enable=protected-access finally: py_tf.TF_DeleteStatus(status) op_list_str = py_tf.TF_GetOpList(lib_handle) op_list = op_def_pb2.OpList() op_list.ParseFromString(compat.as_bytes(op_list_str)) wrappers = py_tf.GetPythonWrappers(op_list_str) # Get a unique name for the module. module_name = hashlib.md5(wrappers).hexdigest() module = imp.new_module(module_name) # pylint: disable=exec-used exec(wrappers, module.__dict__) # Stash away the library handle for making calls into the dynamic library. module.LIB_HANDLE = lib_handle # OpDefs of the list of ops defined in the library. module.OP_LIST = op_list sys.modules[module_name] = module # Memoize the filename to module mapping. with _OP_LIBRARY_MAP_LOCK: _OP_LIBRARY_MAP[library_filename] = module return module
def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. We retrieve the information from the GCE APIs every time this method is called. Returns: A ClusterSpec containing host information returned from Cloud TPUs. Raises: RuntimeError: If the provided TPU is not healthy. """ ############################################################################ # There are 5 potential cases this code must handle: # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and # a. Create a ClusterSpec that includes the coordinator job # b. Create a ClusterSpec without the coordinator job. # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of # tasks and # a. Create a ClusterSpec with the coordinator # b. Create a ClusterSpec without the coordinator # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. ############################################################################ if self._shouldResolve(): # Case 1. full_name = 'projects/%s/locations/%s/nodes/%s' % ( self._project, self._zone, compat.as_text(self._tpu)) request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() if 'health' in response and response['health'] != 'HEALTHY': raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, response['health'])) if 'networkEndpoints' in response: worker_list = [ '%s:%s' % (endpoint['ipAddress'], endpoint['port']) for endpoint in response['networkEndpoints'] ] else: # Fall back to the deprecated response format instance_url = '%s:%s' % (response['ipAddress'], response['port']) worker_list = [instance_url] cluster_spec = {self._job_name: worker_list} else: if not self._tpu.startswith(compat.as_bytes('grpc://')): # Case 3. return None # Case 2. cluster_spec = {self._job_name: [self._tpu[len( compat.as_bytes('grpc://')):]]} if self._coordinator_address: # {1, 2}.a cluster_spec[self._coordinator_name] = [self._coordinator_address] return server_lib.ClusterSpec(cluster_spec)
def raise_exception_on_not_ok_status(): status = c_api_util.ScopedTFStatus() yield status.status if c_api.TF_GetCode(status) != 0: raise _make_specific_exception( None, None, compat.as_text(c_api.TF_Message(status)), c_api.TF_GetCode(status))
def testFormatOneTensorTwoDimLessThanSummarize(self): with self.test_session(): tensor = array_ops.reshape(math_ops.range(4), [2, 2]) format_output = string_ops.string_format("{}", tensor, summarize=3) out = self.evaluate(format_output) expected = ("[[0 1]\n" " [2 3]]") self.assertEqual(compat.as_text(out), expected)
def testFormatOneVariableScalar(self): with self.test_session(): var = variables.Variable(3.34) format_output = string_ops.string_format("{}", [var]) if not context.executing_eagerly(): variables.global_variables_initializer().run() out = self.evaluate(format_output) expected = "3.34" self.assertEqual(compat.as_text(out), expected)
def testFormatOneVariableOneDim(self): with self.test_session(): var = variables.Variable(math_ops.range(10)) format_output = string_ops.string_format("{}", [var]) if not context.executing_eagerly(): variables.global_variables_initializer().run() out = self.evaluate(format_output) expected = "[0 1 2 ... 7 8 9]" self.assertEqual(compat.as_text(out), expected)
def testFormatSummarizeOne(self): with self.test_session(): tensor = array_ops.reshape(math_ops.range(100), [10, 10]) format_output = string_ops.string_format("tensor summary: {}", tensor, summarize=1) out = self.evaluate(format_output) expected = ("tensor summary: [[0 ... 9]\n" " ...\n" " [90 ... 99]]") self.assertEqual(compat.as_text(out), expected)
def add_meta_graph_and_variables(self, sess, tags, signature_def_map=None, assets_collection=None): """Adds the current meta graph to the SavedModel and saves variables. Creates a Saver to save the variables from the provided session. Exports the corresponding meta graph def. This function assumes that the variables to be saved have been initialized. For a given `SavedModelBuilder`, this API must be called exactly once and for the first meta graph to save. For subsequent meta graph defs to be added, the `add_meta_graph()` API must be used. Args: sess: The TensorFlow session from which to save the meta graph and variables. tags: The set of tags with which to save the meta graph. signature_def_map: The map of signature def map to add to the meta graph def. assets_collection: Assets collection to be saved with SavedModel. """ if self._has_saved_variables: raise AssertionError("Variables and assets have already been saved. " "Please invoke `add_meta_graph()` instead.") # Save asset files and write them to disk, if any. self._save_and_write_assets(assets_collection) export_path = os.path.join( compat.as_text(self._export_dir), compat.as_text(constants.VARIABLES_FILENAME)) # Save the variables and export meta graph def. saver = tf_saver.Saver(variables.all_variables()) saver.save(sess, export_path, write_meta_graph=False) meta_graph_def = saver.export_meta_graph() # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) # Mark this instance of SavedModel as having saved variables, such that # subsequent attempts to save variables will fail. self._has_saved_variables = True
def _find_all_hints_in_graph_def(graphdef): """Look at the current default graph and return a list of LiteFuncCall objs. Args: graphdef: A TensorFlow graph_def to look for LiteFuncCalls. Returns: a list of `LifeFuncCall` objects in the form """ func_calls = _collections.defaultdict(_LiteFuncCall) for node in graphdef.node: attr = node.attr # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip uuid = attr[OpHint.FUNCTION_UUID_ATTR].s if (OpHint.FUNCTION_UUID_ATTR not in attr or not attr[OpHint.FUNCTION_UUID_ATTR].s): continue # Start building function call_def = func_calls[uuid] call_def.uuid = uuid call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s # Get sorting and aggregation information sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) if sort == -1: sort = None aggregation = None if OpHint.FUNCTION_AGGREGATE_ATTR in attr: aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) # Add the input or output def put_operand(stuff, index, sort, operand, aggregation): """Add a given index into the function structure.""" if sort is None: stuff[index] = _LiteSingleOperand(operand) else: if index not in stuff: stuff[index] = _LiteAggregateOperand(aggregation) stuff[index].add(sort, operand) if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i, sort, node, aggregation) if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i, sort, node, aggregation) # Remember attributes for a in attr: if a.startswith("_tflite_attr_"): call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor return func_calls
def copy_assets_to_destination_dir(asset_filename_map, destination_dir): """Copy all assets from source path to destination path.""" assets_destination_dir = saved_model_utils.get_or_create_assets_dir( destination_dir) # Copy each asset from source path to destination path. for asset_basename, asset_source_filepath in asset_filename_map.items(): asset_destination_filepath = os.path.join( compat.as_bytes(assets_destination_dir), compat.as_bytes(asset_basename)) # Only copy the asset file to the destination if it does not already # exist. This is to ensure that an asset with the same name defined as # part of multiple graphs is only copied the first time. if not file_io.file_exists(asset_destination_filepath): file_io.copy(asset_source_filepath, asset_destination_filepath) tf_logging.info("Assets written to: %s", compat.as_text(assets_destination_dir))
def deserialize(self, encoded_accumulator): """Deserialize an accumulator received from 'serialize()'.""" accumulator_dict = json.loads(compat.as_text(encoded_accumulator)) accumulator = self._create_accumulator() count_dict = dict( zip(accumulator_dict["vocab"], accumulator_dict["vocab_counts"])) accumulator.count_dict.update(count_dict) if self._compute_idf: accumulator.data = accumulator_dict["data"] create_dict = lambda x: {"count": x, "last_doc_id": -1} idf_count_dicts = [ create_dict(count) for count in accumulator_dict["idf_counts"] ] idf_dict = dict(zip(accumulator_dict["idf_vocab"], idf_count_dicts)) accumulator.per_doc_count_dict.update(idf_dict) return accumulator
def _testOneEpoch(self, files): with self.test_session() as sess: reader = io_ops.TextLineReader(name="test_reader") queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_lines): k, v = sess.run([key, value]) self.assertAllEqual("%s:%d" % (files[i], j + 1), compat.as_text(k)) self.assertAllEqual(self._LineText(i, j), v) with self.assertRaisesOpError( "is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value])
def close(self): """Closes this session. Calling this method frees all resources associated with the session. Raises: RuntimeError: If an error occurs while closing the session. """ with self._extend_lock: if self._opened and not self._closed: self._closed = True try: status = tf_session.TF_NewStatus() tf_session.TF_CloseSession(self._session, status) if tf_session.TF_GetCode(status) != 0: raise RuntimeError( compat.as_text(tf_session.TF_Message(status))) finally: tf_session.TF_DeleteStatus(status)
def __init__(self, target='', graph=None, config=None): """Constructs a new TensorFlow session. Args: target: (Optional) The TensorFlow execution engine to connect to. graph: (Optional) The graph to be used. If this argument is None, the default graph will be used. config: (Optional) ConfigProto proto used to configure the session. Raises: RuntimeError: If an error occurs while creating the TensorFlow session. """ if graph is None: self._graph = ops.get_default_graph() else: self._graph = graph self._opened = False self._closed = False self._current_version = 0 self._extend_lock = threading.Lock() self._target = target self._delete_lock = threading.Lock() self._dead_handles = [] self._session = None opts = tf_session.TF_NewSessionOptions(target=target, config=config) try: status = tf_session.TF_NewStatus() try: self._session = tf_session.TF_NewSession(opts, status) if tf_session.TF_GetCode(status) != 0: raise RuntimeError( compat.as_text(tf_session.TF_Message(status))) finally: tf_session.TF_DeleteStatus(status) finally: tf_session.TF_DeleteSessionOptions(opts)
def testOneEpoch(self): files = self._CreateFiles() with self.cached_session() as sess: reader = io_ops.TFRecordReader(name="test_reader") queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) self.assertTrue( compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) with self.assertRaisesOpError( "is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value])
def _extend_graph(self): # Ensure any changes to the graph are reflected in the runtime. with self._extend_lock: if self._graph.version > self._current_version: graph_def = self._graph.as_graph_def( from_version=self._current_version) try: status = tf_session.TF_NewStatus() tf_session.TF_ExtendGraph(self._session, graph_def.SerializeToString(), status) if tf_session.TF_GetCode(status) != 0: raise RuntimeError( compat.as_text(tf_session.TF_Message(status))) self._opened = True finally: tf_session.TF_DeleteStatus(status) self._current_version = self._graph.version
def _write_value_event(self, event): value = event.summary.value[0] # Obtain the device name from the metadata. summary_metadata = event.summary.value[0].metadata if not summary_metadata.plugin_data: raise ValueError("The value lacks plugin data.") try: content = json.loads( compat.as_text(summary_metadata.plugin_data.content)) except ValueError as err: raise ValueError("Could not parse content into JSON: %r, %r" % (content, err)) device_name = content["device"] dump_full_path = _get_dump_file_path(self._dump_dir, device_name, value.node_name) self._try_makedirs(os.path.dirname(dump_full_path)) with open(dump_full_path, "wb") as f: f.write(event.SerializeToString())
def _serialize_object_graph(saveable_view, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the TrackableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() saveable_view.fill_object_graph_proto(proto) coder = nested_structure_coder.StructureCoder() for concrete_function in saveable_view.concrete_functions: name = compat.as_text(concrete_function.name) name = saveable_view.function_name_map.get(name, name) serialized = function_serialization.serialize_concrete_function( concrete_function, saveable_view.captured_tensor_node_ids, coder) if serialized is not None: proto.concrete_functions[name].CopyFrom(serialized) for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index, saveable_view.function_name_map) return proto
def _do_call(self, fn, *args): try: return fn(*args) except tf_session.StatusNotOK as e: e_type, e_value, e_traceback = sys.exc_info() error_message = compat.as_text(e.error_message) m = BaseSession._NODEDEF_NAME_RE.search(error_message) if m is not None: node_name = m.group(1) node_def = None try: op = self._graph.get_operation_by_name(node_name) node_def = op.node_def except KeyError: op = None # pylint: disable=protected-access raise errors._make_specific_exception(node_def, op, error_message, e.code) # pylint: enable=protected-access six.reraise(e_type, e_value, e_traceback)
def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. We retrieve the information from the GCE APIs every time this method is called. Returns: A ClusterSpec containing host information returned from Cloud TPUs. Raises: RuntimeError: If the provided TPU is not healthy. """ if not self._shouldResolve(): return server_lib.ClusterSpec({}) full_name = 'projects/%s/locations/%s/nodes/%s' % ( self._project, self._zone, compat.as_text(self._tpu)) request = self._service.projects().locations().nodes().get( name=full_name) response = request.execute() if 'health' in response and response['health'] != 'HEALTHY': raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, response['health'])) if 'networkEndpoints' in response: worker_list = [ '%s:%s' % (endpoint['ipAddress'], endpoint['port']) for endpoint in response['networkEndpoints'] ] else: # Fall back to the deprecated response format instance_url = '%s:%s' % (response['ipAddress'], response['port']) worker_list = [instance_url] cluster_spec = {self._job_name: worker_list} if self._coordinator_address: cluster_spec[self._coordinator_name] = [self._coordinator_address] return server_lib.ClusterSpec(cluster_spec)
def Cleanse(obj, encoding='utf-8'): """Makes Python object appropriate for JSON serialization. - Replaces instances of Infinity/-Infinity/NaN with strings. - Turns byte strings into unicode strings. - Turns sets into sorted lists. - Turns tuples into lists. Args: obj: Python data structure. encoding: Charset used to decode byte strings. Returns: Unicode JSON data structure. """ if isinstance(obj, int): return obj elif isinstance(obj, float): if obj == _INFINITY: return 'Infinity' elif obj == _NEGATIVE_INFINITY: return '-Infinity' elif math.isnan(obj): return 'NaN' else: return obj elif isinstance(obj, bytes): return compat.as_text(obj, encoding) elif isinstance(obj, list) or isinstance(obj, tuple): return [Cleanse(i, encoding) for i in obj] elif isinstance(obj, set): return [Cleanse(i, encoding) for i in sorted(obj)] elif isinstance(obj, dict): return { Cleanse(k, encoding): Cleanse(v, encoding) for k, v in obj.items() } else: return obj
def testFormatMultiTensor(self): with self.test_session(): tensor_one = array_ops.reshape(math_ops.range(100), [10, 10]) tensor_two = tensor_one * 10 format_output = string_ops.string_format("One: {},\nTwo: {}", (tensor_one, tensor_two)) out = self.evaluate(format_output) expected = ("One: [[0 1 2 ... 7 8 9]\n" " [10 11 12 ... 17 18 19]\n" " [20 21 22 ... 27 28 29]\n" " ...\n" " [70 71 72 ... 77 78 79]\n" " [80 81 82 ... 87 88 89]\n" " [90 91 92 ... 97 98 99]],\n" "Two: [[0 10 20 ... 70 80 90]\n" " [100 110 120 ... 170 180 190]\n" " [200 210 220 ... 270 280 290]\n" " ...\n" " [700 710 720 ... 770 780 790]\n" " [800 810 820 ... 870 880 890]\n" " [900 910 920 ... 970 980 990]]") self.assertEqual(compat.as_text(out), expected)
def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None): hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes reader = io_ops.FixedLengthRecordReader( header_bytes=self._header_bytes, record_bytes=self._record_bytes, footer_bytes=self._footer_bytes, hop_bytes=hop_bytes, encoding=encoding, name="test_reader") queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) self.evaluate(queue.enqueue_many([files])) self.evaluate(queue.close()) for i in range(self._num_files): for j in range(num_records): k, v = self.evaluate([key, value]) self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k)) self.assertAllEqual(self._Record(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = self.evaluate([key, value])
def testOneEpoch(self): files = self._CreateFiles() with self.test_session() as sess: reader = io_ops.FixedLengthRecordReader( header_bytes=self._header_bytes, record_bytes=self._record_bytes, footer_bytes=self._footer_bytes, hop_bytes=0, name="test_reader") queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k)) self.assertAllEqual(self._Record(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value])
def _save_and_write_assets(self, assets_collection_to_add=None): """Saves asset to the meta graph and writes asset files to disk. Args: assets_collection_to_add: The collection where the asset paths are setup. """ asset_source_filepath_list = _maybe_save_assets( assets_collection_to_add) # Return if there are no assets to write. if len(asset_source_filepath_list) is 0: tf_logging.info("No assets to write.") return assets_destination_dir = os.path.join( compat.as_bytes(self._export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY)) if not file_io.file_exists(assets_destination_dir): file_io.recursive_create_dir(assets_destination_dir) # Copy each asset from source path to destination path. for asset_source_filepath in asset_source_filepath_list: asset_source_filename = os.path.basename(asset_source_filepath) asset_destination_filepath = os.path.join( compat.as_bytes(assets_destination_dir), compat.as_bytes(asset_source_filename)) # Only copy the asset file to the destination if it does not already # exist. This is to ensure that an asset with the same name defined as # part of multiple graphs is only copied the first time. if not file_io.file_exists(asset_destination_filepath): file_io.copy(asset_source_filepath, asset_destination_filepath) tf_logging.info("Assets written to: %s", compat.as_text(assets_destination_dir))
def load_file_system_library(library_filename): """Loads a TensorFlow plugin, containing file system implementation. Pass `library_filename` to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: None. Raises: RuntimeError: when unable to load the library. """ status = py_tf.TF_NewStatus() lib_handle = py_tf.TF_LoadLibrary(library_filename, status) try: error_code = py_tf.TF_GetCode(status) if error_code != 0: error_msg = compat.as_text(py_tf.TF_Message(status)) with _FILE_SYSTEM_LIBRARY_MAP_LOCK: if (error_code == error_codes_pb2.ALREADY_EXISTS and 'has already been loaded' in error_msg and library_filename in _FILE_SYSTEM_LIBRARY_MAP): return # pylint: disable=protected-access raise errors._make_specific_exception(None, None, error_msg, error_code) # pylint: enable=protected-access finally: py_tf.TF_DeleteStatus(status) with _FILE_SYSTEM_LIBRARY_MAP_LOCK: _FILE_SYSTEM_LIBRARY_MAP[library_filename] = lib_handle
def deserialize(self, encoded_accumulator): """Deserialize an accumulator received from 'serialize()'.""" value_dict = json.loads(compat.as_text(encoded_accumulator)) return self._create_accumulator( np.array(value_dict[_COUNT_NAME]), np.array(value_dict[_MEAN_NAME]), np.array(value_dict[_VARIANCE_NAME]))
def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices. Args: cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.tpu.Topology object for the topology of the TPU cluster. Raises: RuntimeError: If no TPU devices found for eager execution. """ if cluster_resolver is None: cluster_resolver = TPUClusterResolver("") assert isinstance(cluster_resolver, TPUClusterResolver) tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access if tpu_name in _INITIALIZED_TPU_SYSTEMS: logging.warning("TPU system %s has already been initialized. " "Reinitializing the TPU can cause previously created " "variables on TPU to be lost.") logging.info("Initializing the TPU system.") if context.executing_eagerly(): # This function looks as it is for the following non-intuitive reasons. # tpu.initialize_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. @function.defun def _tpu_init_fn(): return tpu.initialize_system() tpu_devices = sorted( [x for x in context.list_devices() if "device:TPU:" in x]) if not tpu_devices: raise RuntimeError("Could not find any TPU devices") # Replace the remote TPU device with the remote TPU_SYSTEM system device. As # in the remote TPU device case, we will try to compile it instead of # running through optimization passes and TF Executor, but TPU_SYSTEM should # work. tpu_system_device = tpu_devices[0].replace("TPU", "TPU_SYSTEM") with ops.device(tpu_system_device): output = _tpu_init_fn() serialized_topology = output.numpy() else: master = cluster_resolver.master() session_config = config_pb2.ConfigProto(allow_soft_placement=True) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) logging.info("Finished initializing TPU system.") tpu_topology = topology.Topology(serialized=serialized_topology) _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology return tpu_topology
def apply_op(self, op_type_name, name=None, **keywords): # pylint: disable=g-doc-args """Add a node invoking a registered Op to a graph. Example usage: # input1 and input2 can be Tensors or anything ops.convert_to_tensor() # will convert to a Tensor. op_def_library.apply_op("op", input1=input1, input2=input2) # Can specify a node name. op_def_library.apply_op("op", input1=input1, name="node_name") # Must use keyword arguments, with the names specified in the OpDef. op_def_library.apply_op("op", input_name=input, attr_name=attr) All attrs must either be inferred from an input or specified. (If inferred, the attr must not be specified.) If an attr has a default value specified in the Op's OpDef, then you may pass None as the value of that attr to get the default. Args: op_type_name: string. Must match the name field of a registered Op. name: string. Optional name of the created op. **keywords: input Tensor and attr arguments specified by name, and optional parameters to pass when constructing the Operation. Returns: The Tensor(s) representing the output of the operation, or the Operation itself if there are no outputs. Raises: RuntimeError: On some errors. TypeError: On some errors. ValueError: On some errors. """ op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) op_def = op_info.op_def # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pyline: enable=protected-access except AssertionError as e: raise RuntimeError( "Cannot determine graph for Op '%s' due to: %s" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( ("Op %s is not available in GraphDef version %d. " "It has been removed in version %d. %s.") % (op_type_name, producer, deprecation_version, op_def.deprecation.explanation)) # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) if input_arg.number_attr and len( set(v.dtype.base_dtype for v in values)) > 1: raise TypeError() # All types should match. except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append(converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append("<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s that do not match expected type %s." % (prefix, dtype.name)) elif input_arg.type_attr in attrs: raise TypeError("%s that do not match type %s inferred from " "earlier arguments." % (prefix, dtype.name)) else: raise TypeError("%s that don't all match." % prefix) else: raise TypeError("%s that are invalid." % prefix) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[input_arg.type_attr] try: values = ops.convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except ValueError: # What type does convert_to_tensor think it has? observed = ops.convert_to_tensor(values, as_ref=input_arg.is_ref).dtype.name prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s expected type of %s." % (prefix, dtypes.as_dtype(input_arg.type).name)) else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( "%s type %s of argument '%s'." % (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any([bt != base_types[0] for bt in base_types]): raise TypeError( "All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: assert False, "Unreachable" else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr)) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join(dtypes.as_dtype(x).name for x in attr_value), ", ".join(dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_list_attr)) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x.is_ref_dtype for x in types): raise TypeError( "Input '%s' of '%s' Op requires l-value input" % (input_name, op_type_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError("Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) attr_value.list.SetInParent() if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(attr_value.s), '", "'.join(map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(x), '", "'.join(map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend([_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) elif attr_def.type == "func": if isinstance(value, compat.bytes_or_text_types): attr_value.func.name = value else: value.add_to_graph(ops.get_default_graph()) attr_value.func.name = value.name else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_types = [] output_structure = [] for arg in op_def.output_arg: types = [] if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i if arg.type_attr: types = [_AttrValue(attr_protos, arg.type_attr).type] * n else: types = [arg.type] * n output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) types = [t.type] output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) types = t.list.type output_structure.append(len(types)) else: types = [arg.type] output_structure.append(None) if arg.is_ref: types = [dtypes.as_dtype(x).as_ref for x in types] output_types.extend(types) if keywords: raise TypeError("apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) if output_structure: outputs = op.outputs res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) if isinstance(res, list) and not res and op_def.is_stateful: return op else: return res else: return op
def deserialize(self, encoded_accumulator): """Deserialize an accumulator received from 'serialize()'.""" return json.loads(compat.as_text(encoded_accumulator))
def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices. Args: cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.tpu.Topology object for the topology of the TPU cluster. Raises: RuntimeError: If no TPU devices found for eager execution. """ if cluster_resolver is None: cluster_resolver = TPUClusterResolver("") assert isinstance(cluster_resolver, TPUClusterResolver) tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access if tpu_name in _INITIALIZED_TPU_SYSTEMS: logging.warning("TPU system %s has already been initialized. " "Reinitializing the TPU can cause previously created " "variables on TPU to be lost.") logging.info("Initializing the TPU system: %s", tpu_name) if context.executing_eagerly(): # This function looks as it is for the following non-intuitive reasons. # tpu.initialize_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. job = None if tpu_name not in _LOCAL_MASTERS: # Explicitly place the tpu.initialize_system in the first worker to # avoid the output node match multiple devices error. job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) @function.defun def _tpu_init_fn(): return tpu.initialize_system(job=job) # The TPU_SYSTEM device must match the device used in tpu.initialize_system # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM # devices available. with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access output = _tpu_init_fn() # Clear out the eager context caches since the memory is invalid now. logging.info("Clearing out eager caches") context.context()._clear_caches() # pylint: disable=protected-access serialized_topology = output.numpy() else: master = cluster_resolver.master() cluster_spec = cluster_resolver.cluster_spec() session_config = config_pb2.ConfigProto(allow_soft_placement=True) if cluster_spec: session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) logging.info("Finished initializing TPU system.") tpu_topology = topology.Topology(serialized=serialized_topology) _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology return tpu_topology
def add_meta_graph_and_variables(self, sess, tags, signature_def_map=None, assets_collection=None, legacy_init_op=None, clear_devices=False, main_op=None): """Adds the current meta graph to the SavedModel and saves variables. Creates a Saver to save the variables from the provided session. Exports the corresponding meta graph def. This function assumes that the variables to be saved have been initialized. For a given `SavedModelBuilder`, this API must be called exactly once and for the first meta graph to save. For subsequent meta graph defs to be added, the `add_meta_graph()` API must be used. Args: sess: The TensorFlow session from which to save the meta graph and variables. tags: The set of tags with which to save the meta graph. signature_def_map: The map of signature def map to add to the meta graph def. assets_collection: Assets collection to be saved with SavedModel. legacy_init_op: Legacy support for op or group of ops to execute after the restore op upon a load. clear_devices: Set to true if the device info on the default graph should be cleared. main_op: Op or group of ops to execute when the graph is loaded. """ if self._has_saved_variables: raise AssertionError( "Graph state including variables and assets has " "already been saved. Please invoke " "`add_meta_graph()` instead.") # Validate the signature def map to ensure all included TensorInfos are # properly populated. self._validate_signature_def_map(signature_def_map) # Save asset files and write them to disk, if any. self._save_and_write_assets(assets_collection) # Create the variables sub-directory, if it does not exist. variables_dir = os.path.join( compat.as_text(self._export_dir), compat.as_text(constants.VARIABLES_DIRECTORY)) if not file_io.file_exists(variables_dir): file_io.recursive_create_dir(variables_dir) variables_path = os.path.join( compat.as_text(variables_dir), compat.as_text(constants.VARIABLES_FILENAME)) if main_op is None: # Add legacy init op to the SavedModel. self._maybe_add_legacy_init_op(legacy_init_op) else: self._add_main_op(main_op) # Initialize a saver to generate a sharded output for all saveables in the # current scope. saver = tf_saver.Saver( variables._all_saveable_objects(), # pylint: disable=protected-access sharded=True, write_version=saver_pb2.SaverDef.V2, allow_empty=True) # Save the variables. Also, disable writing the checkpoint state proto. The # file is not used during SavedModel loading. In addition, since a # SavedModel can be copied or moved, this avoids the checkpoint state to # become outdated. saver.save(sess, variables_path, write_meta_graph=False, write_state=False) # Export the meta graph def. # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another # for the model weights). However, a *new* Saver was just created that # includes all of the variables. In the context of the SavedModel, this # new Saver is the only one that needs to be retained. The associated # checkpoint that was saved just above contains all of the variable values. # Thus, any preexisting Savers are redundant and useless at best, but worse # may break downstream graph-processing tools, and can be confusing during # debugging. It is therefore safe and wise to set `clear_extraneous_savers` # to `True`, since it removes both the extraneous SaverDefs and their # associated Save/Restore Ops from the graph. meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices, clear_extraneous_savers=True) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) # Mark this instance of SavedModel as having saved variables, such that # subsequent attempts to save variables will fail. self._has_saved_variables = True
def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. We retrieve the information from the GCE APIs every time this method is called. Returns: A ClusterSpec containing host information returned from Cloud TPUs. Raises: RuntimeError: If the provided TPU is not healthy. """ ############################################################################ # There are 5 potential cases this code must handle: # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and # a. Create a ClusterSpec that includes the coordinator job # b. Create a ClusterSpec without the coordinator job. # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of # tasks and # a. Create a ClusterSpec with the coordinator # b. Create a ClusterSpec without the coordinator # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. ############################################################################ if self._shouldResolve(): # Case 1. full_name = 'projects/%s/locations/%s/nodes/%s' % ( self._project, self._zone, compat.as_text(self._tpu)) request = self._service.projects().locations().nodes().get( name=full_name) response = request.execute() if 'state' in response and response['state'] != 'READY': raise RuntimeError( 'TPU "%s" is not yet ready; state: "%s"' % (compat.as_text(self._tpu), response['state'])) if 'health' in response and response['health'] != 'HEALTHY': raise RuntimeError( 'TPU "%s" is unhealthy: "%s"' % (compat.as_text(self._tpu), response['health'])) if 'networkEndpoints' in response: worker_list = [ '%s:%s' % (endpoint['ipAddress'], endpoint['port']) for endpoint in response['networkEndpoints'] ] else: # Fall back to the deprecated response format instance_url = '%s:%s' % (response['ipAddress'], response['port']) worker_list = [instance_url] cluster_spec = {self._job_name: worker_list} else: if not self._tpu.startswith(compat.as_bytes('grpc://')): # Case 3. return None # Case 2. cluster_spec = { self._job_name: [ x[len(compat.as_bytes('grpc://')):] for x in self._tpu.split( compat.as_bytes(_ENDPOINTS_SEPARATOR)) ] } if self._coordinator_address: # {1, 2}.a cluster_spec[self._coordinator_name] = [self._coordinator_address] return server_lib.ClusterSpec(cluster_spec)
def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. We retrieve the information from the GCE APIs every time this method is called. Returns: A ClusterSpec containing host information returned from Cloud TPUs. Raises: RuntimeError: If the provided TPU is not healthy. """ ############################################################################ # There are 5 potential cases this code must handle: # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and # a. Create a ClusterSpec that includes the coordinator job # b. Create a ClusterSpec without the coordinator job. # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of # tasks and # a. Create a ClusterSpec with the coordinator # b. Create a ClusterSpec without the coordinator # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. ############################################################################ if self._should_resolve(): # Case 1. response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access if 'state' in response and response['state'] != 'READY': raise RuntimeError( 'TPU "%s" is not yet ready; state: "%s"' % (compat.as_text(self._tpu), response['state'])) if 'networkEndpoints' in response: worker_list = [ '%s:%s' % (endpoint['ipAddress'], endpoint['port']) for endpoint in response['networkEndpoints'] ] else: # Fall back to the deprecated response format instance_url = '%s:%s' % (response['ipAddress'], response['port']) worker_list = [instance_url] cluster_spec = {self.task_type: worker_list} else: if self.rpc_layer is None: # Case 3. return None # Case 2. tpus = [] for tpu in compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR): # We are working around the fact that GKE environment variable that is # supplied to us has the protocol string embedded in it, but we want # to strip it out for the ClusterSpec. if (self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://')): tpus.append(tpu[len(self.rpc_layer + '://'):]) else: tpus.append(tpu) cluster_spec = {self.task_type: tpus} if self._coordinator_address: # {1, 2}.a cluster_spec[self._coordinator_name] = [self._coordinator_address] return server_lib.ClusterSpec(cluster_spec)
def get_variables_dir(export_dir): """Return variables sub-directory in the SavedModel.""" return os.path.join(compat.as_text(export_dir), compat.as_text(constants.VARIABLES_DIRECTORY))
def get_variables_path(export_dir): """Return the variables path, used as the prefix for checkpoint files.""" return os.path.join(compat.as_text(get_variables_dir(export_dir)), compat.as_text(constants.VARIABLES_FILENAME))
def get_assets_dir(export_dir): """Return path to asset directory in the SavedModel.""" return os.path.join(compat.as_text(export_dir), compat.as_text(constants.ASSETS_DIRECTORY))
def get_debug_dir(export_dir): """Returns path to the debug sub-directory in the SavedModel.""" return os.path.join(compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY))
def testFormatNoTensor(self): with self.test_session(): format_output = string_ops.string_format("No tensor.", ()) out = self.evaluate(format_output) expected = "No tensor." self.assertEqual(compat.as_text(out), expected)