def _tito_out(tensor_in):
            checkpoint_dir = tmp_dir
            if tmp_dir is None:
                checkpoint_dir = tempfile.mkdtemp()

            g = tf.Graph()
            with g.as_default():
                si = tf.placeholder(dtype=tensor_in.dtype,
                                    shape=tensor_in.shape,
                                    name=tensor_in.op.name)
                so = tito_in(si)
                all_vars = tf.contrib.slim.get_variables_to_restore(
                    exclude=exclude)
                saver = tf.train.Saver(all_vars)
                # Downloading the checkpoint from GCS to local speeds up saver.restore() a lot.
                checkpoint_tmp = os.path.join(checkpoint_dir, 'checkpoint')
                with file_io.FileIO(checkpoint, 'r') as f_in, file_io.FileIO(
                        checkpoint_tmp, 'w') as f_out:
                    f_out.write(f_in.read())
                with tf.Session() as sess:
                    saver.restore(sess, checkpoint_tmp)
                    output_graph_def = tf.graph_util.convert_variables_to_constants(
                        sess, g.as_graph_def(), [so.op.name])
                file_io.delete_file(checkpoint_tmp)
                if tmp_dir is None:
                    shutil.rmtree(checkpoint_dir)

            tensors_out = tf.import_graph_def(output_graph_def,
                                              input_map={si.name: tensor_in},
                                              return_elements=[so.name])
            return tensors_out[0]
Пример #2
0
def create_object_test():
  """Verifies file_io's object manipulation methods ."""
  starttime = int(round(time.time() * 1000))
  dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime)
  print("Creating dir %s." % dir_name)
  file_io.create_dir(dir_name)

  # Create a file in this directory.
  file_name = "%s/test_file.txt" % dir_name
  print("Creating file %s." % file_name)
  file_io.write_string_to_file(file_name, "test file creation.")

  list_files_pattern = "%s/test_file*.txt" % dir_name
  print("Getting files matching pattern %s." % list_files_pattern)
  files_list = file_io.get_matching_files(list_files_pattern)
  print(files_list)

  assert len(files_list) == 1
  assert files_list[0] == file_name

  # Cleanup test files.
  print("Deleting file %s." % file_name)
  file_io.delete_file(file_name)

  # Delete directory.
  print("Deleting directory %s." % dir_name)
  file_io.delete_recursively(dir_name)
Пример #3
0
    def _sync(self):
        added_ids_filepath = 'models/added_ids.txt'
        added_ids = set(self._read_added_ids(added_ids_filepath))
        if not added_ids:
            return
        target_ids = set(self._get_ids())
        logging.info('added_ids count: %d', len(added_ids))
        logging.info('target_ids count: %d', len(target_ids))

        remove_ids = list(added_ids - target_ids)
        add_ids = list(target_ids - added_ids)
        logging.info('remove_ids count: %d', len(remove_ids))
        logging.info('add_ids count: %d', len(add_ids))

        if remove_ids:
            for ids in chunks(remove_ids, 20000):
                ids = np.array(ids, dtype=np.int64)
                self.faiss_index.remove_ids(ids)
            logging.info("removed")

        if add_ids:
            for ids in chunks(add_ids, 20000):
                t0 = time.time()
                filepaths = [self._get_filepath(id) for id in ids]
                xb = self._path_to_xb(filepaths)
                ids = np.array(ids, dtype=np.int64)
                self.faiss_index.add(xb, ids)
                logging.info("%d embeddings added %.3f s", xb.shape[0], time.time() - t0)

        file_io.delete_file(added_ids_filepath)
        logging.info("Synced. ntotal: %d", self.faiss_index.ntotal())
Пример #4
0
 def testFileWrite(self):
   file_path = os.path.join(self.get_temp_dir(), "temp_file")
   file_io.write_string_to_file(file_path, "testing")
   self.assertTrue(file_io.file_exists(file_path))
   file_contents = file_io.read_file_to_string(file_path)
   self.assertEqual(b"testing", file_contents)
   file_io.delete_file(file_path)
Пример #5
0
 def testFileWrite(self):
     file_path = os.path.join(self.get_temp_dir(), "temp_file")
     file_io.write_string_to_file(file_path, "testing")
     self.assertTrue(file_io.file_exists(file_path))
     file_contents = file_io.read_file_to_string(file_path)
     self.assertEqual(b"testing", file_contents)
     file_io.delete_file(file_path)
Пример #6
0
def create_object_test():
    """Verifies file_io's object manipulation methods ."""
    starttime = int(round(time.time() * 1000))
    dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime)
    print("Creating dir %s." % dir_name)
    file_io.create_dir(dir_name)

    # Create a file in this directory.
    file_name = "%s/test_file.txt" % dir_name
    print("Creating file %s." % file_name)
    file_io.write_string_to_file(file_name, "test file creation.")

    list_files_pattern = "%s/test_file*.txt" % dir_name
    print("Getting files matching pattern %s." % list_files_pattern)
    files_list = file_io.get_matching_files(list_files_pattern)
    print(files_list)

    assert len(files_list) == 1
    assert files_list[0] == file_name

    # Cleanup test files.
    print("Deleting file %s." % file_name)
    file_io.delete_file(file_name)

    # Delete directory.
    print("Deleting directory %s." % dir_name)
    file_io.delete_recursively(dir_name)
