Exemplo n.º 1
0
def inference():
    graph = tf.Graph()

    with graph.as_default():
        with tf.gfile.FastGFile('input_image.jpg', 'rb') as f:

            image_data = f.read()
            print("1")
            input_image = tf.image.decode_jpeg(image_data, channels=3)
            input_image = tf.image.resize_images(input_image,
                                                 size=(FLAGS.image_size,
                                                       FLAGS.image_size))
            input_image = utils.convert2float(input_image)
            input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

        print("2")
        with tf.gfile.FastGFile('./pretrained/apple2orange.pb',
                                'rb') as model_file:
            graph_def = tf.GraphDef()
            print("3")
            graph_def.ParseFromString(model_file.read())
        [output_image
         ] = tf.import_graph_def(graph_def,
                                 input_map={'input_image': input_image},
                                 return_elements=['output_image:0'],
                                 name='output')

    with tf.Session(graph=graph) as sess:
        generated = output_image.eval()
        with open(FLAGS.output, 'wb') as f:
            f.write(generated)
Exemplo n.º 2
0
def translate():
    graph = tf.Graph()

    with graph.as_default():
        with tf.gfile.FastGFile(FLAGS.input_img, 'rb') as f:
            img = f.read()
            input_img = tf.image.decode_jpeg(img, channels=3)
            input_img = tf.image.resize_images(input_img,
                                               size=(FLAGS.img_size,
                                                     FLAGS.img_size))
            input_img = utils.convert2float(input_img)
            input_img.set_shape([FLAGS.img_size, FLAGS.img_size, 3])

        with tf.gfile.FastGFile(FLAGS.model, 'rb') as model:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(model.read())

        [output_img
         ] = tf.import_graph_def(graph_def,
                                 input_map={'input_image': input_img},
                                 return_elements=['output_image:0'],
                                 name='output')

    with tf.Session(graph=graph) as sess:
        generated = output_img.eval()
        with open(FLAGS.output_img, 'wb') as f:
            f.write(generated)
Exemplo n.º 3
0
def sample():
  graph = tf.Graph()

  with graph.as_default():
    cycle_gan = CycleGAN()

    with tf.gfile.FastGFile(IMG_PATH, 'r') as f:
      image_data = f.read()
    in_image = tf.image.decode_jpeg(image_data, channels=3)
    in_image = tf.image.resize_images(in_image, size=(128, 128))
    in_image = utils.convert2float(in_image)
    in_image.set_shape([128, 128, 3])

    cycle_gan = CycleGAN()
    cycle_gan.model()
    out_image = cycle_gan.sample(tf.expand_dims(in_image, 0))

  with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    cycle_gan.saver.restore(sess, CKPT_PATH)
    generated = out_image.eval()
    samples_dir = 'samples'
    os.makedirs(samples_dir, exist_ok=True)
    samples_file = os.path.join(samples_dir, 'sample.jpg')
    with open(samples_file, 'wb') as f:
      f.write(generated)
Exemplo n.º 4
0
def test(file):
    dataset = FLAGS.input.split("/")[1] + '/'
    test_name = FLAGS.input.split("/")[2] + '/'

    graph = tf.Graph()

    with graph.as_default():
        print('Reading in image: ' + file)
        with tf.gfile.FastGFile(FLAGS.input + file, 'rb') as f:
            image_data = f.read()
            input_image = tf.image.decode_jpeg(image_data, channels=3)
            input_image = tf.image.resize_images(input_image,
                                                 size=(FLAGS.image_size,
                                                       FLAGS.image_size))
            input_image = utils.convert2float(input_image)
            input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

        with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(model_file.read())
        [output] = tf.import_graph_def(graph_def,
                                       input_map={'input_image': input_image},
                                       return_elements=['output_image:0'],
                                       name='output')

        with tf.Session(graph=graph) as sess:
            generated = output.eval()
            with open(FLAGS.output + dataset + test_name + file, 'wb') as f:
                f.write(generated)
