Пример #1
0
  def testSequenceRecord(self):
    vector = np.array([[0.2, 0.3], [0.4, 0.5]], dtype=np.float32)

    writer = tf.python_io.TFRecordWriter(record_file)
    record_inputter.write_sequence_record(vector, writer)
    writer.close()

    inputter = record_inputter.SequenceRecordInputter()
    data, transformed = _first_element(inputter, record_file)
    input_receiver = inputter.get_serving_input_receiver()

    self.assertIn("length", data)
    self.assertIn("tensor", data)
    self.assertIn("length", input_receiver.features)
    self.assertIn("tensor", input_receiver.features)

    self.assertAllEqual([None, None, 2], transformed.get_shape().as_list())
    self.assertAllEqual([None, None, 2], input_receiver.features["tensor"].get_shape().as_list())

    with self.test_session() as sess:
      sess.run(tf.tables_initializer())
      data, transformed = sess.run([data, transformed])
      self.assertEqual(2, data["length"])
      self.assertAllEqual(vector, data["tensor"])
      self.assertAllEqual([vector], transformed)
Пример #2
0
def ark_to_records(ark_filename, out_prefix, dtype=np.float32):
    """Converts ARK dataset to TFRecords."""
    record_writer = tf.python_io.TFRecordWriter(out_prefix + ".records")
    count = 0

    with io.open(ark_filename, encoding="utf-8") as ark_file:
        while True:
            ark_idx, vector = consume_next_vector(ark_file, dtype=dtype)
            if not ark_idx:
                break
            write_sequence_record(vector, record_writer)
            count += 1

    record_writer.close()
    print("Saved {} records.".format(count))
Пример #3
0
  def testSequenceRecord(self):
    vector = np.array([[0.2, 0.3], [0.4, 0.5]], dtype=np.float32)

    record_file = os.path.join(self.get_temp_dir(), "data.records")
    writer = compat.tf_compat(v2="io.TFRecordWriter", v1="python_io.TFRecordWriter")(record_file)
    record_inputter.write_sequence_record(vector, writer)
    writer.close()

    inputter = record_inputter.SequenceRecordInputter()
    features, transformed = self._makeDataset(
        inputter,
        record_file,
        shapes={"tensor": [None, None, 2], "length": [None]})

    self.assertEqual([2], features["length"])
    self.assertAllEqual([vector], features["tensor"])
    self.assertAllEqual([vector], transformed)
Пример #4
0
  def testSequenceRecord(self):
    vector = np.array([[0.2, 0.3], [0.4, 0.5]], dtype=np.float32)

    record_file = os.path.join(self.get_temp_dir(), "data.records")
    writer = tf.python_io.TFRecordWriter(record_file)
    record_inputter.write_sequence_record(vector, writer)
    writer.close()

    inputter = record_inputter.SequenceRecordInputter()
    features, transformed = self._makeDataset(
        inputter,
        record_file,
        shapes={"tensor": [None, None, 2], "length": [None]})

    with self.test_session() as sess:
      sess.run(tf.tables_initializer())
      features, transformed = sess.run([features, transformed])
      self.assertEqual([2], features["length"])
      self.assertAllEqual([vector], features["tensor"])
      self.assertAllEqual([vector], transformed)
Пример #5
0
 def _write_example(vector, text):
     write_sequence_record(vector, record_writer)
     write_text(text, text_writer)