예제 #1
0
 def process_func(self, example_line):
   """Process the single example line and return data 
     Default behavior, assumes each line is the path to a single image.
     This is used to train a VAE.
   """
   #return imresize(imread(example_line), [256, 256])
   return imread(example_line)
예제 #2
0
파일: datasets.py 프로젝트: rootkit/gan-lib
 def images_paths(self, set):
     origin = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip'
     origin_file_name = os.path.basename(origin)
     download_folder = os.path.join(DATA_FOLDER, 'horse2zebra')
     download_path = os.path.join(download_folder, origin_file_name)
     download(origin, download_path)
     extract_path = extract_all(download_path)
     extract_path = os.path.join(extract_path, set + self.name)
     images_paths = glob(os.path.join(extract_path, '*.jpg'))
     final_images_paths = []
     for image_path in images_paths:
         image = imread(image_path)
         if image.shape != self.image_shape:
             msg = '{} has wrong shape {}. It should be {}. Removing file'
             print(msg.format(image_path, image.shape, self.image_shape))
             os.remove(image_path)
         else:
             final_images_paths.append(image_path)
     return np.array(final_images_paths)
예제 #3
0
    if ckpt_path is None:
        raise Exception('No checkpoint!')
    else:
        print('Copy variables from % s' % ckpt_path)

    #--test--#
    b_list = glob('./Datasets/' + dataset + '/bounding_box_train-Market/*.jpg')
    a_list = glob('./Datasets/' + dataset + '/bounding_box_train-Duke/*.jpg')

    b_save_dir = './test_predictions/' + dataset + '_spgan' + '/bounding_box_train_market2duke/'
    a_save_dir = './test_predictions/' + dataset + '_spgan' + '/bounding_box_train_duke2market/'
    utils.mkdir([a_save_dir, b_save_dir])

    for i in range(len(a_list)):
        a_real_ipt = im.imresize(im.imread(a_list[i]), [crop_size, crop_size])
        a_real_ipt.shape = 1, crop_size, crop_size, 3
        a2b_opt = sess.run(a2b, feed_dict={a_real: a_real_ipt})
        a_img_opt = a2b_opt

        img_name = os.path.basename(a_list[i])
        img_name = 'market_' + img_name  # market_style
        im.imwrite(im.immerge(a_img_opt, 1, 1), a_save_dir + img_name)
        print('Save %s' % (a_save_dir + img_name))

    for i in range(len(b_list)):
        b_real_ipt = im.imresize(im.imread(b_list[i]), [crop_size, crop_size])
        b_real_ipt.shape = 1, crop_size, crop_size, 3
        b2a_opt = sess.run(b2a, feed_dict={b_real: b_real_ipt})
        b_img_opt = b2a_opt
        img_name = os.path.basename(b_list[i])
예제 #4
0
파일: datasets.py 프로젝트: rootkit/gan-lib
 def transform(self, data):
     image = imread(data[0])
     return normalize(image)
예제 #5
0
 def process_func(self, example_line):
     return imread(example_line)