Exemplo n.º 5
0
def inference(url="",
              outputpath="output.jpg",
              isurl=True,
              modelpath="zebra2horse.pb"):
    graph = tf.Graph()
    with graph.as_default():
        if isurl:
            image_data = requests.get(url=url).content
        else:
            #print(url)
            with open(url, "rb") as f:
                image_data = f.read()
                input_image = tf.image.decode_jpeg(image_data, channels=3)
                input_image = tf.image.resize_images(input_image,
                                                     size=(FLAGS.image_size,
                                                           FLAGS.image_size))
                input_image = utils.convert2float(input_image)
                input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])
        with tf.gfile.FastGFile(modelpath, 'rb') as model_file:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(model_file.read())
            [output_image
             ] = tf.import_graph_def(graph_def,
                                     input_map={'input_image': input_image},
                                     return_elements=['output_image:0'],
                                     name='output')
        with tf.Session(graph=graph) as sess:
            generated = output_image.eval()
            with open(outputpath, 'wb') as f:
                f.write(generated)
Exemplo n.º 6
0
def inference():
    graph = tf.Graph()  #创建计算图

    with graph.as_default():
        with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
            image_data = f.read()
            input_image = tf.image.decode_jpeg(image_data, channels=3)  #解码
            input_image = tf.image.resize_images(input_image,
                                                 size=(FLAGS.image_size,
                                                       FLAGS.image_size))
            input_image = utils.convert2float(input_image)
            input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

        with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:  #导入模型
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(model_file.read())
        [output_image
         ] = tf.import_graph_def(graph_def,
                                 input_map={'input_image': input_image},
                                 return_elements=['output_image:0'],
                                 name='output')

        with tf.Session(graph=graph) as sess:
            generated = output_imag.eval()  #计算output_image
            with open(FLAGS.output, 'wb') as f:  #将计算结果写入output文件
                f.write(generated)
def inference():
  graph = tf.Graph()
  for ind in range(0,18):
      with graph.as_default():
        
            input_image = r'E:\data\after_spm\031419464625\%d.jpg'%ind
            with tf.gfile.FastGFile(input_image, 'rb') as f:
              image_data = f.read()
              input_image = tf.image.decode_jpeg(image_data, channels=3)
              input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
              input_image = utils.convert2float(input_image)
              input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])
        
            with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
              graph_def = tf.GraphDef()
              graph_def.ParseFromString(model_file.read())
            [output_image] = tf.import_graph_def(graph_def,
                                  input_map={'input_image': input_image},
                                  return_elements=['output_image:0'],
                                  name='output')
    
      with tf.Session(graph=graph) as sess:
        generated = output_image.eval()
        output_image = 'E:\\data\\after_spm\\OUTPUT_REAL\\output_histogram_%d.jpg'%ind
        with open(output_image, 'wb') as f:
          f.write(generated)
Exemplo n.º 8
0
def inference(model, name, artist, img_in, img_out, size=256):
  graph = tf.Graph()

  with graph.as_default():
    with tf.gfile.FastGFile(img_in, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.resize_images(input_image, size=(size, size))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([size, size, 3])

    with tf.gfile.FastGFile(model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                                         input_map={'input_image': input_image},
                                         return_elements=['output_image:0'],
                                         name='output')

  with tf.Session(graph=graph) as sess:
    generated = output_image.eval()
    out_art = Image.open(io.BytesIO(generated))
    draw_text(out_art, name.replace("$", " "))
    draw_text(out_art, artist.replace("$", " "), bottom=False)
    out_art.save(img_out, "JPEG")
Exemplo n.º 9
0
def inference():
  graph = tf.Graph()

  with graph.as_default():
    with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.transpose_image(input_image)
      input_image = tf.image.resize_images(input_image, size=(FLAGS.image_length, FLAGS.image_height))
      # input_image = tf.image.resize_images(input_image, size=(FLAGS.image_height, FLAGS.image_length))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([FLAGS.image_length, FLAGS.image_height, 3])
      # input_image.set_shape([FLAGS.image_height, FLAGS.image_length, 3])

    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name='output')

  with tf.Session(graph=graph) as sess:
    generated = output_image.eval()
    with open(FLAGS.output, 'wb') as f:
      f.write(generated)
Exemplo n.º 10
0
 def _preprocess(self, image):
     image = tf.image.resize_images(image,
                                    size=(self.image_size[0],
                                          self.image_size[1]))
     image = utils.convert2float(image)
     image.set_shape([self.image_size[0], self.image_size[1], 3])
     return image
Exemplo n.º 11
0
def color(src, dst):
  MODEL = 'pretrained/sketch2render.pb'
  IMG_SIZE = 256

  graph = tf.Graph()
  with graph.as_default():
    with tf.io.gfile.GFile(src, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.resize(input_image, size=(IMG_SIZE, IMG_SIZE))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([IMG_SIZE, IMG_SIZE, 3])

    with tf.io.gfile.GFile(MODEL, 'rb') as model_file:
      # graph_def = tf.Graph()
      graph_def = tf.compat.v1.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name='output')

  with tf.compat.v1.Session(graph=graph) as sess:
    generated = output_image.eval()
    with open(dst, 'wb') as f:
      f.write(generated)
