Exemplo n.º 1
0
def evaluate(model_file, data_path, image_index=0):
    def calc_psnr(y, y_target):
        mse = np.mean((y - y_target)**2)
        if mse == 0:
            return 100
        return 20. * math.log10(1. / math.sqrt(mse))

    interpreter = tf.lite.Interpreter(model_path=model_file)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    input_shape = input_details[0]['shape']
    output_shape = output_details[0]['shape']
    div2k = DIV2K(data_path, patch_size=0)
    # Get lr, hr image pair
    lr, hr = div2k[image_index]
    # Check if image size can be used for inference
    if lr.shape[0] < input_shape[1] or lr.shape[1] < input_shape[2]:
        print(
            f'Eval image {image_index} has invalid dimensions. Expecting h >= {input_shape[1]} and w >= {input_shape[2]}.'
        )
        raise ValueError
    # Crop lr, hr images to match fixed shapes of the tensorflow lite model
    lr = lr[:input_shape[1], :input_shape[2]]
    lr = np.expand_dims(lr, 0)
    lr = np.expand_dims(lr, -1)
    interpreter.set_tensor(input_details[0]['index'], lr)
    interpreter.invoke()
    sr = interpreter.get_tensor(output_details[0]['index']).squeeze()
    hr = hr[:output_shape[1], :output_shape[2]]
    return np.clip(np.round(sr * 255.), 0, 255).astype(np.uint8), np.clip(
        np.round(hr * 255.), 0, 255).astype(np.uint8), calc_psnr(sr, hr)
Exemplo n.º 2
0
def get_data_loader(cfg, data_dir, batch_size=None):
    batch_size = cfg["batch_size"] if batch_size is None else batch_size
    transform = transforms.Compose([
        RandomCrop(cfg["hr_crop_size"], cfg["scale"]), 
        ToTensor()])
    dataset = DIV2K(data_dir, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
Exemplo n.º 3
0
def train(model, dataset_path, scale_factor=3, num_epochs=1, batch_size=32):
    train_gen = DIV2K(dataset_path,
                      scale_factor=scale_factor,
                      batch_size=batch_size)
    model.compile(optimizer='adam', loss='mse')
    model.fit(train_gen, epochs=num_epochs, workers=8)
    return model
Exemplo n.º 4
0
 def representative_dataset_gen():
     div2k = DIV2K(data_path, scale_factor=scale_factor, patch_size=0)
     for i in range(50):
         x, _ = div2k[i]
         # Skip images that are not witin input h,w boundaries
         if x.shape[0] > input_shape[1] and x.shape[1] > input_shape[2]:
             # crop to input shape starting for top left corner of image
             x = x[:input_shape[1], :input_shape[2]]
             x = np.expand_dims(x, 0)
             x = np.expand_dims(x, -1)
             yield [x]
Exemplo n.º 5
0
    "converts the input.jpg to output.jpg using the pretrained model makin it 4x larger.",
    action="store_true")

args = parser.parse_args()

if args.pretrained:
    img = cv2.imread('input.jpg')
    generator = generator()
    generator.load_weights('pre-trained/generator.h5')
    x, y, c = img.shape
    img = generator.predict(tf.reshape(img, [1, x, y, c]))
    cv2.imwrite('output.jpg', img[0])

