def make_inference_graph(model_name, patch_dim): """Build the inference graph for either the X2Y or Y2X GAN. Args: model_name: The var scope name 'ModelX2Y' or 'ModelY2X'. patch_dim: An integer size of patches to feed to the generator. Returns: Tuple of (input_placeholder, generated_tensor). """ input_hwc_pl = tf.placeholder(tf.float32, [None, None, 3]) # Expand HWC to NHWC images_x = tf.expand_dims( data_provider.full_image_to_patch(input_hwc_pl, patch_dim), 0) with tf.variable_scope(model_name): with tf.variable_scope('Generator'): generated = networks.generator(images_x) return input_hwc_pl, generated
def test_generator_run_multi_channel(self): img_batch = tf.zeros([3, 128, 128, 5]) model_output = networks.generator(img_batch) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(model_output)
def test_generator_invalid_input(self): with self.assertRaisesRegexp(ValueError, 'must have rank 4'): networks.generator(tf.zeros([28, 28, 3]))