Exemplo n.º 12
0
def sample():
    """Translate image to image (currently only support image with size 128x128)"""
    graph = tf.Graph()

    with graph.as_default():
        with tf.gfile.FastGFile(FLAGS.input, 'r') as f:
            image_data = f.read()
            input_image = tf.image.decode_jpeg(image_data, channels=3)
            input_image = tf.image.resize_images(input_image, size=(128, 128))
            input_image = utils.convert2float(input_image)
            input_image.set_shape([128, 128, 3])

        with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(model_file.read())
        [output_image
         ] = tf.import_graph_def(graph_def,
                                 input_map={'input_image': input_image},
                                 return_elements=['output_image:0'],
                                 name='apple2orange')

    with tf.Session(graph=graph) as sess:
        generated = output_image.eval()
        with open(FLAGS.output, 'wb') as f:
            f.write(generated)
Exemplo n.º 13
0
def main():

    picF = "picF"
    files = os.listdir(picF)[:1]

    sess = tf.InteractiveSession()

    global_step = tf.Variable(2501, name="global_step", trainable=False)

    sess.run(tf.global_variables_initializer())

    img = tf.read_file(os.path.join(picF, files[0]))
    img = tf.image.decode_jpeg(img)
    #img = utils.convert2float(img)
    img = tf.expand_dims(img, axis=0)

    tf.summary.image('real', img)
    tf.summary.scalar('test', global_step)
    outimg = utils.convert2float(img)
    tf.summary.image('out', utils.batch_convert2int(outimg))
    sdf = outimg.eval()
    print(sdf)
    sdf = utils.batch_convert2int(outimg).eval()
    print(sdf)
    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('C:\\log')

    summary, _, s, outimg2 = sess.run([summary_op, img, global_step, outimg])
    train_writer.add_summary(summary)
    train_writer.flush()

    sess.close()
def inference():
    test = os.listdir(FLAGS.test_path)

    graph = tf.Graph()

    for index in range(len(test)):
        with graph.as_default():
            with tf.gfile.FastGFile(FLAGS.test_path + "/" + test[index],
                                    'rb') as f:
                image_data = f.read()
                input_image = tf.image.decode_jpeg(image_data, channels=3)
                input_image = tf.image.resize_images(input_image,
                                                     size=(FLAGS.image_size,
                                                           FLAGS.image_size))
                input_image = utils.convert2float(input_image)
                input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

            with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(model_file.read())
            [output_image
             ] = tf.import_graph_def(graph_def,
                                     input_map={'input_image': input_image},
                                     return_elements=['output_image:0'],
                                     name='output')

        with tf.Session(graph=graph) as sess:
            generated = output_image.eval()
            with open(FLAGS.output + test[index], 'wb') as f:
                f.write(generated)
Exemplo n.º 15
0
def inference():
  graph = tf.Graph()

  with graph.as_default():
    with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())

    for node in graph_def.node:
      if node.op == 'RefSwitch':
        node.op = 'Switch'
        for index in range(len(node.input)):
          if 'moving_' in node.input[index]:
            node.input[index] = node.input[index] + '/read'
      elif node.op == 'AssignSub':
        node.op = 'Sub'
        if 'use_locking' in node.attr: del node.attr['use_locking']

    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name='output')

  with tf.Session(graph=graph) as sess:
    generated = output_image.eval()
    with open(FLAGS.output, 'wb') as f:
      f.write(generated)
Exemplo n.º 16
0
 def _preprocess(self, image):
     image = tf.image.resize_images(image,
                                    size=(self.image_length,
                                          self.image_height))
     image = utils.convert2float(image)
     image.set_shape([self.image_length, self.image_height, 3])
     return image