else:
    train_loader = DIV2K(scale=args.UPSCALING,
                         subset='train',
                         HR_SIZE=args.HR_PATCH_SIZE)
    train_ds = train_loader.dataset(batch_size=args.BATCH_SIZE,
                                    random_transform=True,
                                    repeat_count=None)
    valid_loader = DIV2K(scale=args.UPSCALING,
                         subset='valid',
                         HR_SIZE=args.HR_PATCH_SIZE)
    valid_ds = valid_loader.dataset(batch_size=1,
                                    random_transform=False,
                                    repeat_count=1)

    generator = generator()
    discriminator = discriminator(HR_SIZE=args.HR_PATCH_SIZE)

    pre_train(generator,
Exemplo n.º 6
0
def main(args):

    if not os.path.exists(args.output_folder):
        os.mkdir(args.output_folder)
    model_subs_path = os.path.join(args.output_folder, 'arch_img')
    if not os.path.exists(model_subs_path):
        os.mkdir(model_subs_path)
    weights_path = os.path.join(args.output_folder, 'weights')
    if not os.path.exists(weights_path):
        os.mkdir(weights_path)
    valid_img_save_path = os.path.join(args.output_folder, 'valid_img')
    if not os.path.exists(valid_img_save_path):
        os.mkdir(valid_img_save_path)

    # Dataset
    train_loader = DIV2K(scale=args.train_scale,
                         downgrade='unknown',
                         subset='train')

    # Create a tf.data.Dataset
    train_ds = train_loader.dataset(batch_size=args.batch_size,
                                    random_transform=True)

    gene = Generator()

    ## load pre_trained_weights
    files_path = os.path.join(args.pre_trained_weight_path, '*.h5')
    latest_file_path = sorted(glob.iglob(files_path),
                              key=os.path.getctime,
                              reverse=True)[0]
    gene.load_weights(latest_file_path)

    disc = Discriminator(args.train_lr_size * args.train_scale)
    tf.keras.utils.plot_model(disc,
                              to_file=os.path.join(model_subs_path,
                                                   'discriminator.png'),
                              show_shapes=True,
                              show_layer_names=True,
                              expand_nested=True)
    content_model = Content_Net()
    tf.keras.utils.plot_model(content_model,
                              to_file=os.path.join(model_subs_path,
                                                   'content.png'),
                              show_shapes=True,
                              show_layer_names=True,
                              expand_nested=True)

    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        args.learning_schedule, args.learning_rate)

    generator_optimizer = tf.keras.optimizers.Adam(learning_rate_fn)
    discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate_fn)

    @tf.function
    def train_step(lr, hr, generator, discriminator, content):

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            ## re-scale
            ## lr: 0 ~ 1
            ## hr: -1 ~ 1
            lr = tf.cast(lr, tf.float32)
            hr = tf.cast(hr, tf.float32)
            lr = lr / 255
            hr = hr / 127.5 - 1

            sr = generator(lr, training=True)

            sr_output = discriminator(sr, training=True)
            hr_output = discriminator(hr, training=True)

            disc_loss = discriminator_loss(sr_output, hr_output)

            mse_loss = mse_based_loss(sr, hr)
            gen_loss = generator_loss(sr_output)
            cont_loss = content_loss(content, sr, hr)
            perceptual_loss = mse_loss + cont_loss + 1e-3 * gen_loss

        gradients_of_generator = gen_tape.gradient(
            perceptual_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(
            disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(
            zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(
            zip(gradients_of_discriminator, discriminator.trainable_variables))

        return perceptual_loss, disc_loss

    def valid_step(image_path, gene, step):
        img = np.array(Image.open(image_path))[..., :3]

        scaled_lr_img = tf.cast(img, tf.float32)
        scaled_lr_img = scaled_lr_img / 255
        scaled_lr_img = scaled_lr_img[np.newaxis, :, :, :]

        sr_img = gene(scaled_lr_img).numpy()

        sr_img = np.clip(sr_img, -1, 1)
        sr_img = (sr_img + 1) * 127.5
        sr_img = np.around(sr_img)
        sr_img = sr_img.astype(np.uint8)

        im = Image.fromarray(sr_img[0])

        im.save(
            os.path.join(valid_img_save_path, 'step_{0:07d}.png'.format(step)))
        print('[step_{}.png] save images'.format(step))

    start = time.time()
    for step, (lr, hr) in enumerate(train_ds.take(args.max_train_step + 1)):

        gen_loss, disc_loss = train_step(lr, hr, gene, disc, content_model)

        if step % args.interval_save_weight == 0:
            gene.save_weights(
                os.path.join(weights_path, 'gen_step_{}.h5'.format(step)))
            disc.save_weights(
                os.path.join(weights_path, 'disc_step_{}.h5'.format(step)))

        if step % args.interval_validation == 0:
            valid_step(args.valid_image_path, gene, step)

        if step % args.interval_show_info == 0:
            print(
                'step:{}/{} gen_loss {}, disc_loss {}, Training time is {} step/s'
                .format(step, args.max_train_step, gen_loss, disc_loss,
                        (time.time() - start) / args.interval_show_info))
            start = time.time()
Exemplo n.º 7
0
def main(args):

    if not os.path.exists(args.output_folder):
        os.mkdir(args.output_folder)
    model_subs_path = os.path.join(args.output_folder, 'arch_img')
    if not os.path.exists(model_subs_path):
        os.mkdir(model_subs_path)
    weights_path = os.path.join(args.output_folder, 'pre_weights')
    if not os.path.exists(weights_path):
        os.mkdir(weights_path)
    valid_img_save_path = os.path.join(args.output_folder, 'pre_valid_img')
    if not os.path.exists(valid_img_save_path):
        os.mkdir(valid_img_save_path)

    # Dataset
    train_loader = DIV2K(
        scale=args.train_scale, downgrade='unknown',
        subset='train')  # 'bicubic', 'unknown', 'mild' or 'difficult'

    # Create a tf.data.Dataset
    train_ds = train_loader.dataset(batch_size=args.batch_size,
                                    random_transform=True)

    # Define Model
    gene = Generator(args.train_lr_size, scale=args.train_scale)
    tf.keras.utils.plot_model(gene,
                              to_file=os.path.join(model_subs_path,
                                                   'generator.png'),
                              show_shapes=True,
                              show_layer_names=True,
                              expand_nested=True)
    generator_optimizer = tf.keras.optimizers.Adam(args.learning_rate)

    @tf.function
    def pre_train_step(lr, hr, generator):
        with tf.GradientTape() as gen_tape:
            ## re-scale
            ## lr: 0 ~ 1
            ## hr: -1 ~ 1
            lr = tf.cast(lr, tf.float32)
            hr = tf.cast(hr, tf.float32)
            lr = lr / 255
            hr = hr / 127.5 - 1

            sr = generator(lr, training=True)
            loss = mse_based_loss(sr, hr)

        gradients_of_generator = gen_tape.gradient(
            loss, generator.trainable_variables)
        generator_optimizer.apply_gradients(
            zip(gradients_of_generator, generator.trainable_variables))

        return loss

    def valid_step(image_path, gene, step):
        img = np.array(Image.open(image_path))[..., :3]

        scaled_lr_img = tf.cast(img, tf.float32)
        scaled_lr_img = scaled_lr_img / 255
        scaled_lr_img = scaled_lr_img[np.newaxis, :, :, :]

        sr_img = gene(scaled_lr_img).numpy()

        sr_img = np.clip(sr_img, -1, 1)
        sr_img = (sr_img + 1) * 127.5
        sr_img = np.around(sr_img)
        sr_img = sr_img.astype(np.uint8)

        im = Image.fromarray(sr_img[0])

        im.save(
            os.path.join(valid_img_save_path, 'step_{0:07d}.png'.format(step)))
        print('[step_{}.png] save images'.format(step))

    # Start Train
    start = time.time()
    for pre_step, (lr, hr) in enumerate(
            train_ds.take(args.max_pre_train_step + 1)):

        mse = pre_train_step(lr, hr, gene)

        if pre_step % args.interval_save_weight == 0:
            gene.save_weights(
                os.path.join(weights_path,
                             'pre_gen_step_{}.h5'.format(pre_step)))

        if pre_step % args.interval_validation == 0:
            valid_step(args.valid_image_path, gene, pre_step)

        if pre_step % args.interval_show_info == 0:
            print('step:{}/{} MSE_LOSS {}, Training time is {} step/s'.format(
                pre_step, args.max_pre_train_step, mse,
                (time.time() - start) / args.interval_show_info))
            start = time.time()