示例#1
0
  def _makeDataset(self, inputter, data_file, metadata=None, dataset_size=1, shapes=None):
    if metadata is not None:
      inputter.initialize(metadata)

    self.assertEqual(dataset_size, inputter.get_dataset_size(data_file))

    dataset = inputter.make_dataset(data_file)
    dataset = dataset.map(lambda *arg: inputter.process(item_or_tuple(arg)))
    dataset = dataset.padded_batch(1, padded_shapes=data.get_padded_shapes(dataset))

    iterator = dataset.make_initializable_iterator()
    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
    next_element = iterator.get_next()

    if shapes is not None:
      all_features = [next_element]
      if not inputter.is_target:
        all_features.append(inputter.get_serving_input_receiver().features)
      else:
        with self.assertRaises(ValueError):
          _ = inputter.get_serving_input_receiver()
      for features in all_features:
        self.assertNotIn("raw", features)
        for field, shape in six.iteritems(shapes):
          self.assertIn(field, features)
          self.assertAllEqual(shape, features[field].get_shape().as_list())

    transformed = inputter.transform_data(next_element)
    return next_element, transformed
示例#2
0
  def testSequenceRecord(self):
    vector = np.array([[0.2, 0.3], [0.4, 0.5]], dtype=np.float32)

    writer = tf.python_io.TFRecordWriter(record_file)
    record_inputter.write_sequence_record(vector, writer)
    writer.close()

    inputter = record_inputter.SequenceRecordInputter()
    data, transformed = _first_element(inputter, record_file)
    input_receiver = inputter.get_serving_input_receiver()

    self.assertIn("length", data)
    self.assertIn("tensor", data)
    self.assertIn("length", input_receiver.features)
    self.assertIn("tensor", input_receiver.features)

    self.assertAllEqual([None, None, 2], transformed.get_shape().as_list())
    self.assertAllEqual([None, None, 2], input_receiver.features["tensor"].get_shape().as_list())

    with self.test_session() as sess:
      sess.run(tf.tables_initializer())
      data, transformed = sess.run([data, transformed])
      self.assertEqual(2, data["length"])
      self.assertAllEqual(vector, data["tensor"])
      self.assertAllEqual([vector], transformed)
示例#3
0
  def _makeDataset(self, inputter, data_file, metadata=None, dataset_size=1, shapes=None):
    if metadata is not None:
      inputter.initialize(metadata)

    self.assertEqual(dataset_size, inputter.get_dataset_size(data_file))
    dataset = inputter.make_dataset(data_file)
    dataset = dataset.map(lambda *arg: inputter.process(item_or_tuple(arg)))
    dataset = dataset.padded_batch(1, padded_shapes=data.get_padded_shapes(dataset))

    if compat.is_tf2():
      iterator = None
      features = iter(dataset).next()
    else:
      iterator = dataset.make_initializable_iterator()
      features = iterator.get_next()

    if shapes is not None:
      all_features = [features]
      if not compat.is_tf2() and not inputter.is_target:
        all_features.append(inputter.get_serving_input_receiver().features)
      for f in all_features:
        for field, shape in six.iteritems(shapes):
          self.assertIn(field, f)
          self.assertTrue(f[field].shape.is_compatible_with(shape))

    inputs = inputter.make_inputs(features, training=True)
    if not compat.is_tf2():
      with self.test_session() as sess:
        sess.run(tf.tables_initializer())
        sess.run(tf.global_variables_initializer())
        sess.run(iterator.initializer)
    return self.evaluate((features, inputs))