Exemplo n.º 17
0
def inference():
    graph = tf.Graph()

    with tf.Session(graph=graph) as sess:
        with graph.as_default():
            for input in imgs:
                output = input[0:-4] + '_f.jpg'

                with tf.gfile.FastGFile(input, 'rb') as f:
                    image_data = f.read()
                    input_image = tf.image.decode_jpeg(image_data, channels=3)
                    input_image = tf.image.resize_images(input_image, size=(image_size, image_size))
                    input_image = utils.convert2float(input_image)
                    input_image.set_shape([image_size, image_size, 3])

                with tf.gfile.FastGFile(model, 'rb') as model_file:
                    graph_def = tf.GraphDef()
                    graph_def.ParseFromString(model_file.read())

                [output_image] = tf.import_graph_def(graph_def,
                                                     input_map={'input_image': input_image},
                                                     return_elements=['output_image:0'],
                                                     name='output')

                generated = output_image.eval()
                with open(output, 'wb') as f:
                    f.write(generated)
Exemplo n.º 18
0
def inference():
    graph = tf.Graph()
    with graph.as_default():
        with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(model_file.read())
        outputs = []
        files = os.listdir(FLAGS.input)
        for filename in files:
            with tf.gfile.FastGFile(FLAGS.input + '/' + filename, 'rb') as f:
                image_data = f.read()
                if FLAGS.direction == 'XtoY':
                    input_image = tf.image.decode_jpeg(image_data, channels=1)
                    input_image = tf.image.resize_images(
                        input_image, size=(FLAGS.image_size, FLAGS.image_size))
                    input_image = utils.convert2float(input_image)
                    input_image.set_shape(
                        [FLAGS.image_size, FLAGS.image_size, 1])
                elif FLAGS.direction == 'YtoX':
                    input_image = tf.image.decode_jpeg(image_data, channels=3)
                    input_image = tf.image.resize_images(
                        input_image, size=(FLAGS.image_size, FLAGS.image_size))
                    input_image = utils.convert2float(input_image)
                    input_image.set_shape(
                        [FLAGS.image_size, FLAGS.image_size, 3])

            if FLAGS.direction == 'XtoY':
                [output_image] = tf.import_graph_def(
                    graph_def,
                    input_map={'input_image_X': input_image},
                    return_elements=['output_image:0'],
                    name='output')
            elif FLAGS.direction == 'YtoX':
                [output_image] = tf.import_graph_def(
                    graph_def,
                    input_map={'input_image_Y': input_image},
                    return_elements=['output_image:0'],
                    name='output')
            outputs = outputs + [output_image]

        with tf.Session(graph=graph) as sess:
            for (output_image, filename) in zip(outputs, files):
                generated = output_image.eval()
                with open(FLAGS.output + '/' + filename, 'wb') as f:
                    f.write(generated)
Exemplo n.º 19
0
 def _preprocess(self, image):
     image = tf.image.resize_images(image,
                                    size=(self.image_size, self.image_size))
     image = utils.convert2float(image)
     if self.name == 'X':
         image.set_shape([self.image_size, self.image_size, 1])
     elif self.name == 'Y':
         image.set_shape([self.image_size, self.image_size, 3])
     return image
Exemplo n.º 20
0
 def _preprocess(self, image):
     #image = tf.image.resize_images(image, size=(self.image_size, self.image_size))
     image = tf.image.resize_images(image,
                                    size=(self.image_height,
                                          self.image_width))
     image = utils.convert2float(image)
     #image.set_shape([self.image_size, self.image_size, 3])
     image.set_shape([self.image_height, self.image_width, 3])
     return image
Exemplo n.º 21
0
def inference():
    graph = tf.Graph()

    with graph.as_default():
        with tf.gfile.FastGFile(FLAGS.input1, 'rb') as f:
            image_data = f.read()
            input_image1 = tf.image.decode_jpeg(image_data, channels=3)
            input_image1 = tf.image.resize_images(input_image1,
                                                  size=(FLAGS.image_size,
                                                        FLAGS.image_size))
            input_image1 = utils.convert2float(input_image1)
            input_image1.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

        with tf.gfile.FastGFile(FLAGS.input2, 'rb') as f:
            image_data = f.read()
            input_image2 = tf.image.decode_jpeg(image_data, channels=3)
            input_image2 = tf.image.resize_images(input_image2,
                                                  size=(FLAGS.image_size,
                                                        FLAGS.image_size))
            input_image2 = utils.convert2float(input_image2)
            input_image2.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

        with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(model_file.read())

        [output_image1, output_image2] = tf.import_graph_def(
            graph_def,
            input_map={
                'input_image1': input_image1,
                'input_image2': input_image2
            },
            return_elements=['output_image1:0', 'output_image2:0'],
            name='output')

    with tf.Session(graph=graph) as sess:
        generated1 = output_image1.eval()
        with open(FLAGS.output1, 'wb') as f:
            f.write(generated1)

        generated2 = output_image2.eval()
        with open(FLAGS.output2, 'wb') as f:
            f.write(generated2)
