def _tito_out(tensor_in): checkpoint_dir = tmp_dir if tmp_dir is None: checkpoint_dir = tempfile.mkdtemp() g = tf.Graph() with g.as_default(): si = tf.placeholder(dtype=tensor_in.dtype, shape=tensor_in.shape, name=tensor_in.op.name) so = tito_in(si) all_vars = tf.contrib.slim.get_variables_to_restore( exclude=exclude) saver = tf.train.Saver(all_vars) # Downloading the checkpoint from GCS to local speeds up saver.restore() a lot. checkpoint_tmp = os.path.join(checkpoint_dir, 'checkpoint') with file_io.FileIO(checkpoint, 'r') as f_in, file_io.FileIO( checkpoint_tmp, 'w') as f_out: f_out.write(f_in.read()) with tf.Session() as sess: saver.restore(sess, checkpoint_tmp) output_graph_def = tf.graph_util.convert_variables_to_constants( sess, g.as_graph_def(), [so.op.name]) file_io.delete_file(checkpoint_tmp) if tmp_dir is None: shutil.rmtree(checkpoint_dir) tensors_out = tf.import_graph_def(output_graph_def, input_map={si.name: tensor_in}, return_elements=[so.name]) return tensors_out[0]
def create_object_test(): """Verifies file_io's object manipulation methods .""" starttime = int(round(time.time() * 1000)) dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime) print("Creating dir %s." % dir_name) file_io.create_dir(dir_name) # Create a file in this directory. file_name = "%s/test_file.txt" % dir_name print("Creating file %s." % file_name) file_io.write_string_to_file(file_name, "test file creation.") list_files_pattern = "%s/test_file*.txt" % dir_name print("Getting files matching pattern %s." % list_files_pattern) files_list = file_io.get_matching_files(list_files_pattern) print(files_list) assert len(files_list) == 1 assert files_list[0] == file_name # Cleanup test files. print("Deleting file %s." % file_name) file_io.delete_file(file_name) # Delete directory. print("Deleting directory %s." % dir_name) file_io.delete_recursively(dir_name)
def _sync(self): added_ids_filepath = 'models/added_ids.txt' added_ids = set(self._read_added_ids(added_ids_filepath)) if not added_ids: return target_ids = set(self._get_ids()) logging.info('added_ids count: %d', len(added_ids)) logging.info('target_ids count: %d', len(target_ids)) remove_ids = list(added_ids - target_ids) add_ids = list(target_ids - added_ids) logging.info('remove_ids count: %d', len(remove_ids)) logging.info('add_ids count: %d', len(add_ids)) if remove_ids: for ids in chunks(remove_ids, 20000): ids = np.array(ids, dtype=np.int64) self.faiss_index.remove_ids(ids) logging.info("removed") if add_ids: for ids in chunks(add_ids, 20000): t0 = time.time() filepaths = [self._get_filepath(id) for id in ids] xb = self._path_to_xb(filepaths) ids = np.array(ids, dtype=np.int64) self.faiss_index.add(xb, ids) logging.info("%d embeddings added %.3f s", xb.shape[0], time.time() - t0) file_io.delete_file(added_ids_filepath) logging.info("Synced. ntotal: %d", self.faiss_index.ntotal())
def testFileWrite(self): file_path = os.path.join(self.get_temp_dir(), "temp_file") file_io.write_string_to_file(file_path, "testing") self.assertTrue(file_io.file_exists(file_path)) file_contents = file_io.read_file_to_string(file_path) self.assertEqual(b"testing", file_contents) file_io.delete_file(file_path)
def _v1_asset_saved_model(self): export_graph = ops.Graph() vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt") with open(vocab_path, "w") as f: f.write("alpha\nbeta\ngamma\n") with export_graph.as_default(): initializer = lookup_ops.TextFileInitializer( vocab_path, key_dtype=dtypes.string, key_index=lookup_ops.TextFileIndex.WHOLE_LINE, value_dtype=dtypes.int64, value_index=lookup_ops.TextFileIndex.LINE_NUMBER) table = lookup_ops.HashTable(initializer, default_value=-1) start = array_ops.placeholder(shape=None, dtype=dtypes.string, name="in") output = table.lookup(start, name="out") with session_lib.Session() as session: session.run([table.initializer]) path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) simple_save.simple_save(session, path, inputs={"start": start}, outputs={"output": output}, legacy_init_op=table.initializer) file_io.delete_file(vocab_path) return path
def test_libsvm_generator(self): data_dir = test.get_temp_dir() data_file = os.path.join(data_dir, "test_libvsvm.txt") if file_io.file_exists(data_file): file_io.delete_file(data_file) with open(data_file, "wt") as writer: writer.write(LIBSVM_DATA) want = { "1": np.array([[0.1], [0.], [0.12], [0.]], dtype=np.float32), "2": np.array([[0.], [0.13], [0.], [0.]], dtype=np.float32), "3": np.array([[0.3], [0.], [0.], [0.]], dtype=np.float32), "4": np.array([[-0.4], [0.], [0.24], [0.]], dtype=np.float32), "5": np.array([[0.], [0.], [0.5], [0.]], dtype=np.float32), } reader = data_lib.libsvm_generator(data_file, 5, 4, seed=10) for features, labels in reader(): self.assertAllEqual(labels, [2.0, 0., 1.0, -1.0]) self.assertAllEqual(sorted(features.keys()), sorted(want.keys())) for k in sorted(want): self.assertAllEqual(features.get(k), want.get(k)) break
def _v1_asset_saved_model(self): export_graph = ops.Graph() vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt") with open(vocab_path, "w") as f: f.write("alpha\nbeta\ngamma\n") with export_graph.as_default(): initializer = lookup_ops.TextFileInitializer( vocab_path, key_dtype=dtypes.string, key_index=lookup_ops.TextFileIndex.WHOLE_LINE, value_dtype=dtypes.int64, value_index=lookup_ops.TextFileIndex.LINE_NUMBER) table = lookup_ops.HashTable( initializer, default_value=-1) start = array_ops.placeholder( shape=None, dtype=dtypes.string, name="in") output = table.lookup(start, name="out") with session_lib.Session() as session: session.run([table.initializer]) path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) simple_save.simple_save( session, path, inputs={"start": start}, outputs={"output": output}, legacy_init_op=table.initializer) file_io.delete_file(vocab_path) return path
def test_table(self): initializer = lookup_ops.TextFileInitializer( self._vocab_path, key_dtype=dtypes.string, key_index=lookup_ops.TextFileIndex.WHOLE_LINE, value_dtype=dtypes.int64, value_index=lookup_ops.TextFileIndex.LINE_NUMBER) root = util.Checkpoint( table=lookup_ops.HashTable(initializer, default_value=-1)) root.table_user = def_function.function( root.table.lookup, input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) self.assertEqual( 2, self.evaluate(root.table_user(constant_op.constant("gamma")))) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, save_dir) file_io.delete_file(self._vocab_path) self.assertAllClose({"output_0": [2, 0]}, _import_and_infer(save_dir, {"keys": ["gamma", "alpha"]})) second_dir = os.path.join(self.get_temp_dir(), "second_dir") # Asset paths should track the location the SavedModel is loaded from. file_io.rename(save_dir, second_dir) self.assertAllClose({"output_0": [2, 1]}, _import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
def test_table(self): initializer = lookup_ops.TextFileInitializer( self._vocab_path, key_dtype=dtypes.string, key_index=lookup_ops.TextFileIndex.WHOLE_LINE, value_dtype=dtypes.int64, value_index=lookup_ops.TextFileIndex.LINE_NUMBER) root = util.Checkpoint(table=lookup_ops.HashTable( initializer, default_value=-1)) root.table_user = def_function.function( root.table.lookup, input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) self.assertEqual( 2, self.evaluate(root.table_user(constant_op.constant("gamma")))) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, save_dir) file_io.delete_file(self._vocab_path) self.assertAllClose( {"output_0": [2, 0]}, _import_and_infer(save_dir, {"keys": ["gamma", "alpha"]})) second_dir = os.path.join(self.get_temp_dir(), "second_dir") # Asset paths should track the location the SavedModel is loaded from. file_io.rename(save_dir, second_dir) self.assertAllClose( {"output_0": [2, 1]}, _import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
def test_plot_model_with_wrapped_layers_and_models(self): inputs = keras.Input(shape=(None, 3)) lstm = keras.layers.LSTM(6, return_sequences=True, name='lstm') x = lstm(inputs) # Add layer inside a Wrapper bilstm = keras.layers.Bidirectional( keras.layers.LSTM(16, return_sequences=True, name='bilstm')) x = bilstm(x) # Add model inside a Wrapper submodel = keras.Sequential( [keras.layers.Dense(32, name='dense', input_shape=(None, 32))]) wrapped_dense = keras.layers.TimeDistributed(submodel) x = wrapped_dense(x) # Add shared submodel outputs = submodel(x) model = keras.Model(inputs, outputs) dot_img_file = 'model_2.png' try: vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True, expand_nested=True) self.assertTrue(file_io.file_exists(dot_img_file)) file_io.delete_file(dot_img_file) except ImportError: pass
def _tito_out(tensor_in): checkpoint_dir = tmp_dir if tmp_dir is None: checkpoint_dir = tempfile.mkdtemp() g = tf.Graph() with g.as_default(): si = tf.placeholder(dtype=tensor_in.dtype, shape=tensor_in.shape, name=tensor_in.op.name) so = tito_in(si) all_vars = tf.contrib.slim.get_variables_to_restore(exclude=exclude) saver = tf.train.Saver(all_vars) # Downloading the checkpoint from GCS to local speeds up saver.restore() a lot. checkpoint_tmp = os.path.join(checkpoint_dir, 'checkpoint') with file_io.FileIO(checkpoint, 'r') as f_in, file_io.FileIO(checkpoint_tmp, 'w') as f_out: f_out.write(f_in.read()) with tf.Session() as sess: saver.restore(sess, checkpoint_tmp) output_graph_def = tf.graph_util.convert_variables_to_constants(sess, g.as_graph_def(), [so.op.name]) file_io.delete_file(checkpoint_tmp) if tmp_dir is None: shutil.rmtree(checkpoint_dir) tensors_out = tf.import_graph_def(output_graph_def, input_map={si.name: tensor_in}, return_elements=[so.name]) return tensors_out[0]
def end(self, session): if self.mode == tf.estimator.ModeKeys.EVAL: current_global_step = session.run(self.global_step_tensor) with open(os.path.join(self.writer.get_logdir(), "checkpoint")) as f: checkpoints = [ckpt for ckpt in f] checkpoints = [ self.extract_global_step(ckpt) for ckpt in checkpoints[1:] ] checkpoints = list( filter(lambda gs: gs < current_global_step, checkpoints)) if len(checkpoints) > self.keep_eval_results_max_epoch: checkpoint_to_delete = checkpoints[ -self.keep_eval_results_max_epoch] tf.logging.info("Deleting %s results at the step %d", self.mode, checkpoint_to_delete) tfrecord_filespec = os.path.join( self.writer.get_logdir(), "eval_result_step{:09d}_*.tfrecord".format( checkpoint_to_delete)) alignment_filespec = os.path.join( self.writer.get_logdir(), "alignment_eval_result_step{:09d}_*.png".format( checkpoint_to_delete)) mel_filespec = os.path.join( self.writer.get_logdir(), "mel_eval_result_step{:09d}_*.png".format( checkpoint_to_delete)) for pathname in tf.gfile.Glob( [tfrecord_filespec, alignment_filespec, mel_filespec]): file_io.delete_file(pathname)
def testRename(self): file_path = os.path.join(self.get_temp_dir(), "temp_file") file_io.write_string_to_file(file_path, "testing") rename_path = os.path.join(self.get_temp_dir(), "rename_file") file_io.rename(file_path, rename_path) self.assertTrue(file_io.file_exists(rename_path)) self.assertFalse(file_io.file_exists(file_path)) file_io.delete_file(rename_path)
def remove(id, cloud_path=None): filepath = "%s/%s" % (emb_path, id_to_path(id)) if cloud_path: cloud_filepath = '%s/%s' % (cloud_path, filepath) if file_io.file_exists(cloud_filepath): file_io.delete_file(cloud_filepath) if file_io.file_exists(filepath): return file_io.delete_file(filepath)
def testCopyOverwriteFalse(self): file_path = os.path.join(self.get_temp_dir(), "temp_file") file_io.write_string_to_file(file_path, "testing") copy_path = os.path.join(self.get_temp_dir(), "copy_file") file_io.write_string_to_file(copy_path, "copy") with self.assertRaises(errors.AlreadyExistsError): file_io.copy(file_path, copy_path, overwrite=False) file_io.delete_file(file_path) file_io.delete_file(copy_path)
def testCopy(self): file_path = os.path.join(self.get_temp_dir(), "temp_file") file_io.write_string_to_file(file_path, "testing") copy_path = os.path.join(self.get_temp_dir(), "copy_file") file_io.copy(file_path, copy_path) self.assertTrue(file_io.file_exists(copy_path)) self.assertEqual(b"testing", file_io.read_file_to_string(file_path)) file_io.delete_file(file_path) file_io.delete_file(copy_path)
def _delete_file_if_exists(filespec): """Deletes files matching `filespec`.""" for pathname in file_io.get_matching_files(filespec): try: file_io.delete_file(pathname) except errors.NotFoundError: logging.warning( "Hit NotFoundError when deleting '%s', possibly because another " "process/thread is also deleting/moving the same file", pathname)
def delete_file(self, filename): """Proxy for tensorflow.python.lib.io.file_io.delete_file function. Mocks the function if a real GCS bucket is not available for testing. """ if not self.mock_gcs: tf_file_io.delete_file(filename) elif filename.startswith(self._gcs_prefix): self.local_objects.pop(filename) else: os.remove(filename)
def save(model, params): basename = "model-{}.hdf5".format(params.current_date) model.save(basename) with open(basename, "r") as input: with file_io.FileIO(params.model_filename, 'w+') as f: f.write(input.read()) file_io.delete_file(basename) print("model saved to GS: {}".format(params.model_filename))
def testRenameOverwriteFalse(self): file_path = os.path.join(self.get_temp_dir(), "temp_file") file_io.write_string_to_file(file_path, "testing") rename_path = os.path.join(self.get_temp_dir(), "rename_file") file_io.write_string_to_file(rename_path, "rename") with self.assertRaises(errors.AlreadyExistsError): file_io.rename(file_path, rename_path, overwrite=False) self.assertTrue(file_io.file_exists(rename_path)) self.assertTrue(file_io.file_exists(file_path)) file_io.delete_file(rename_path) file_io.delete_file(file_path)
def testAtomicWriteStringToFileOverwriteFalse(self): file_path = os.path.join(self._base_dir, "temp_file") file_io.atomic_write_string_to_file(file_path, "old", overwrite=False) with self.assertRaises(errors.AlreadyExistsError): file_io.atomic_write_string_to_file(file_path, "new", overwrite=False) file_contents = file_io.read_file_to_string(file_path) self.assertEqual("old", file_contents) file_io.delete_file(file_path) file_io.atomic_write_string_to_file(file_path, "new", overwrite=False) file_contents = file_io.read_file_to_string(file_path) self.assertEqual("new", file_contents)
def Remove(self, request, context): logging.debug('remove - id: %d', request.id) ids = np.array([request.id], dtype=np.int64) self.faiss_index.remove_ids(ids) filepath = self._get_filepath(request.id) if file_io.file_exists(filepath): file_io.delete_file(filepath) return pb2.SimpleReponse(message='Removed, %s!' % request.id) return pb2.SimpleReponse(message='Not existed, %s!' % request.id)
def load_weights(self, path): if path.startswith("gs://"): from google.cloud import storage bucket_name, sub_folder = split_bucket(path) storage_client = storage.Client() tmp_file = download_weight(bucket_name, sub_folder, storage_client) self.model.load_weights( tmp_file) # load_weights veut un f***** path file_io.delete_file(tmp_file) else: self.model.load_weights(path) self.compile()
def _fetch_embedding(self, emb_filepath): try: embedding = np.frombuffer( file_io.read_file_to_string(emb_filepath), dtype=np.float32) embedding = embedding.reshape(self.SHAPE) except ValueError as e: logging.warn('Could not load an embedding file from %s: %s', emb_filepath, str(e)) error_count.inc() if e.message.startswith('cannot reshape array of size 0 into'): file_io.delete_file(emb_filepath) return raise e
def test_plot_model_rnn(self): model = keras.Sequential() model.add( keras.layers.LSTM( 16, return_sequences=True, input_shape=(2, 3), name='lstm')) model.add(keras.layers.TimeDistributed(keras.layers.Dense(5, name='dense'))) dot_img_file = 'model_2.png' try: vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True) self.assertTrue(file_io.file_exists(dot_img_file)) file_io.delete_file(dot_img_file) except ImportError: pass
def testGetMatchingFiles(self): dir_path = os.path.join(self.get_temp_dir(), "temp_dir") file_io.create_dir(dir_path) files = ["file1.txt", "file2.txt", "file3.txt"] for name in files: file_path = os.path.join(dir_path, name) file_io.write_string_to_file(file_path, "testing") expected_match = [os.path.join(dir_path, name) for name in files] self.assertItemsEqual( file_io.get_matching_files(os.path.join(dir_path, "file*.txt")), expected_match) for name in files: file_path = os.path.join(dir_path, name) file_io.delete_file(file_path)
def test_plot_model_cnn(self): model = keras.Sequential() model.add( keras.layers.Conv2D( filters=2, kernel_size=(2, 3), input_shape=(3, 5, 5), name='conv')) model.add(keras.layers.Flatten(name='flat')) model.add(keras.layers.Dense(5, name='dense')) dot_img_file = 'model_1.png' try: vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True) self.assertTrue(file_io.file_exists(dot_img_file)) file_io.delete_file(dot_img_file) except ImportError: pass
def test_read_batched_sequence_example_dataset(self, sloppy_ordering): # Save protos in a sstable file in a temp folder. serialized_sequence_examples = [ SEQ_EXAMPLE_PROTO_1.SerializeToString(), SEQ_EXAMPLE_PROTO_2.SerializeToString() ] * 100 data_dir = test.get_temp_dir() data_file = os.path.join(data_dir, "test_sequence_example.tfrecord") if file_io.file_exists(data_file): file_io.delete_file(data_file) with tf_record.TFRecordWriter(data_file) as writer: for s in serialized_sequence_examples: writer.write(s) batched_dataset = data_lib.read_batched_sequence_example_dataset( file_pattern=data_file, batch_size=2, list_size=2, context_feature_spec=CONTEXT_FEATURE_SPEC, example_feature_spec=EXAMPLE_FEATURE_SPEC, reader=readers.TFRecordDataset, shuffle=False, sloppy_ordering=sloppy_ordering) features = batched_dataset.make_one_shot_iterator().get_next() self.assertAllEqual(sorted(features), ["query_length", "unigrams", "utility"]) # Check static shapes for dense tensors. self.assertAllEqual([2, 1], features["query_length"].get_shape().as_list()) self.assertAllEqual([2, 2, 1], features["utility"].get_shape().as_list()) with session.Session() as sess: sess.run(variables.local_variables_initializer()) queue_runner.start_queue_runners() feature_map = sess.run(features) # Test dense_shape, indices and values for a SparseTensor. self.assertAllEqual(feature_map["unigrams"].dense_shape, [2, 2, 3]) self.assertAllEqual( feature_map["unigrams"].indices, [[0, 0, 0], [0, 1, 0], [0, 1, 1], [0, 1, 2], [1, 0, 0]]) self.assertAllEqual( feature_map["unigrams"].values, [b"tensorflow", b"learning", b"to", b"rank", b"gbdt"]) # Check values directly for dense tensors. self.assertAllEqual(feature_map["query_length"], [[3], [2]]) self.assertAllEqual(feature_map["utility"], [[[0.], [1.0]], [[0.], [0.]]])
def testGetMatchingFiles(self): dir_path = os.path.join(self.get_temp_dir(), "temp_dir") file_io.create_dir(dir_path) files = ["file1.txt", "file2.txt", "file3.txt"] for name in files: file_path = os.path.join(dir_path, name) file_io.write_string_to_file(file_path, "testing") expected_match = [os.path.join(dir_path, name) for name in files] self.assertItemsEqual(file_io.get_matching_files(os.path.join(dir_path, "file*.txt")), expected_match) for name in files: file_path = os.path.join(dir_path, name) file_io.delete_file(file_path)
def io_write_from_temp(temp_pth, dest_pth): """ Used to save any file with tensorflow lib io (gcloud compatibility) We first save it to temp path, then load it save it to dest path and delete temp file Args: temp_pth: dest_pth: Returns: """ with file_io.FileIO(temp_pth, mode='r') as input_f: with file_io.FileIO(dest_pth, mode='w+') as output_f: output_f.write(input_f.read()) file_io.delete_file(temp_pth)
def test_plot_model_rnn(self): model = keras.Sequential() model.add( keras.layers.LSTM(16, return_sequences=True, input_shape=(2, 3), name='lstm')) model.add( keras.layers.TimeDistributed(keras.layers.Dense(5, name='dense'))) dot_img_file = 'model_2.png' try: vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True) self.assertTrue(file_io.file_exists(dot_img_file)) file_io.delete_file(dot_img_file) except ImportError: pass
def test_assets(self, cycles): file1 = self._make_asset("contents 1") file2 = self._make_asset("contents 2") root = tracking.AutoTrackable() root.asset1 = tracking.TrackableAsset(file1) root.asset2 = tracking.TrackableAsset(file2) save_dir = os.path.join(self.get_temp_dir(), "save_dir") save.save(root, save_dir) file_io.delete_file(file1) file_io.delete_file(file2) load_dir = os.path.join(self.get_temp_dir(), "load_dir") file_io.rename(save_dir, load_dir) imported = load.load(load_dir) with open(self.evaluate(imported.asset1.asset_path), "r") as f: self.assertEqual("contents 1", f.read()) with open(self.evaluate(imported.asset2.asset_path), "r") as f: self.assertEqual("contents 2", f.read())
def classify(image_data): file_name = str(uuid.uuid1()) file_io.write_string_to_file(file_name, image_data) img = scipy.ndimage.imread(file_name, mode="RGB") # Scale it to 32x32 img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe') # Predict prediction = classsifier.model.predict([img]) print(prediction[0]) #print (prediction[0].index(max(prediction[0]))) num = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] result = "This is a %s" % (num[prediction[0].tolist().index( max(prediction[0]))]) file_io.delete_file(file_name) return result
def test_assets(self): file1 = self._make_asset("contents 1") file2 = self._make_asset("contents 2") root = tracking.Checkpointable() root.asset1 = tracking.TrackableAsset(file1) root.asset2 = tracking.TrackableAsset(file2) save_dir = os.path.join(self.get_temp_dir(), "save_dir") save.save(root, save_dir, signatures={}) file_io.delete_file(file1) file_io.delete_file(file2) load_dir = os.path.join(self.get_temp_dir(), "load_dir") file_io.rename(save_dir, load_dir) imported = load.load(load_dir) with open(imported.asset1.asset_path.numpy(), "r") as f: self.assertEquals("contents 1", f.read()) with open(imported.asset2.asset_path.numpy(), "r") as f: self.assertEquals("contents 2", f.read())
def test_assets(self): file1 = self._make_asset("contents 1") file2 = self._make_asset("contents 2") root = tracking.AutoCheckpointable() root.asset1 = tracking.TrackableAsset(file1) root.asset2 = tracking.TrackableAsset(file2) save_dir = os.path.join(self.get_temp_dir(), "save_dir") save.save(root, save_dir, signatures={}) file_io.delete_file(file1) file_io.delete_file(file2) load_dir = os.path.join(self.get_temp_dir(), "load_dir") file_io.rename(save_dir, load_dir) imported = load.load(load_dir) with open(imported.asset1.asset_path.numpy(), "r") as f: self.assertEquals("contents 1", f.read()) with open(imported.asset2.asset_path.numpy(), "r") as f: self.assertEquals("contents 2", f.read())
def test_assets_import(self): file1 = self._make_asset("contents 1") file2 = self._make_asset("contents 2") root = tracking.Checkpointable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.asset1 = tracking.TrackableAsset(file1) root.asset2 = tracking.TrackableAsset(file2) save_dir = os.path.join(self.get_temp_dir(), "save_dir") save.save(root, save_dir) file_io.delete_file(file1) file_io.delete_file(file2) load_dir = os.path.join(self.get_temp_dir(), "load_dir") file_io.rename(save_dir, load_dir) imported = load.load(load_dir) with open(imported.asset1.asset_path.numpy(), "r") as f: self.assertEquals("contents 1", f.read()) with open(imported.asset2.asset_path.numpy(), "r") as f: self.assertEquals("contents 2", f.read())
def testFileDelete(self): file_path = os.path.join(self.get_temp_dir(), "temp_file") file_io.write_string_to_file(file_path, "testing") file_io.delete_file(file_path) self.assertFalse(file_io.file_exists(file_path))
def testFileDeleteFail(self): file_path = os.path.join(self._base_dir, "temp_file") with self.assertRaises(errors.NotFoundError): file_io.delete_file(file_path)
def testFileDelete(self): file_path = os.path.join(self._base_dir, "temp_file") file_io.FileIO(file_path, mode="w").write("testing") file_io.delete_file(file_path) self.assertFalse(file_io.file_exists(file_path))
def testAPIDefCompatibility(self): # Get base ApiDef name_to_base_api_def = self._GetBaseApiMap() # Extract Python API visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.do_not_descend_map['tf'].append('contrib') traverse.traverse(tf, public_api_visitor) proto_dict = visitor.GetProtos() # Map from first character of op name to Python ApiDefs. api_def_map = defaultdict(api_def_pb2.ApiDefs) # We need to override all endpoints even if 1 endpoint differs from base # ApiDef. So, we first create a map from an op to all its endpoints. op_to_endpoint_name = defaultdict(list) # Generate map from generated python op to endpoint names. for public_module, value in proto_dict.items(): module_obj = _GetSymbol(public_module) for sym in value.tf_module.member_method: obj = getattr(module_obj, sym.name) # Check if object is defined in gen_* module. That is, # the object has been generated from OpDef. if hasattr(obj, '__module__') and _IsGenModule(obj.__module__): if obj.__name__ not in name_to_base_api_def: # Symbol might be defined only in Python and not generated from # C++ api. continue relative_public_module = public_module[len('tensorflow.'):] full_name = (relative_public_module + '.' + sym.name if relative_public_module else sym.name) op_to_endpoint_name[obj].append(full_name) # Generate Python ApiDef overrides. for op, endpoint_names in op_to_endpoint_name.items(): api_def = self._CreatePythonApiDef( name_to_base_api_def[op.__name__], endpoint_names) if api_def: api_defs = api_def_map[op.__name__[0].upper()] api_defs.op.extend([api_def]) for key in _ALPHABET: # Get new ApiDef for the given key. new_api_defs_str = '' if key in api_def_map: new_api_defs = api_def_map[key] new_api_defs.op.sort(key=attrgetter('graph_op_name')) new_api_defs_str = str(new_api_defs) # Get current ApiDef for the given key. api_defs_file_path = os.path.join( _PYTHON_API_DIR, 'api_def_%s.pbtxt' % key) old_api_defs_str = '' if file_io.file_exists(api_defs_file_path): old_api_defs_str = file_io.read_file_to_string(api_defs_file_path) if old_api_defs_str == new_api_defs_str: continue if FLAGS.update_goldens: if not new_api_defs_str: logging.info('Deleting %s...' % api_defs_file_path) file_io.delete_file(api_defs_file_path) else: logging.info('Updating %s...' % api_defs_file_path) file_io.write_string_to_file(api_defs_file_path, new_api_defs_str) else: self.assertMultiLineEqual( old_api_defs_str, new_api_defs_str, 'To update golden API files, run api_compatibility_test locally ' 'with --update_goldens=True flag.')
def delete_file(cls, filename): return file_io.delete_file(filename)
def testAPIDefCompatibility(self): # Get base ApiDef name_to_base_api_def = self._GetBaseApiMap() snake_to_camel_graph_op_names = { self._GenerateLowerCaseOpName(name): name for name in name_to_base_api_def.keys()} # Extract Python API visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.do_not_descend_map['tf'].append('contrib') traverse.traverse(tf, public_api_visitor) proto_dict = visitor.GetProtos() # Map from file path to Python ApiDefs. new_api_defs_map = defaultdict(api_def_pb2.ApiDefs) # We need to override all endpoints even if 1 endpoint differs from base # ApiDef. So, we first create a map from an op to all its endpoints. op_to_endpoint_name = defaultdict(list) # Generate map from generated python op to endpoint names. for public_module, value in proto_dict.items(): module_obj = _GetSymbol(public_module) for sym in value.tf_module.member_method: obj = getattr(module_obj, sym.name) # Check if object is defined in gen_* module. That is, # the object has been generated from OpDef. if hasattr(obj, '__module__') and _IsGenModule(obj.__module__): if obj.__name__ not in snake_to_camel_graph_op_names: # Symbol might be defined only in Python and not generated from # C++ api. continue relative_public_module = public_module[len('tensorflow.'):] full_name = (relative_public_module + '.' + sym.name if relative_public_module else sym.name) op_to_endpoint_name[obj].append(full_name) # Generate Python ApiDef overrides. for op, endpoint_names in op_to_endpoint_name.items(): graph_op_name = snake_to_camel_graph_op_names[op.__name__] api_def = self._CreatePythonApiDef( name_to_base_api_def[graph_op_name], endpoint_names) if api_def: file_path = _GetApiDefFilePath(graph_op_name) api_defs = new_api_defs_map[file_path] api_defs.op.extend([api_def]) self._AddHiddenOpOverrides(name_to_base_api_def, new_api_defs_map) old_api_defs_map = _GetGoldenApiDefs() for file_path, new_api_defs in new_api_defs_map.items(): # Get new ApiDef string. new_api_defs_str = str(new_api_defs) # Get current ApiDef for the given file. old_api_defs_str = ( old_api_defs_map[file_path] if file_path in old_api_defs_map else '') if old_api_defs_str == new_api_defs_str: continue if FLAGS.update_goldens: logging.info('Updating %s...' % file_path) file_io.write_string_to_file(file_path, new_api_defs_str) else: self.assertMultiLineEqual( old_api_defs_str, new_api_defs_str, 'To update golden API files, run api_compatibility_test locally ' 'with --update_goldens=True flag.') for file_path in set(old_api_defs_map) - set(new_api_defs_map): if FLAGS.update_goldens: logging.info('Deleting %s...' % file_path) file_io.delete_file(file_path) else: self.fail( '%s file is no longer needed and should be removed.' 'To update golden API files, run api_compatibility_test locally ' 'with --update_goldens=True flag.' % file_path)
def _delete_file_if_exists(filespec): """Deletes files matching `filespec`.""" for pathname in file_io.get_matching_files(filespec): file_io.delete_file(pathname)
def _AssertProtoDictEquals(self, expected_dict, actual_dict, verbose=False, update_goldens=False, additional_missing_object_message='', api_version=2): """Diff given dicts of protobufs and report differences a readable way. Args: expected_dict: a dict of TFAPIObject protos constructed from golden files. actual_dict: a ict of TFAPIObject protos constructed by reading from the TF package linked to the test. verbose: Whether to log the full diffs, or simply report which files were different. update_goldens: Whether to update goldens when there are diffs found. additional_missing_object_message: Message to print when a symbol is missing. api_version: TensorFlow API version to test. """ diffs = [] verbose_diffs = [] expected_keys = set(expected_dict.keys()) actual_keys = set(actual_dict.keys()) only_in_expected = expected_keys - actual_keys only_in_actual = actual_keys - expected_keys all_keys = expected_keys | actual_keys # This will be populated below. updated_keys = [] for key in all_keys: diff_message = '' verbose_diff_message = '' # First check if the key is not found in one or the other. if key in only_in_expected: diff_message = 'Object %s expected but not found (removed). %s' % ( key, additional_missing_object_message) verbose_diff_message = diff_message elif key in only_in_actual: diff_message = 'New object %s found (added).' % key verbose_diff_message = diff_message else: # Do not truncate diff self.maxDiff = None # pylint: disable=invalid-name # Now we can run an actual proto diff. try: self.assertProtoEquals(expected_dict[key], actual_dict[key]) except AssertionError as e: updated_keys.append(key) diff_message = 'Change detected in python object: %s.' % key verbose_diff_message = str(e) # All difference cases covered above. If any difference found, add to the # list. if diff_message: diffs.append(diff_message) verbose_diffs.append(verbose_diff_message) # If diffs are found, handle them based on flags. if diffs: diff_count = len(diffs) logging.error(self._test_readme_message) logging.error('%d differences found between API and golden.', diff_count) messages = verbose_diffs if verbose else diffs for i in range(diff_count): print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr) if update_goldens: # Write files if requested. logging.warning(self._update_golden_warning) # If the keys are only in expected, some objects are deleted. # Remove files. for key in only_in_expected: filepath = _KeyToFilePath(key, api_version) file_io.delete_file(filepath) # If the files are only in actual (current library), these are new # modules. Write them to files. Also record all updates in files. for key in only_in_actual | set(updated_keys): filepath = _KeyToFilePath(key, api_version) file_io.write_string_to_file( filepath, text_format.MessageToString(actual_dict[key])) else: # Fail if we cannot fix the test by updating goldens. self.fail('%d differences found between API and golden.' % diff_count) else: logging.info('No differences found between API and golden.')