Пример #7
0
 def _v1_asset_saved_model(self):
     export_graph = ops.Graph()
     vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt")
     with open(vocab_path, "w") as f:
         f.write("alpha\nbeta\ngamma\n")
     with export_graph.as_default():
         initializer = lookup_ops.TextFileInitializer(
             vocab_path,
             key_dtype=dtypes.string,
             key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
             value_dtype=dtypes.int64,
             value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
         table = lookup_ops.HashTable(initializer, default_value=-1)
         start = array_ops.placeholder(shape=None,
                                       dtype=dtypes.string,
                                       name="in")
         output = table.lookup(start, name="out")
         with session_lib.Session() as session:
             session.run([table.initializer])
             path = os.path.join(self.get_temp_dir(), "saved_model",
                                 str(ops.uid()))
             simple_save.simple_save(session,
                                     path,
                                     inputs={"start": start},
                                     outputs={"output": output},
                                     legacy_init_op=table.initializer)
     file_io.delete_file(vocab_path)
     return path
    def test_libsvm_generator(self):
        data_dir = test.get_temp_dir()
        data_file = os.path.join(data_dir, "test_libvsvm.txt")
        if file_io.file_exists(data_file):
            file_io.delete_file(data_file)

        with open(data_file, "wt") as writer:
            writer.write(LIBSVM_DATA)

        want = {
            "1": np.array([[0.1], [0.], [0.12], [0.]], dtype=np.float32),
            "2": np.array([[0.], [0.13], [0.], [0.]], dtype=np.float32),
            "3": np.array([[0.3], [0.], [0.], [0.]], dtype=np.float32),
            "4": np.array([[-0.4], [0.], [0.24], [0.]], dtype=np.float32),
            "5": np.array([[0.], [0.], [0.5], [0.]], dtype=np.float32),
        }

        reader = data_lib.libsvm_generator(data_file, 5, 4, seed=10)

        for features, labels in reader():
            self.assertAllEqual(labels, [2.0, 0., 1.0, -1.0])
            self.assertAllEqual(sorted(features.keys()), sorted(want.keys()))
            for k in sorted(want):
                self.assertAllEqual(features.get(k), want.get(k))
            break
Пример #9
0
 def _v1_asset_saved_model(self):
   export_graph = ops.Graph()
   vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt")
   with open(vocab_path, "w") as f:
     f.write("alpha\nbeta\ngamma\n")
   with export_graph.as_default():
     initializer = lookup_ops.TextFileInitializer(
         vocab_path,
         key_dtype=dtypes.string,
         key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
         value_dtype=dtypes.int64,
         value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
     table = lookup_ops.HashTable(
         initializer, default_value=-1)
     start = array_ops.placeholder(
         shape=None, dtype=dtypes.string, name="in")
     output = table.lookup(start, name="out")
     with session_lib.Session() as session:
       session.run([table.initializer])
       path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
       simple_save.simple_save(
           session,
           path,
           inputs={"start": start},
           outputs={"output": output},
           legacy_init_op=table.initializer)
   file_io.delete_file(vocab_path)
   return path
Пример #10
0
 def test_table(self):
     initializer = lookup_ops.TextFileInitializer(
         self._vocab_path,
         key_dtype=dtypes.string,
         key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
         value_dtype=dtypes.int64,
         value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
     root = util.Checkpoint(
         table=lookup_ops.HashTable(initializer, default_value=-1))
     root.table_user = def_function.function(
         root.table.lookup,
         input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
     self.assertEqual(
         2, self.evaluate(root.table_user(constant_op.constant("gamma"))))
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     save.save(root, save_dir)
     file_io.delete_file(self._vocab_path)
     self.assertAllClose({"output_0": [2, 0]},
                         _import_and_infer(save_dir,
                                           {"keys": ["gamma", "alpha"]}))
     second_dir = os.path.join(self.get_temp_dir(), "second_dir")
     # Asset paths should track the location the SavedModel is loaded from.
     file_io.rename(save_dir, second_dir)
     self.assertAllClose({"output_0": [2, 1]},
                         _import_and_infer(second_dir,
                                           {"keys": ["gamma", "beta"]}))
Пример #11
0
 def test_table(self):
   initializer = lookup_ops.TextFileInitializer(
       self._vocab_path,
       key_dtype=dtypes.string,
       key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
       value_dtype=dtypes.int64,
       value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
   root = util.Checkpoint(table=lookup_ops.HashTable(
       initializer, default_value=-1))
   root.table_user = def_function.function(
       root.table.lookup,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
   self.assertEqual(
       2,
       self.evaluate(root.table_user(constant_op.constant("gamma"))))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(root, save_dir)
   file_io.delete_file(self._vocab_path)
   self.assertAllClose(
       {"output_0": [2, 0]},
       _import_and_infer(save_dir, {"keys": ["gamma", "alpha"]}))
   second_dir = os.path.join(self.get_temp_dir(), "second_dir")
   # Asset paths should track the location the SavedModel is loaded from.
   file_io.rename(save_dir, second_dir)
   self.assertAllClose(
       {"output_0": [2, 1]},
       _import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
Пример #12
0
 def test_plot_model_with_wrapped_layers_and_models(self):
     inputs = keras.Input(shape=(None, 3))
     lstm = keras.layers.LSTM(6, return_sequences=True, name='lstm')
     x = lstm(inputs)
     # Add layer inside a Wrapper
     bilstm = keras.layers.Bidirectional(
         keras.layers.LSTM(16, return_sequences=True, name='bilstm'))
     x = bilstm(x)
     # Add model inside a Wrapper
     submodel = keras.Sequential(
         [keras.layers.Dense(32, name='dense', input_shape=(None, 32))])
     wrapped_dense = keras.layers.TimeDistributed(submodel)
     x = wrapped_dense(x)
     # Add shared submodel
     outputs = submodel(x)
     model = keras.Model(inputs, outputs)
     dot_img_file = 'model_2.png'
     try:
         vis_utils.plot_model(model,
                              to_file=dot_img_file,
                              show_shapes=True,
                              expand_nested=True)
         self.assertTrue(file_io.file_exists(dot_img_file))
         file_io.delete_file(dot_img_file)
     except ImportError:
         pass
Пример #13
0
    def _tito_out(tensor_in):
      checkpoint_dir = tmp_dir
      if tmp_dir is None:
        checkpoint_dir = tempfile.mkdtemp()

      g = tf.Graph()
      with g.as_default():
        si = tf.placeholder(dtype=tensor_in.dtype, shape=tensor_in.shape, name=tensor_in.op.name)
        so = tito_in(si)
        all_vars = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
        saver = tf.train.Saver(all_vars)
        # Downloading the checkpoint from GCS to local speeds up saver.restore() a lot.
        checkpoint_tmp = os.path.join(checkpoint_dir, 'checkpoint')
        with file_io.FileIO(checkpoint, 'r') as f_in, file_io.FileIO(checkpoint_tmp, 'w') as f_out:
          f_out.write(f_in.read())
        with tf.Session() as sess:
          saver.restore(sess, checkpoint_tmp)
          output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                          g.as_graph_def(),
                                                                          [so.op.name])
        file_io.delete_file(checkpoint_tmp)
        if tmp_dir is None:
          shutil.rmtree(checkpoint_dir)

      tensors_out = tf.import_graph_def(output_graph_def,
                                        input_map={si.name: tensor_in},
                                        return_elements=[so.name])
      return tensors_out[0]
Пример #14
0
 def end(self, session):
     if self.mode == tf.estimator.ModeKeys.EVAL:
         current_global_step = session.run(self.global_step_tensor)
         with open(os.path.join(self.writer.get_logdir(),
                                "checkpoint")) as f:
             checkpoints = [ckpt for ckpt in f]
             checkpoints = [
                 self.extract_global_step(ckpt) for ckpt in checkpoints[1:]
             ]
             checkpoints = list(
                 filter(lambda gs: gs < current_global_step, checkpoints))
             if len(checkpoints) > self.keep_eval_results_max_epoch:
                 checkpoint_to_delete = checkpoints[
                     -self.keep_eval_results_max_epoch]
                 tf.logging.info("Deleting %s results at the step %d",
                                 self.mode, checkpoint_to_delete)
                 tfrecord_filespec = os.path.join(
                     self.writer.get_logdir(),
                     "eval_result_step{:09d}_*.tfrecord".format(
                         checkpoint_to_delete))
                 alignment_filespec = os.path.join(
                     self.writer.get_logdir(),
                     "alignment_eval_result_step{:09d}_*.png".format(
                         checkpoint_to_delete))
                 mel_filespec = os.path.join(
                     self.writer.get_logdir(),
                     "mel_eval_result_step{:09d}_*.png".format(
                         checkpoint_to_delete))
                 for pathname in tf.gfile.Glob(
                     [tfrecord_filespec, alignment_filespec, mel_filespec]):
                     file_io.delete_file(pathname)
Пример #15
0
 def testRename(self):
   file_path = os.path.join(self.get_temp_dir(), "temp_file")
   file_io.write_string_to_file(file_path, "testing")
   rename_path = os.path.join(self.get_temp_dir(), "rename_file")
   file_io.rename(file_path, rename_path)
   self.assertTrue(file_io.file_exists(rename_path))
   self.assertFalse(file_io.file_exists(file_path))
   file_io.delete_file(rename_path)
Пример #16
0
 def testRename(self):
     file_path = os.path.join(self.get_temp_dir(), "temp_file")
     file_io.write_string_to_file(file_path, "testing")
     rename_path = os.path.join(self.get_temp_dir(), "rename_file")
     file_io.rename(file_path, rename_path)
     self.assertTrue(file_io.file_exists(rename_path))
     self.assertFalse(file_io.file_exists(file_path))
     file_io.delete_file(rename_path)
def remove(id, cloud_path=None):
    filepath = "%s/%s" % (emb_path, id_to_path(id))
    if cloud_path:
        cloud_filepath = '%s/%s' % (cloud_path, filepath)
        if file_io.file_exists(cloud_filepath):
            file_io.delete_file(cloud_filepath)
    if file_io.file_exists(filepath):
        return file_io.delete_file(filepath)
Пример #18
0
 def testCopyOverwriteFalse(self):
     file_path = os.path.join(self.get_temp_dir(), "temp_file")
     file_io.write_string_to_file(file_path, "testing")
     copy_path = os.path.join(self.get_temp_dir(), "copy_file")
     file_io.write_string_to_file(copy_path, "copy")
     with self.assertRaises(errors.AlreadyExistsError):
         file_io.copy(file_path, copy_path, overwrite=False)
     file_io.delete_file(file_path)
     file_io.delete_file(copy_path)
Пример #19
0
 def testCopy(self):
     file_path = os.path.join(self.get_temp_dir(), "temp_file")
     file_io.write_string_to_file(file_path, "testing")
     copy_path = os.path.join(self.get_temp_dir(), "copy_file")
     file_io.copy(file_path, copy_path)
     self.assertTrue(file_io.file_exists(copy_path))
     self.assertEqual(b"testing", file_io.read_file_to_string(file_path))
     file_io.delete_file(file_path)
     file_io.delete_file(copy_path)
Пример #20
0
def _delete_file_if_exists(filespec):
  """Deletes files matching `filespec`."""
  for pathname in file_io.get_matching_files(filespec):
    try:
      file_io.delete_file(pathname)
    except errors.NotFoundError:
      logging.warning(
          "Hit NotFoundError when deleting '%s', possibly because another "
          "process/thread is also deleting/moving the same file", pathname)
Пример #21
0
 def testCopy(self):
   file_path = os.path.join(self.get_temp_dir(), "temp_file")
   file_io.write_string_to_file(file_path, "testing")
   copy_path = os.path.join(self.get_temp_dir(), "copy_file")
   file_io.copy(file_path, copy_path)
   self.assertTrue(file_io.file_exists(copy_path))
   self.assertEqual(b"testing", file_io.read_file_to_string(file_path))
   file_io.delete_file(file_path)
   file_io.delete_file(copy_path)
Пример #22
0
 def testCopyOverwriteFalse(self):
   file_path = os.path.join(self.get_temp_dir(), "temp_file")
   file_io.write_string_to_file(file_path, "testing")
   copy_path = os.path.join(self.get_temp_dir(), "copy_file")
   file_io.write_string_to_file(copy_path, "copy")
   with self.assertRaises(errors.AlreadyExistsError):
     file_io.copy(file_path, copy_path, overwrite=False)
   file_io.delete_file(file_path)
   file_io.delete_file(copy_path)
Пример #23
0
 def delete_file(self, filename):
     """Proxy for tensorflow.python.lib.io.file_io.delete_file function. Mocks
     the function if a real GCS bucket is not available for testing.
     """
     if not self.mock_gcs:
         tf_file_io.delete_file(filename)
     elif filename.startswith(self._gcs_prefix):
         self.local_objects.pop(filename)
     else:
         os.remove(filename)
Пример #24
0
 def delete_file(self, filename):
     """Proxy for tensorflow.python.lib.io.file_io.delete_file function. Mocks
     the function if a real GCS bucket is not available for testing.
     """
     if not self.mock_gcs:
         tf_file_io.delete_file(filename)
     elif filename.startswith(self._gcs_prefix):
         self.local_objects.pop(filename)
     else:
         os.remove(filename)
def save(model, params):
    basename = "model-{}.hdf5".format(params.current_date)
    model.save(basename)

    with open(basename, "r") as input:
        with file_io.FileIO(params.model_filename, 'w+') as f:
            f.write(input.read())

    file_io.delete_file(basename)
    print("model saved to GS: {}".format(params.model_filename))
Пример #26
0
 def testRenameOverwriteFalse(self):
   file_path = os.path.join(self.get_temp_dir(), "temp_file")
   file_io.write_string_to_file(file_path, "testing")
   rename_path = os.path.join(self.get_temp_dir(), "rename_file")
   file_io.write_string_to_file(rename_path, "rename")
   with self.assertRaises(errors.AlreadyExistsError):
     file_io.rename(file_path, rename_path, overwrite=False)
   self.assertTrue(file_io.file_exists(rename_path))
   self.assertTrue(file_io.file_exists(file_path))
   file_io.delete_file(rename_path)
   file_io.delete_file(file_path)
Пример #27
0
 def testAtomicWriteStringToFileOverwriteFalse(self):
   file_path = os.path.join(self._base_dir, "temp_file")
   file_io.atomic_write_string_to_file(file_path, "old", overwrite=False)
   with self.assertRaises(errors.AlreadyExistsError):
     file_io.atomic_write_string_to_file(file_path, "new", overwrite=False)
   file_contents = file_io.read_file_to_string(file_path)
   self.assertEqual("old", file_contents)
   file_io.delete_file(file_path)
   file_io.atomic_write_string_to_file(file_path, "new", overwrite=False)
   file_contents = file_io.read_file_to_string(file_path)
   self.assertEqual("new", file_contents)
Пример #28
0
    def Remove(self, request, context):
        logging.debug('remove - id: %d', request.id)
        ids = np.array([request.id], dtype=np.int64)
        self.faiss_index.remove_ids(ids)

        filepath = self._get_filepath(request.id) 
        if file_io.file_exists(filepath):
            file_io.delete_file(filepath)
            return pb2.SimpleReponse(message='Removed, %s!' % request.id)

        return pb2.SimpleReponse(message='Not existed, %s!' % request.id)
Пример #29
0
 def testRenameOverwriteFalse(self):
     file_path = os.path.join(self.get_temp_dir(), "temp_file")
     file_io.write_string_to_file(file_path, "testing")
     rename_path = os.path.join(self.get_temp_dir(), "rename_file")
     file_io.write_string_to_file(rename_path, "rename")
     with self.assertRaises(errors.AlreadyExistsError):
         file_io.rename(file_path, rename_path, overwrite=False)
     self.assertTrue(file_io.file_exists(rename_path))
     self.assertTrue(file_io.file_exists(file_path))
     file_io.delete_file(rename_path)
     file_io.delete_file(file_path)
Пример #30
0
 def load_weights(self, path):
     if path.startswith("gs://"):
         from google.cloud import storage
         bucket_name, sub_folder = split_bucket(path)
         storage_client = storage.Client()
         tmp_file = download_weight(bucket_name, sub_folder, storage_client)
         self.model.load_weights(
             tmp_file)  # load_weights veut un f***** path
         file_io.delete_file(tmp_file)
     else:
         self.model.load_weights(path)
     self.compile()
Пример #31
0
 def _fetch_embedding(self, emb_filepath):
     try:
         embedding = np.frombuffer(
             file_io.read_file_to_string(emb_filepath), dtype=np.float32)
         embedding = embedding.reshape(self.SHAPE)
     except ValueError as e:
         logging.warn('Could not load an embedding file from %s: %s',
                      emb_filepath, str(e))
         error_count.inc()
         if e.message.startswith('cannot reshape array of size 0 into'):
             file_io.delete_file(emb_filepath)
             return
         raise e
Пример #32
0
 def test_plot_model_rnn(self):
   model = keras.Sequential()
   model.add(
       keras.layers.LSTM(
           16, return_sequences=True, input_shape=(2, 3), name='lstm'))
   model.add(keras.layers.TimeDistributed(keras.layers.Dense(5, name='dense')))
   dot_img_file = 'model_2.png'
   try:
     vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
     self.assertTrue(file_io.file_exists(dot_img_file))
     file_io.delete_file(dot_img_file)
   except ImportError:
     pass
Пример #33
0
 def testGetMatchingFiles(self):
     dir_path = os.path.join(self.get_temp_dir(), "temp_dir")
     file_io.create_dir(dir_path)
     files = ["file1.txt", "file2.txt", "file3.txt"]
     for name in files:
         file_path = os.path.join(dir_path, name)
         file_io.write_string_to_file(file_path, "testing")
     expected_match = [os.path.join(dir_path, name) for name in files]
     self.assertItemsEqual(
         file_io.get_matching_files(os.path.join(dir_path, "file*.txt")),
         expected_match)
     for name in files:
         file_path = os.path.join(dir_path, name)
         file_io.delete_file(file_path)
Пример #34
0
 def test_plot_model_cnn(self):
   model = keras.Sequential()
   model.add(
       keras.layers.Conv2D(
           filters=2, kernel_size=(2, 3), input_shape=(3, 5, 5), name='conv'))
   model.add(keras.layers.Flatten(name='flat'))
   model.add(keras.layers.Dense(5, name='dense'))
   dot_img_file = 'model_1.png'
   try:
     vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
     self.assertTrue(file_io.file_exists(dot_img_file))
     file_io.delete_file(dot_img_file)
   except ImportError:
     pass
Пример #35
0
    def test_read_batched_sequence_example_dataset(self, sloppy_ordering):
        # Save protos in a sstable file in a temp folder.
        serialized_sequence_examples = [
            SEQ_EXAMPLE_PROTO_1.SerializeToString(),
            SEQ_EXAMPLE_PROTO_2.SerializeToString()
        ] * 100
        data_dir = test.get_temp_dir()
        data_file = os.path.join(data_dir, "test_sequence_example.tfrecord")
        if file_io.file_exists(data_file):
            file_io.delete_file(data_file)

        with tf_record.TFRecordWriter(data_file) as writer:
            for s in serialized_sequence_examples:
                writer.write(s)

        batched_dataset = data_lib.read_batched_sequence_example_dataset(
            file_pattern=data_file,
            batch_size=2,
            list_size=2,
            context_feature_spec=CONTEXT_FEATURE_SPEC,
            example_feature_spec=EXAMPLE_FEATURE_SPEC,
            reader=readers.TFRecordDataset,
            shuffle=False,
            sloppy_ordering=sloppy_ordering)

        features = batched_dataset.make_one_shot_iterator().get_next()
        self.assertAllEqual(sorted(features),
                            ["query_length", "unigrams", "utility"])
        # Check static shapes for dense tensors.
        self.assertAllEqual([2, 1],
                            features["query_length"].get_shape().as_list())
        self.assertAllEqual([2, 2, 1],
                            features["utility"].get_shape().as_list())

        with session.Session() as sess:
            sess.run(variables.local_variables_initializer())
            queue_runner.start_queue_runners()
            feature_map = sess.run(features)
            # Test dense_shape, indices and values for a SparseTensor.
            self.assertAllEqual(feature_map["unigrams"].dense_shape, [2, 2, 3])
            self.assertAllEqual(
                feature_map["unigrams"].indices,
                [[0, 0, 0], [0, 1, 0], [0, 1, 1], [0, 1, 2], [1, 0, 0]])
            self.assertAllEqual(
                feature_map["unigrams"].values,
                [b"tensorflow", b"learning", b"to", b"rank", b"gbdt"])
            # Check values directly for dense tensors.
            self.assertAllEqual(feature_map["query_length"], [[3], [2]])
            self.assertAllEqual(feature_map["utility"],
                                [[[0.], [1.0]], [[0.], [0.]]])
Пример #36
0
 def testGetMatchingFiles(self):
   dir_path = os.path.join(self.get_temp_dir(), "temp_dir")
   file_io.create_dir(dir_path)
   files = ["file1.txt", "file2.txt", "file3.txt"]
   for name in files:
     file_path = os.path.join(dir_path, name)
     file_io.write_string_to_file(file_path, "testing")
   expected_match = [os.path.join(dir_path, name) for name in files]
   self.assertItemsEqual(file_io.get_matching_files(os.path.join(dir_path,
                                                                 "file*.txt")),
                         expected_match)
   for name in files:
     file_path = os.path.join(dir_path, name)
     file_io.delete_file(file_path)
Пример #37
0
 def test_plot_model_cnn(self):
   model = keras.Sequential()
   model.add(
       keras.layers.Conv2D(
           filters=2, kernel_size=(2, 3), input_shape=(3, 5, 5), name='conv'))
   model.add(keras.layers.Flatten(name='flat'))
   model.add(keras.layers.Dense(5, name='dense'))
   dot_img_file = 'model_1.png'
   try:
     vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
     self.assertTrue(file_io.file_exists(dot_img_file))
     file_io.delete_file(dot_img_file)
   except ImportError:
     pass
Пример #38
0
def io_write_from_temp(temp_pth, dest_pth):
    """
    Used to save any file with tensorflow lib io (gcloud compatibility)
    We first save it to temp path, then load it save it to dest path and delete temp file
    Args:
        temp_pth:
        dest_pth:

    Returns:

    """
    with file_io.FileIO(temp_pth, mode='r') as input_f:
        with file_io.FileIO(dest_pth, mode='w+') as output_f:
            output_f.write(input_f.read())
            file_io.delete_file(temp_pth)
Пример #39
0
 def test_plot_model_rnn(self):
     model = keras.Sequential()
     model.add(
         keras.layers.LSTM(16,
                           return_sequences=True,
                           input_shape=(2, 3),
                           name='lstm'))
     model.add(
         keras.layers.TimeDistributed(keras.layers.Dense(5, name='dense')))
     dot_img_file = 'model_2.png'
     try:
         vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
         self.assertTrue(file_io.file_exists(dot_img_file))
         file_io.delete_file(dot_img_file)
     except ImportError:
         pass
Пример #40
0
  def test_assets(self, cycles):
    file1 = self._make_asset("contents 1")
    file2 = self._make_asset("contents 2")

    root = tracking.AutoTrackable()
    root.asset1 = tracking.TrackableAsset(file1)
    root.asset2 = tracking.TrackableAsset(file2)

    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
    save.save(root, save_dir)

    file_io.delete_file(file1)
    file_io.delete_file(file2)
    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
    file_io.rename(save_dir, load_dir)

    imported = load.load(load_dir)
    with open(self.evaluate(imported.asset1.asset_path), "r") as f:
      self.assertEqual("contents 1", f.read())
    with open(self.evaluate(imported.asset2.asset_path), "r") as f:
      self.assertEqual("contents 2", f.read())
Пример #41
0
def classify(image_data):
    file_name = str(uuid.uuid1())
    file_io.write_string_to_file(file_name, image_data)
    img = scipy.ndimage.imread(file_name, mode="RGB")
    # Scale it to 32x32
    img = scipy.misc.imresize(img, (32, 32),
                              interp="bicubic").astype(np.float32,
                                                       casting='unsafe')

    # Predict
    prediction = classsifier.model.predict([img])
    print(prediction[0])
    #print (prediction[0].index(max(prediction[0])))
    num = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck'
    ]
    result = "This is a %s" % (num[prediction[0].tolist().index(
        max(prediction[0]))])
    file_io.delete_file(file_name)
    return result
Пример #42
0
  def test_assets(self):
    file1 = self._make_asset("contents 1")
    file2 = self._make_asset("contents 2")

    root = tracking.Checkpointable()
    root.asset1 = tracking.TrackableAsset(file1)
    root.asset2 = tracking.TrackableAsset(file2)

    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
    save.save(root, save_dir, signatures={})

    file_io.delete_file(file1)
    file_io.delete_file(file2)
    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
    file_io.rename(save_dir, load_dir)

    imported = load.load(load_dir)
    with open(imported.asset1.asset_path.numpy(), "r") as f:
      self.assertEquals("contents 1", f.read())
    with open(imported.asset2.asset_path.numpy(), "r") as f:
      self.assertEquals("contents 2", f.read())
Пример #43
0
  def test_assets(self):
    file1 = self._make_asset("contents 1")
    file2 = self._make_asset("contents 2")

    root = tracking.AutoCheckpointable()
    root.asset1 = tracking.TrackableAsset(file1)
    root.asset2 = tracking.TrackableAsset(file2)

    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
    save.save(root, save_dir, signatures={})

    file_io.delete_file(file1)
    file_io.delete_file(file2)
    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
    file_io.rename(save_dir, load_dir)

    imported = load.load(load_dir)
    with open(imported.asset1.asset_path.numpy(), "r") as f:
      self.assertEquals("contents 1", f.read())
    with open(imported.asset2.asset_path.numpy(), "r") as f:
      self.assertEquals("contents 2", f.read())
Пример #44
0
    def test_assets_import(self):
        file1 = self._make_asset("contents 1")
        file2 = self._make_asset("contents 2")

        root = tracking.Checkpointable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        root.asset1 = tracking.TrackableAsset(file1)
        root.asset2 = tracking.TrackableAsset(file2)

        save_dir = os.path.join(self.get_temp_dir(), "save_dir")
        save.save(root, save_dir)

        file_io.delete_file(file1)
        file_io.delete_file(file2)
        load_dir = os.path.join(self.get_temp_dir(), "load_dir")
        file_io.rename(save_dir, load_dir)

        imported = load.load(load_dir)
        with open(imported.asset1.asset_path.numpy(), "r") as f:
            self.assertEquals("contents 1", f.read())
        with open(imported.asset2.asset_path.numpy(), "r") as f:
            self.assertEquals("contents 2", f.read())
Пример #45
0
 def testFileDelete(self):
   file_path = os.path.join(self.get_temp_dir(), "temp_file")
   file_io.write_string_to_file(file_path, "testing")
   file_io.delete_file(file_path)
   self.assertFalse(file_io.file_exists(file_path))
Пример #46
0
 def testFileDeleteFail(self):
   file_path = os.path.join(self._base_dir, "temp_file")
   with self.assertRaises(errors.NotFoundError):
     file_io.delete_file(file_path)
Пример #47
0
 def testFileDelete(self):
   file_path = os.path.join(self._base_dir, "temp_file")
   file_io.FileIO(file_path, mode="w").write("testing")
   file_io.delete_file(file_path)
   self.assertFalse(file_io.file_exists(file_path))
  def testAPIDefCompatibility(self):
    # Get base ApiDef
    name_to_base_api_def = self._GetBaseApiMap()
    # Extract Python API
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    traverse.traverse(tf, public_api_visitor)
    proto_dict = visitor.GetProtos()

    # Map from first character of op name to Python ApiDefs.
    api_def_map = defaultdict(api_def_pb2.ApiDefs)
    # We need to override all endpoints even if 1 endpoint differs from base
    # ApiDef. So, we first create a map from an op to all its endpoints.
    op_to_endpoint_name = defaultdict(list)

    # Generate map from generated python op to endpoint names.
    for public_module, value in proto_dict.items():
      module_obj = _GetSymbol(public_module)
      for sym in value.tf_module.member_method:
        obj = getattr(module_obj, sym.name)

        # Check if object is defined in gen_* module. That is,
        # the object has been generated from OpDef.
        if hasattr(obj, '__module__') and _IsGenModule(obj.__module__):
          if obj.__name__ not in name_to_base_api_def:
            # Symbol might be defined only in Python and not generated from
            # C++ api.
            continue
          relative_public_module = public_module[len('tensorflow.'):]
          full_name = (relative_public_module + '.' + sym.name
                       if relative_public_module else sym.name)
          op_to_endpoint_name[obj].append(full_name)

    # Generate Python ApiDef overrides.
    for op, endpoint_names in op_to_endpoint_name.items():
      api_def = self._CreatePythonApiDef(
          name_to_base_api_def[op.__name__], endpoint_names)
      if api_def:
        api_defs = api_def_map[op.__name__[0].upper()]
        api_defs.op.extend([api_def])

    for key in _ALPHABET:
      # Get new ApiDef for the given key.
      new_api_defs_str = ''
      if key in api_def_map:
        new_api_defs = api_def_map[key]
        new_api_defs.op.sort(key=attrgetter('graph_op_name'))
        new_api_defs_str = str(new_api_defs)

      # Get current ApiDef for the given key.
      api_defs_file_path = os.path.join(
          _PYTHON_API_DIR, 'api_def_%s.pbtxt' % key)
      old_api_defs_str = ''
      if file_io.file_exists(api_defs_file_path):
        old_api_defs_str = file_io.read_file_to_string(api_defs_file_path)

      if old_api_defs_str == new_api_defs_str:
        continue

      if FLAGS.update_goldens:
        if not new_api_defs_str:
          logging.info('Deleting %s...' % api_defs_file_path)
          file_io.delete_file(api_defs_file_path)
        else:
          logging.info('Updating %s...' % api_defs_file_path)
          file_io.write_string_to_file(api_defs_file_path, new_api_defs_str)
      else:
        self.assertMultiLineEqual(
            old_api_defs_str, new_api_defs_str,
            'To update golden API files, run api_compatibility_test locally '
            'with --update_goldens=True flag.')
Пример #49
0
 def delete_file(cls, filename):
     return file_io.delete_file(filename)
  def testAPIDefCompatibility(self):
    # Get base ApiDef
    name_to_base_api_def = self._GetBaseApiMap()
    snake_to_camel_graph_op_names = {
        self._GenerateLowerCaseOpName(name): name
        for name in name_to_base_api_def.keys()}
    # Extract Python API
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    traverse.traverse(tf, public_api_visitor)
    proto_dict = visitor.GetProtos()

    # Map from file path to Python ApiDefs.
    new_api_defs_map = defaultdict(api_def_pb2.ApiDefs)
    # We need to override all endpoints even if 1 endpoint differs from base
    # ApiDef. So, we first create a map from an op to all its endpoints.
    op_to_endpoint_name = defaultdict(list)

    # Generate map from generated python op to endpoint names.
    for public_module, value in proto_dict.items():
      module_obj = _GetSymbol(public_module)
      for sym in value.tf_module.member_method:
        obj = getattr(module_obj, sym.name)

        # Check if object is defined in gen_* module. That is,
        # the object has been generated from OpDef.
        if hasattr(obj, '__module__') and _IsGenModule(obj.__module__):
          if obj.__name__ not in snake_to_camel_graph_op_names:
            # Symbol might be defined only in Python and not generated from
            # C++ api.
            continue
          relative_public_module = public_module[len('tensorflow.'):]
          full_name = (relative_public_module + '.' + sym.name
                       if relative_public_module else sym.name)
          op_to_endpoint_name[obj].append(full_name)

    # Generate Python ApiDef overrides.
    for op, endpoint_names in op_to_endpoint_name.items():
      graph_op_name = snake_to_camel_graph_op_names[op.__name__]
      api_def = self._CreatePythonApiDef(
          name_to_base_api_def[graph_op_name], endpoint_names)

      if api_def:
        file_path = _GetApiDefFilePath(graph_op_name)
        api_defs = new_api_defs_map[file_path]
        api_defs.op.extend([api_def])

    self._AddHiddenOpOverrides(name_to_base_api_def, new_api_defs_map)

    old_api_defs_map = _GetGoldenApiDefs()
    for file_path, new_api_defs in new_api_defs_map.items():
      # Get new ApiDef string.
      new_api_defs_str = str(new_api_defs)

      # Get current ApiDef for the given file.
      old_api_defs_str = (
          old_api_defs_map[file_path] if file_path in old_api_defs_map else '')

      if old_api_defs_str == new_api_defs_str:
        continue

      if FLAGS.update_goldens:
        logging.info('Updating %s...' % file_path)
        file_io.write_string_to_file(file_path, new_api_defs_str)
      else:
        self.assertMultiLineEqual(
            old_api_defs_str, new_api_defs_str,
            'To update golden API files, run api_compatibility_test locally '
            'with --update_goldens=True flag.')

    for file_path in set(old_api_defs_map) - set(new_api_defs_map):
      if FLAGS.update_goldens:
        logging.info('Deleting %s...' % file_path)
        file_io.delete_file(file_path)
      else:
        self.fail(
            '%s file is no longer needed and should be removed.'
            'To update golden API files, run api_compatibility_test locally '
            'with --update_goldens=True flag.' % file_path)
def _delete_file_if_exists(filespec):
  """Deletes files matching `filespec`."""
  for pathname in file_io.get_matching_files(filespec):
    file_io.delete_file(pathname)
Пример #52
0
  def _AssertProtoDictEquals(self,
                             expected_dict,
                             actual_dict,
                             verbose=False,
                             update_goldens=False,
                             additional_missing_object_message='',
                             api_version=2):
    """Diff given dicts of protobufs and report differences a readable way.

    Args:
      expected_dict: a dict of TFAPIObject protos constructed from golden files.
      actual_dict: a ict of TFAPIObject protos constructed by reading from the
        TF package linked to the test.
      verbose: Whether to log the full diffs, or simply report which files were
        different.
      update_goldens: Whether to update goldens when there are diffs found.
      additional_missing_object_message: Message to print when a symbol is
        missing.
      api_version: TensorFlow API version to test.
    """
    diffs = []
    verbose_diffs = []

    expected_keys = set(expected_dict.keys())
    actual_keys = set(actual_dict.keys())
    only_in_expected = expected_keys - actual_keys
    only_in_actual = actual_keys - expected_keys
    all_keys = expected_keys | actual_keys

    # This will be populated below.
    updated_keys = []

    for key in all_keys:
      diff_message = ''
      verbose_diff_message = ''
      # First check if the key is not found in one or the other.
      if key in only_in_expected:
        diff_message = 'Object %s expected but not found (removed). %s' % (
            key, additional_missing_object_message)
        verbose_diff_message = diff_message
      elif key in only_in_actual:
        diff_message = 'New object %s found (added).' % key
        verbose_diff_message = diff_message
      else:
        # Do not truncate diff
        self.maxDiff = None  # pylint: disable=invalid-name
        # Now we can run an actual proto diff.
        try:
          self.assertProtoEquals(expected_dict[key], actual_dict[key])
        except AssertionError as e:
          updated_keys.append(key)
          diff_message = 'Change detected in python object: %s.' % key
          verbose_diff_message = str(e)

      # All difference cases covered above. If any difference found, add to the
      # list.
      if diff_message:
        diffs.append(diff_message)
        verbose_diffs.append(verbose_diff_message)

    # If diffs are found, handle them based on flags.
    if diffs:
      diff_count = len(diffs)
      logging.error(self._test_readme_message)
      logging.error('%d differences found between API and golden.', diff_count)
      messages = verbose_diffs if verbose else diffs
      for i in range(diff_count):
        print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr)

      if update_goldens:
        # Write files if requested.
        logging.warning(self._update_golden_warning)

        # If the keys are only in expected, some objects are deleted.
        # Remove files.
        for key in only_in_expected:
          filepath = _KeyToFilePath(key, api_version)
          file_io.delete_file(filepath)

        # If the files are only in actual (current library), these are new
        # modules. Write them to files. Also record all updates in files.
        for key in only_in_actual | set(updated_keys):
          filepath = _KeyToFilePath(key, api_version)
          file_io.write_string_to_file(
              filepath, text_format.MessageToString(actual_dict[key]))
      else:
        # Fail if we cannot fix the test by updating goldens.
        self.fail('%d differences found between API and golden.' % diff_count)

    else:
      logging.info('No differences found between API and golden.')