Exemplo n.º 22
0
  def _preprocess(self, image):
    image = tf.transpose(image, [1, 0, 2])
    x = tf.random_uniform([], 0, 1120, dtype = tf.int32)
    y = tf.random_uniform([], 0, 640, dtype = tf.int32)
    image = tf.image.crop_to_bounding_box(image, x, y, self.image_size, self.image_size)

    image = tf.contrib.image.rotate(image, -math.pi/2)
    #image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = utils.convert2float(tf.cast(image, tf.float32))
    image.set_shape([self.image_size, self.image_size, 3])
    return image
Exemplo n.º 23
0
 def _preprocess(self, image):
     """ 读取并对TFrecords文件解码
     若需要处理非L/RGB/RGBA类型的图像,请自行添加代码
     python Image 读入的图像按照[height,weight,depth]维度排列
 Return:
   image: 3D tensor [image_width, image_height, image_depth]
 """
     if self.image_mode == 'L':
         image = tf.reshape(image, [self.image_width, self.image_height, 1])
         image = utils.convert2float(image)
     elif self.image_mode == 'RGB':
         image = tf.reshape(image, [self.image_width, self.image_height, 3])
         image = utils.convert2float(image)
     elif self.image_mode == 'RGBA':
         image = tf.reshape(image, [self.image_width, self.image_height, 4])
         image = utils.convert2float(image)
         #image = tf.cast(image,tf.float32)*(1./255)-0.5
     else:
         print('The image mode must be L/RGB/RGBA!')
         sys.exit()
     return image
Exemplo n.º 24
0
def main():
    sess = tf.InteractiveSession()
    testimg = np.asarray([[0,127,255],
                            [20,128,127],
                            [255,100,100]],dtype=np.float)
    print(testimg)
    img = tf.convert_to_tensor(testimg,dtype=tf.uint8)
    print(img)

    img = tf.expand_dims(img,axis=2)
    print(img)
    print(img.eval())
    #img = tf.image.convert_image_dtype(img,dtype= tf.float32)
    img = utils.convert2float(img)
    #img = (img/127.5)-1.0
    print(img)
    print(img.eval())
    img = (img +1.0)/2.0
    img = tf.image.convert_image_dtype(img,dtype=tf.uint8)
    print(img)
    print(img.eval())
def inference():
    graph = tf.Graph()
    # visualization staff yw3025
    index = open(FLAGS.input + "/" + FLAGS.direction + "index.html", "w")
    index.write("<html><body><table><tr>")
    index.write("<th>name</th><th>input</th><th>output</th></tr>")
    # batch inference staff yw3025
    for file in os.listdir(FLAGS.input):
        filename = FLAGS.input + "/" + file
        print(filename)
        with graph.as_default():
            with tf.gfile.FastGFile(filename, 'rb') as f:
                image_data = f.read()
                input_image = tf.image.decode_jpeg(image_data, channels=3)
                input_image = tf.image.resize_images(input_image,
                                                     size=(FLAGS.image_size,
                                                           FLAGS.image_size))
                input_image = utils.convert2float(input_image)
                input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

            with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(model_file.read())
            [output_image
             ] = tf.import_graph_def(graph_def,
                                     input_map={'input_image': input_image},
                                     return_elements=['output_image:0'],
                                     name='output')

        with tf.Session(graph=graph) as sess:
            generated = output_image.eval()
            with open(FLAGS.output + "/" + FLAGS.direction + file, 'wb') as f:
                f.write(generated)
        # visualization staff yw3025
        index.write("<td>%s</td>" % filename)
        index.write("<td><img src='%s'></td>" % file)
        index.write("<td><img src='%s'></td>" % (FLAGS.direction + file))
        index.write("</tr>")

        print("processing" + filename)
