示例#1
0
 def input_fn(is_training, data_dir, batch_size, *args, **kwargs):  # pylint: disable=unused-argument
     return model_helpers.generate_synthetic_data(
         input_shape=tf.TensorShape(
             [batch_size, height, width, num_channels]),
         input_dtype=tf.float32,
         label_shape=tf.TensorShape([batch_size]),
         label_dtype=tf.int32)
示例#2
0
def _generate_synthetic_data(params):
    """Create synthetic data based on the parameter batch size."""
    batch = length = int(math.sqrt(params["batch_size"]))
    return model_helpers.generate_synthetic_data(
        input_shape=tf.TensorShape([batch, length]),
        input_value=1,
        input_dtype=tf.int32,
        label_shape=tf.TensorShape([batch, length]),
        label_value=1,
        label_dtype=tf.int32,
    )
示例#3
0
def generate_synthetic_input_dataset(model, batch_size):
  """Generate synthetic dataset."""
  image_size = _get_default_image_size(model)
  image_shape = (batch_size,) + image_size + (_NUM_CHANNELS,)
  label_shape = (batch_size, _NUM_CLASSES)

  dataset = model_helpers.generate_synthetic_data(
      input_shape=tf.TensorShape(image_shape),
      label_shape=tf.TensorShape(label_shape),
  )
  return dataset
示例#4
0
    def test_generate_only_input_data(self):
        d = model_helpers.generate_synthetic_data(input_shape=tf.TensorShape(
            [4]),
                                                  input_value=43.5,
                                                  input_dtype=tf.float32)

        element = tf.compat.v1.data.make_one_shot_iterator(d).get_next()
        self.assertFalse(isinstance(element, tuple))

        with self.session() as sess:
            inp = sess.run(element)
            self.assertAllClose(inp, [43.5, 43.5, 43.5, 43.5])
  def test_generate_synethetic_data(self):
    input_element, label_element = model_helpers.generate_synthetic_data(
        input_shape=tf.TensorShape([5]),
        input_value=123,
        input_dtype=tf.float32,
        label_shape=tf.TensorShape([]),
        label_value=456,
        label_dtype=tf.int32).make_one_shot_iterator().get_next()

    with self.test_session() as sess:
      for n in range(5):
        inp, lab = sess.run((input_element, label_element))
        self.assertAllClose(inp, [123., 123., 123., 123., 123.])
        self.assertEquals(lab, 456)
  def test_generate_nested_data(self):
    d = model_helpers.generate_synthetic_data(
        input_shape={'a': tf.TensorShape([2]),
                     'b': {'c': tf.TensorShape([3]), 'd': tf.TensorShape([])}},
        input_value=1.1)

    element = d.make_one_shot_iterator().get_next()
    self.assertIn('a', element)
    self.assertIn('b', element)
    self.assertEquals(len(element['b']), 2)
    self.assertIn('c', element['b'])
    self.assertIn('d', element['b'])
    self.assertNotIn('c', element)

    with self.test_session() as sess:
      inp = sess.run(element)
      self.assertAllClose(inp['a'], [1.1, 1.1])
      self.assertAllClose(inp['b']['c'], [1.1, 1.1, 1.1])
      self.assertAllClose(inp['b']['d'], 1.1)