示例#1
0
  def test_generator_graph(self):
    for shape in ([4, 32, 32], [3, 128, 128], [2, 80, 400]):
      tf.reset_default_graph()
      img = tf.ones(shape + [3])
      output_imgs = networks.generator(img)

      self.assertAllEqual(shape + [3], output_imgs.shape.as_list())
示例#2
0
  def test_generator_graph(self):
    for shape in ([4, 32, 32], [3, 128, 128], [2, 80, 400]):
      tf.reset_default_graph()
      img = tf.ones(shape + [3])
      output_imgs = networks.generator(img)

      self.assertAllEqual(shape + [3], output_imgs.shape.as_list())
示例#3
0
 def test_generator_invalid_channels(self):
   with self.assertRaisesRegexp(
       ValueError, 'Last dimension shape must be known but is None'):
     img = tf.placeholder(tf.float32, shape=[4, 32, 32, None])
     networks.generator(img)
示例#4
0
 def test_generator_run_multi_channel(self):
   img_batch = tf.zeros([3, 128, 128, 5])
   model_output = networks.generator(img_batch)
   with self.test_session() as sess:
     sess.run(tf.global_variables_initializer())
     sess.run(model_output)
示例#5
0
 def test_generator_invalid_input(self):
   with self.assertRaisesRegexp(ValueError, 'must have rank 4'):
     networks.generator(tf.zeros([28, 28, 3]))
示例#6
0
  def test_generator_graph_unknown_batch_dim(self):
    img = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
    output_imgs = networks.generator(img)

    self.assertAllEqual([None, 32, 32, 3], output_imgs.shape.as_list())
示例#7
0
 def test_generator_invalid_input(self):
   with self.assertRaisesRegexp(ValueError, 'must have rank 4'):
     networks.generator(tf.zeros([28, 28, 3]))
示例#8
0
  def test_generator_graph_unknown_batch_dim(self):
    img = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
    output_imgs = networks.generator(img)

    self.assertAllEqual([None, 32, 32, 3], output_imgs.shape.as_list())
示例#9
0
 def test_generator_run(self):
   img_batch = tf.zeros([3, 128, 128, 3])
   model_output = networks.generator(img_batch)
   with self.test_session() as sess:
     sess.run(tf.global_variables_initializer())
     sess.run(model_output)