Exemplo n.º 26
0
def main():
    modelfile = 'model\\model.ckpt-2000'

    picF = "picF"
    files = os.listdir(picF)

    sess = tf.InteractiveSession()
    ge = Generator('G', is_training=False)
    x = tf.placeholder(dtype=tf.float32, shape=(5, 270, 480, 3))

    out = ge(x)

    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver(var_list=ge.variables)
    v = ge.variables[0]
    print(v.name)
    saver.restore(sess, modelfile)
    print(v.eval())

    for pic in files:
        img = tf.read_file(os.path.join(picF, pic))
        img = tf.image.decode_jpeg(img)
        img = utils.convert2float(img)
        img = tf.expand_dims(img, axis=0)
        shape = tf.shape(img).eval()
        img.set_shape(shape)
        out = ge(img)

        out = tf.unstack(out)[0]
        out = utils.convert2int(out)
        out = tf.image.encode_jpeg(out)

        out = out.eval()
        with tf.gfile.GFile(os.path.join('out\\picF', pic), 'wb') as fw:
            fw.write(out)
            fw.flush()

    sess.close()
Exemplo n.º 27
0
def inference():
    graph = tf.Graph()

    with graph.as_default():

        with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(model_file.read())

        file_list = data_reader(FLAGS.input)
        whole = len(file_list)
        cnt = 0
        with tf.Session(graph=graph) as sess:
            for file in file_list:
                with tf.gfile.FastGFile(file, 'rb') as f:
                    image_data = f.read()
                    input_image = tf.image.decode_jpeg(image_data, channels=3)
                    input_image = tf.image.resize_images(
                        input_image, size=(FLAGS.image_size, FLAGS.image_size))
                    input_image = utils.convert2float(input_image)
                    input_image.set_shape(
                        [FLAGS.image_size, FLAGS.image_size, 3])
                    #input_image_list.append(input_image)
                print cnt
                [output_image] = tf.import_graph_def(
                    graph_def,
                    input_map={'input_image': input_image},
                    return_elements=['output_image:0'],
                    name='output')
                print cnt
                generated = output_image.eval()
                print cnt
                output_file_name = file.split('/')[-1]
                with open(FLAGS.output + '/fake_{}'.format(output_file_name),
                          'wb') as f:
                    f.write(generated)
                cnt += 1
                if cnt / whole > 0.05:
                    print cnt / whole, 'done'
Exemplo n.º 28
0
 def _preprocess(self, image):
     image = tf.image.resize_images(image,
                                    size=(self.image_size, self.image_size))
     image = convert2float(image)
     image.set_shape([self.image_size, self.image_size, 3])
     return image
Exemplo n.º 29
0
def inference():
    path_src_test_pic = FLAGS.Test_input
    path_dst_test_dir = FLAGS.Test_output
    temp_img_list = os.listdir(path_src_test_pic)
    image_size = FLAGS.image_size
    path_model_dir = FLAGS.Model_dir
    model_name_list = os.listdir(path_model_dir)
    Total_count_model = len(model_name_list)
    Total_count_img = len(temp_img_list)

    count_model = 1
    #################################################################################
    for model_name in model_name_list:

        path_model_now_use = os.path.join(path_model_dir, model_name)
        print("Model Progress :" + str(count_model) + "/" +
              str(Total_count_model))
        print("Now use Model is " + model_name)
        count_model += 1
        count_img = 1
        model_name = model_name.split(".")
        model_name = model_name[0]
        output_by_model_dir = os.path.join(path_dst_test_dir, model_name)

        if not os.path.isdir(output_by_model_dir):
            os.mkdir(output_by_model_dir)

        print("Image will save in the path of directory :  " +
              output_by_model_dir)

        for img in temp_img_list:
            graph = tf.Graph()
            start = time()

            with graph.as_default():
                temp_input = os.path.join(path_src_test_pic, img)
                temp_output = os.path.join(output_by_model_dir, img)

                with tf.gfile.FastGFile(temp_input, 'rb') as f:
                    image_data = f.read()
                    input_image = tf.image.decode_jpeg(image_data, channels=3)
                    input_image = tf.image.resize_images(input_image,
                                                         size=(image_size,
                                                               image_size))
                    input_image = utils.convert2float(input_image)
                    input_image.set_shape([image_size, image_size, 3])

                with tf.gfile.FastGFile(path_model_now_use,
                                        'rb') as model_file:
                    graph_def = tf.GraphDef()
                    graph_def.ParseFromString(model_file.read())

                    [output_image] = tf.import_graph_def(
                        graph_def,
                        input_map={'input_image': input_image},
                        return_elements=['output_image:0'],
                        name='output')

                with tf.Session(graph=graph) as sess:
                    generated = output_image.eval()
                    with open(temp_output, 'wb') as f:
                        f.write(generated)

                print(str(count_img) + "/" + str(Total_count_img))

                End = time()
                t = End - start
                print(str(t) + " sec")
                count_img += 1