コード例 #1
0
def make_inference_graph(model_name, patch_dim):
  """Build the inference graph for either the X2Y or Y2X GAN.

  Args:
    model_name: The var scope name 'ModelX2Y' or 'ModelY2X'.
    patch_dim: An integer size of patches to feed to the generator.

  Returns:
    Tuple of (input_placeholder, generated_tensor).
  """
  input_hwc_pl = tf.placeholder(tf.float32, [None, None, 3])

  # Expand HWC to NHWC
  images_x = tf.expand_dims(
      data_provider.full_image_to_patch(input_hwc_pl, patch_dim), 0)

  with tf.variable_scope(model_name):
    with tf.variable_scope('Generator'):
      generated = networks.generator(images_x)
  return input_hwc_pl, generated
コード例 #2
0
 def test_generator_run_multi_channel(self):
     img_batch = tf.zeros([3, 128, 128, 5])
     model_output = networks.generator(img_batch)
     with self.cached_session() as sess:
         sess.run(tf.global_variables_initializer())
         sess.run(model_output)
コード例 #3
0
 def test_generator_invalid_input(self):
     with self.assertRaisesRegexp(ValueError, 'must have rank 4'):
         networks.generator(tf.zeros([28, 28, 3]))