Exemplo n.º 1
0
def validate(model, loader, epoch, device, args):
    model.eval()
    tic = time.time()
    total_loss = 0.0
    cc_loss = AverageMeter()
    kldiv_loss = AverageMeter()
    nss_loss = AverageMeter()
    sim_loss = AverageMeter()

    for (img, gt) in loader:
        img = img.to(device)
        gt = gt.to(device)
        pred_map = model(img)
        # Blurring
        blur_map = pred_map.cpu().squeeze(0).clone().numpy()
        blur_map = blur(blur_map).unsqueeze(0).to(device)

        cc_loss.update(cc(blur_map, gt))
        kldiv_loss.update(kldiv(blur_map, gt))
        nss_loss.update(nss(blur_map, gt))
        sim_loss.update(similarity(blur_map, gt))
    judge_loss = args.kldiv_coeff * kldiv_loss.avg + args.cc_coeff * cc_loss.avg + args.sim_coeff * sim_loss.avg
    #judge_loss = args.kldiv_coeff * kldiv_loss.avg + args.cc_coeff * cc_loss.avg
    #judge_loss = args.kldiv_coeff * kldiv_loss.avg
    print(
        '[{:2d},   val] CC : {:.5f}, KLDIV : {:.5f}, NSS : {:.5f}, SIM : {:.5f}, sum : {:.5f},  time:{:3f} minutes'
        .format(epoch, cc_loss.avg, kldiv_loss.avg, nss_loss.avg, sim_loss.avg,
                judge_loss, (time.time() - tic) / 60))
    sys.stdout.flush()

    return judge_loss
Exemplo n.º 2
0
def texture_loss(target, prediction, adv_1):
    prediction1 = tf.image.rgb_to_grayscale(blur(19, prediction))
    target1 = tf.image.rgb_to_grayscale(blur(19, target))
    enhanced1 = tf.reshape(prediction1, [-1, PATCH_WIDTH * PATCH_HEIGHT])
    dslr1 = tf.reshape(target1, [-1, PATCH_WIDTH * PATCH_HEIGHT])
    adversarial_1 = tf.multiply(enhanced1, 1 - adv_1) + tf.multiply(
        dslr1, adv_1)
    adversarial_image1 = tf.reshape(adversarial_1,
                                    [-1, PATCH_HEIGHT, PATCH_WIDTH, 1])
    discrim_predictions1 = models.adversarial_1(adversarial_image1)
    discrim_target1 = tf.concat([adv_1, 1 - adv_1], 1)

    loss_texture = -tf.reduce_sum(discrim_target1 * tf.log(
        tf.clip_by_value(discrim_predictions1, 1e-10, 1.0)))
    correct_predictions1 = tf.equal(tf.argmax(discrim_predictions1, 1),
                                    tf.argmax(discrim_target1, 1))
    discim_accuracy1 = tf.reduce_mean(tf.cast(correct_predictions1,
                                              tf.float32))
    return -loss_texture, discim_accuracy1
def dataset(dataset_train_path,batch_size,scale_factor):
    assert(os.path.exists(dataset_train_path))
    data = []
    for file in os.listdir(dataset_train_path):
        if file.endswith('.bmp'):
            filepath = os.path.join(dataset_train_path,file)
            img = imageio.imread(filepath).dot([0.299, 0.587, 0.114])
            patches = extract_patches(img,(36,36),0.166)

            data += [patches[idx] for idx in range(patches.shape[0])]

    mod_data = [from_numpy(np.expand_dims(blur(upscale(patch,scale_factor),scale_factor),0)).float() for patch in data]
    data = [from_numpy(np.expand_dims(upscale(patch,scale_factor),0)).float() for patch in data]
    l = len(data)
    for idx in range(0,l,batch_size):
        yield stack(mod_data[idx:min(idx+batch_size,l)]),stack(data[idx:min(idx+batch_size,l)])
Exemplo n.º 4
0
def main():
    """Main function
    """
    # Load one of these sample image, show different color channels.
    img = load_image(os.path.join('samples', 'IMG_6566.JPG'))
    show_custom_channels(img, color_space='rgb', title='Input image')

    # Zoom in and show a small window to see triplet of color values for a
    # 64x64 (or so) window
    zoomed = zoom_in(img, 850, 950, height=500, width=500)
    show_custom_channels(zoomed, color_space='rgb', title='Zoomed-in window')

    # Separate H&E color stain channels from the image
    channel_lst, cmap_lst = show_custom_channels(
        zoomed,
        color_space='hed',
        title='Immunohistochemical staining colors separation')

    # Select eosin channel for processing
    sel_chn = np.copy(channel_lst[1])
    sel_cmap = cmap_lst[1]

    # Add noise and do a simple denoising task
    noised = add_noise(sel_chn)
    denoised = simple_denoise(noised, kernel_size=3)
    show_with_cmap([sel_chn, noised, denoised], [sel_cmap] * 3, [
        'Original image', 'With Gaussian noise', 'Denoised with Median filter'
    ])

    # Apply blurring and add noise and do a simple deblurring task, using the
    # Wiener filter
    blurred_noised = add_noise(blur(sel_chn, block_size=3), sigma=3)
    deblurred = simple_deblur(blurred_noised)
    show_with_cmap([sel_chn, blurred_noised, deblurred], [sel_cmap] * 3, [
        'Original image', 'With blurring and noise',
        'Deblurred with Wiener filter'
    ])

    # Detect cell boundary and overlay the results on images
    detect_cell_boundary(sel_chn, sel_cmap)
    plt.show()
    pass
Exemplo n.º 5
0
    discim_accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

    # 2) content loss
    '''
    CONTENT_LAYER = 'relu5_4'

    enhanced_vgg = vgg.net(vgg_dir, vgg.preprocess(enhanced * 255))
    dslr_vgg = vgg.net(vgg_dir, vgg.preprocess(dslr_image * 255))

    content_size = utils._tensor_size(dslr_vgg[CONTENT_LAYER]) * batch_size
    loss_content = 2 * tf.nn.l2_loss(enhanced_vgg[CONTENT_LAYER] - dslr_vgg[CONTENT_LAYER]) / content_size
    '''
    loss_content = loss.content_loss(dslr_image, enhanced, batch_size)

    # 3) color loss
    enhanced_blur = lutils.blur(enhanced)
    dslr_blur = lutils.blur(dslr_image)

    #loss_color = tf.reduce_sum(tf.pow(dslr_blur - enhanced_blur, 2))/(2 * batch_size)
    loss_color = tf.reduce_sum(
        tf.abs(dslr_image - enhanced)) / (2 * batch_size)

    #loss_color = loss.color_loss(dslr_image, enhanced, batch_size)

    # 4) total variation loss

    batch_shape = (batch_size, PATCH_WIDTH, PATCH_HEIGHT, 3)
    tv_y_size = lutils._tensor_size(enhanced[:, 1:, :, :])
    tv_x_size = lutils._tensor_size(enhanced[:, :, 1:, :])
    y_tv = tf.nn.l2_loss(enhanced[:, 1:, :, :] -
                         enhanced[:, :batch_shape[1] - 1, :, :])
Exemplo n.º 6
0
            # 2) content loss

            CONTENT_LAYER = 'relu5_4'

            enhanced_vgg = vgg.net(vgg_dir, vgg.preprocess(enhanced * 255))
            dslr_vgg = vgg.net(vgg_dir, vgg.preprocess(dslr_image[i] * 255))

            content_size = utils._tensor_size(
                dslr_vgg[CONTENT_LAYER]) * batch_size / GPU_NUM
            loss_content = 2 * tf.nn.l2_loss(
                enhanced_vgg[CONTENT_LAYER] -
                dslr_vgg[CONTENT_LAYER]) / content_size

            # 3) color loss

            enhanced_blur = utils.blur(enhanced)
            dslr_blur = utils.blur(dslr_image[i])

            loss_color = tf.reduce_sum(tf.pow(dslr_blur - enhanced_blur,
                                              2)) / (2 * batch_size / GPU_NUM)

            # 4) total variation loss

            batch_shape = (batch_size / GPU_NUM, PATCH_WIDTH, PATCH_HEIGHT, 3)
            tv_y_size = utils._tensor_size(enhanced[:, 1:, :, :])
            tv_x_size = utils._tensor_size(enhanced[:, :, 1:, :])
            y_tv = tf.nn.l2_loss(enhanced[:, 1:, :, :] -
                                 enhanced[:, :batch_shape[1] - 1, :, :])
            x_tv = tf.nn.l2_loss(enhanced[:, :, 1:, :] -
                                 enhanced[:, :, :batch_shape[2] - 1, :])
            loss_tv = 2 * (x_tv / tv_x_size +
Exemplo n.º 7
0
base_path = 'F://university/course/4th_1/acquisition/prj'
if args.overfit:
    train_names = glob(base_path+'/overfit/*.jpg')
else:
    train_names = glob(base_path+'/train/*.jpg')
for i in range(len(train_names)):
    name = train_names[i]
    img = io.imread(name,as_grey=True)
    if img.dtype == np.uint8:
        img = np.float32(img/255.0)
    patches_gt,_ = utils.Img2patch(img,patch_size=img_size)
    train_patches_gt.append(patches_gt)

    # blur
    kernel = utils.kernel_generator(kernel_size=36)
    img_blur = utils.blur(img,kernel)
    patches_blur, _ = utils.Img2patch(img_blur, patch_size=img_size)
    train_patches_blur.append(patches_blur)

    # noise
    sigma = np.random.randint(20,30)
    img_noise = utils.noise(img, sigma)
    patches_noise, _ = utils.Img2patch(img_noise, patch_size=img_size)
    train_patches_noise.append(patches_noise)
    if i%100==0:
        print(i,' loaded')

# (2000, x, 256) -> (sigma, 256)
# Do not use cancatenate, too slow
train_patches_gt = np.vstack(train_patches_gt)
train_patches_blur = np.vstack(train_patches_blur)
Exemplo n.º 8
0
def train_and_validate(dataset_path,
                       batch_size,
                       scale_factor,
                       num_epochs,
                       learning_rate,
                       weight_decay,
                       output_dir,
                       verbose=True):

    model_output_dir = os.path.join(output_dir, 'model')

    model = network.ten()

    logging.info('TEN Model loaded')
    if verbose:
        print('TEN Model loaded')
        total_params = sum(p.numel() for p in model.parameters())
        print(f'Total Parameters: {total_params}')

    if cuda:
        model = model.cuda()

    model.train()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=weight_decay)

    logging.info('Adam optimizer loaded')
    if verbose:
        print('Adam optimizer loaded')

    total_epoch_time = 0
    losses = []

    # Set initial best_loss arbitrarily high
    best_val_loss = 2.0e50

    for epoch in range(1, num_epochs + 1):

        model.train()

        logging.info(f'Epoch {epoch} Start')
        if verbose:
            print(f'\n<----- START EPOCH {epoch} ------->\n')

        start_time = time()

        total_loss_for_this_epoch = 0

        # For each batch
        batch = 1
        for patches, gt in dataset(dataset_path, batch_size, scale_factor):

            optimizer.zero_grad()

            if cuda:
                patches = patches.cuda()
                gt = gt.cuda()

            pred = model(patches)
            loss = compute_loss(pred, gt)

            logging.info(f'Epoch {epoch} Batch {batch} Loss {loss.item()}')
            if verbose and (batch - 1) % 10 == 0:
                print(f'Epoch {epoch} Batch {batch} Loss {loss.item()}')

            loss.backward()
            optimizer.step()

            total_loss_for_this_epoch += loss.item()

            batch += 1

        avg_loss = total_loss_for_this_epoch / batch
        losses.append(avg_loss)

        epoch_time = time() - start_time
        if verbose:
            print(f'Epoch time: {epoch_time}')
        total_epoch_time += epoch_time

        # Validation
        model.eval()
        val_img_file = random.choice(
            [f for f in os.listdir(dataset_path) if f.endswith('.bmp')])
        val_img = imageio.imread(os.path.join(dataset_path, val_img_file)).dot(
            [0.299, 0.587, 0.114])
        mod_val_img = torch.from_numpy(
            blur(upscale(val_img, scale_factor),
                 scale_factor)).float().unsqueeze(0).unsqueeze(0)
        val_img = torch.from_numpy(upscale(
            val_img, scale_factor)).float().unsqueeze(0).unsqueeze(0)
        if cuda:
            mod_val_img = mod_val_img.cuda()
            val_img = val_img.cuda()
        out = model(mod_val_img)
        val_loss = compute_loss(out, val_img).item()

        if verbose:
            print(
                f'Epoch {epoch} Validation Image {val_img_file} Loss {val_loss}'
            )

        # Save current model
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_output_dir, 'current.pth')
        logging.info('Current model saved')
        if verbose:
            print('Current model saved')

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, model_output_dir, 'best.pth')
            logging.info('Best model saved')
            if verbose:
                print('Best model saved')

        # Save model every 20 epochs
        if (epoch) % 20 == 0:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, model_output_dir, f'epoch_{epoch}.pth')
            logging.info(f'Epoch {epoch} Model saved')
            if verbose:
                print(f'Epoch {epoch} Model saved')

        # Learning rate decay

        if epoch % 30 == 0 and epoch <= 60:
            learning_rate = learning_rate / 10
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
            logging.info(
                f'Epoch {epoch}: Learning rate decayed by factor of 10')

        logging.info(f'Epoch {epoch} completed')
        if verbose:
            print(
                f'\n<----- END EPOCH {epoch} Time elapsed: {time()-start_time}------->\n'
            )

    logging.info('All epochs completed')
    logging.info(f'Average Time: {total_epoch_time/num_epochs:.4f} seconds')
    logging.info(f'Average Loss: {sum(losses) / len(losses):.4f}')
    if verbose:
        print('All epochs completed')
        print(f'Average Time: {total_epoch_time/num_epochs:.4f} seconds')
        print(f'Average Loss: {sum(losses) / len(losses):.4f}')

    if verbose:
        print('Losses array: ', losses)
        print('Best Validation Loss', best_val_loss)
import fourrier

path = "../Tests/usa.mp4"
cap = cv2.VideoCapture(path)
ret, now = cap.read()

images = []
i = 1

background = np.zeros_like(now) + 0.5
num_updates = np.zeros_like(now) + 1

while (i >= 1):

    ret, now = cap.read()
    now = utils.blur(now)

    # manage image bundle
    images.append(cv2.cvtColor(now, cv2.COLOR_BGR2GRAY) / 255)
    if len(images) > 10:
        images.pop(0)

    # get motion detection
    i += 1

    # frame skip
    if i % 4 == 1:
        motion = fourrier.image_fft(np.asarray(images))[0, :, :]
        thresh = utils.white_mask(motion)

        masked = now.copy()
Exemplo n.º 10
0
def main(args):
  setproctitle.setproctitle('hdrnet_run')

  inputs = get_input_list(args.input)


  # -------- Load params ----------------------------------------------------
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(config=config) as sess:
    checkpoint_path = tf.train.latest_checkpoint(args.checkpoint_dir)
    if checkpoint_path is None:
      log.error('Could not find a checkpoint in {}'.format(args.checkpoint_dir))
      return


  # -------- Setup graph ----------------------------------------------------
  tf.reset_default_graph()
  t_fullres_input = tf.placeholder(tf.float32, (1, width, heighth, 3))
  target = tf.placeholder(tf.float32, (1, width, heighth, 3))
  t_lowres_input = utils.blur(5,t_fullres_input)
  img_low = tf.image.resize_images(
                    t_lowres_input, [width/args.scale, heighth/args.scale],
                    method=tf.image.ResizeMethod.BICUBIC)
  img_high = utils.Getfilter(5,t_fullres_input)


  with tf.variable_scope('inference'):
    prediction = models.Resnet(img_low,img_high,t_fullres_input)
  ssim = MultiScaleSSIM(target,prediction)
  psnr = metrics.psnr(target, prediction)
  saver = tf.train.Saver()

  start = time.clock()
  with tf.Session(config=config) as sess:
    log.info('Restoring weights from {}'.format(checkpoint_path))
    saver.restore(sess, checkpoint_path)
    SSIM = 0
    PSNR = 0
    for idx, input_path in enumerate(inputs):
      target_path = args.target + input_path.split('/')[2]
      log.info("Processing {}".format(input_path,target_path))
      im_input = cv2.imread(input_path, -1)  # -1 means read as is, no conversions.
      im_target = cv2.imread(target_path, -1)

      if im_input.shape[2] == 4:
        log.info("Input {} has 4 channels, dropping alpha".format(input_path))
        im_input = im_input[:, :, :3]
        im_target = im_target[:, :, :3]


      im_input = np.flip(im_input, 2)  # OpenCV reads BGR, convert back to RGB.
      im_target = np.flip(im_target, 2)


      im_input = skimage.img_as_float(im_input)
      im_target = skimage.img_as_float(im_target)


      im_input = im_input[np.newaxis, :, :, :]
      im_target = im_target[np.newaxis, :, :, :]


      feed_dict = {
          t_fullres_input: im_input,
          target:im_target
      }


      ssim1,psnr1 =  sess.run([ssim,psnr], feed_dict=feed_dict)
      SSIM = SSIM + ssim1
      PSNR = PSNR + psnr1
      if idx>=1000:
        break
    print("SSIM:%s,PSNR:%s"%(SSIM/1000,PSNR/1000))
  end = time.clock()
  print("耗时%s秒"%str(end-start))
Exemplo n.º 11
0
def main(args):
    setproctitle.setproctitle('hdrnet_run')

    inputs = get_input_list(args.input)

    # -------- Load params ----------------------------------------------------
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        checkpoint_path = tf.train.latest_checkpoint(args.checkpoint_dir)
        if checkpoint_path is None:
            log.error('Could not find a checkpoint in {}'.format(
                args.checkpoint_dir))
            return

        # metapath = ".".join([checkpoint_path, "meta"])
        # log.info('Loading graph from {}'.format(metapath))
        # tf.train.import_meta_graph(metapath)

        # model_params = utils.get_model_params(sess)

    # -------- Setup graph ----------------------------------------------------
    tf.reset_default_graph()
    t_fullres_input = tf.placeholder(tf.float32, (1, width, heighth, 3))
    t_lowres_input = utils.blur(5, t_fullres_input)
    img_low = tf.image.resize_images(
        t_lowres_input, [width / args.scale, heighth / args.scale],
        method=tf.image.ResizeMethod.BICUBIC)
    img_high = utils.Getfilter(5, t_fullres_input)

    with tf.variable_scope('inference'):
        prediction = models.Resnet(img_low, img_high, t_fullres_input)
    output = tf.cast(255.0 * tf.squeeze(tf.clip_by_value(prediction, 0, 1)),
                     tf.uint8)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        log.info('Restoring weights from {}'.format(checkpoint_path))
        saver.restore(sess, checkpoint_path)

        for idx, input_path in enumerate(inputs):

            log.info("Processing {}".format(input_path))
            im_input = cv2.imread(input_path,
                                  -1)  # -1 means read as is, no conversions.
            if im_input.shape[2] == 4:
                log.info("Input {} has 4 channels, dropping alpha".format(
                    input_path))
                im_input = im_input[:, :, :3]

            im_input = np.flip(im_input,
                               2)  # OpenCV reads BGR, convert back to RGB.

            # log.info("Max level: {}".format(np.amax(im_input[:, :, 0])))
            # log.info("Max level: {}".format(np.amax(im_input[:, :, 1])))
            # log.info("Max level: {}".format(np.amax(im_input[:, :, 2])))

            # HACK for HDR+.
            if im_input.dtype == np.uint16 and args.hdrp:
                log.info(
                    "Using HDR+ hack for uint16 input. Assuming input white level is 32767."
                )
                # im_input = im_input / 32767.0
                # im_input = im_input / 32767.0 /2
                # im_input = im_input / (1.0*2**16)
                im_input = skimage.img_as_float(im_input)
            else:
                im_input = skimage.img_as_float(im_input)

            # Make or Load lowres image
            # lowres_input = skimage.transform.resize(
            #   im_input, [im_input.shape[0]/args.scale, im_input.shape[1]/args.scale], order = 0)
            # im_input = cv2.resize(lowres_input,(2000,1500),interpolation=cv2.INTER_CUBIC)
            # im_input1 = utils.blur(im_input)
            # lowres_input = cv2.resize(im_input1, (im_input1.shape[1]/args.scale,im_input1.shape[0]/args.scale),
            #                           interpolation=cv2.INTER_CUBIC )

            fname = os.path.splitext(os.path.basename(input_path))[0]
            output_path = os.path.join(args.output, fname + ".png")
            basedir = os.path.dirname(output_path)

            im_input = im_input[np.newaxis, :, :, :]
            # lowres_input = lowres_input[np.newaxis, :, :, :]

            feed_dict = {
                t_fullres_input: im_input
                # t_lowres_input: lowres_input
            }

            out_ = sess.run(output, feed_dict=feed_dict)

            if not os.path.exists(basedir):
                os.makedirs(basedir)

            skimage.io.imsave(output_path, out_)
Exemplo n.º 12
0
    orig_image = tf.reshape(orig_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])

    # get processed enhanced image

    enhanced = models.resnet(bad_image)

    # transform both orig and enhanced images to grayscale

    enhanced_gray = tf.reshape(tf.image.rgb_to_grayscale(enhanced),
                               [-1, PATCH_WIDTH * PATCH_HEIGHT])
    orig_gray = tf.reshape(tf.image.rgb_to_grayscale(orig_image),
                           [-1, PATCH_WIDTH * PATCH_HEIGHT])

    #  color loss

    enhanced_blur = utils.blur(enhanced)
    orig_blur = utils.blur(orig_image)

    loss_color = tf.reduce_sum(tf.pow(orig_blur - enhanced_blur,
                                      2)) / (2 * batch_size)
    loss_generator = w_color * loss_color

    # optimize parameters of image enhancement

    generator_vars = [
        v for v in tf.global_variables() if v.name.startswith("generator")
    ]
    train_step_gen = tf.train.AdamOptimizer(learning_rate).minimize(
        loss_generator, var_list=generator_vars)

    saver = tf.train.Saver(var_list=generator_vars, max_to_keep=100)
Exemplo n.º 13
0
def main(args, data_params):
  procname = os.path.basename(args.checkpoint_dir)
  #setproctitle.setproctitle('hdrnet_{}'.format(procname))

  log.info('Preparing summary and checkpoint directory {}'.format(
      args.checkpoint_dir))
  if not os.path.exists(args.checkpoint_dir):
    os.makedirs(args.checkpoint_dir)

  tf.set_random_seed(1234)  # Make experiments repeatable

  # Select an architecture

  # Add model parameters to the graph (so they are saved to disk at checkpoint)

  # --- Train/Test datasets ---------------------------------------------------
  data_pipe = getattr(dp, args.data_pipeline)
  with tf.variable_scope('train_data'):
    train_data_pipeline = data_pipe(
        args.data_dir,
        shuffle=True,
        batch_size=args.batch_size, nthreads=args.data_threads,
        fliplr=args.fliplr, flipud=args.flipud, rotate=args.rotate,
        random_crop=args.random_crop, params=data_params,
        output_resolution=args.output_resolution,scale=args.scale)
    train_samples = train_data_pipeline.samples

    train_samples['high_input'] = Getfilter(5,train_samples['image_input'])

    train_samples['lowres_input1'] = blur(5,train_samples['lowres_input'])
    train_samples['low_input'] = tf.image.resize_images(train_samples['lowres_input1'],
                      [args.output_resolution[0]/args.scale, args.output_resolution[1]/args.scale],
                      method = tf.image.ResizeMethod.BICUBIC)



  if args.eval_data_dir is not None:
    with tf.variable_scope('eval_data'):
      eval_data_pipeline = data_pipe(
          args.eval_data_dir,
          shuffle=False,
          batch_size=1, nthreads=1,
          fliplr=False, flipud=False, rotate=False,
          random_crop=False, params=data_params,
          output_resolution=args.output_resolution,scale=args.scale)
      eval_samples = train_data_pipeline.samples
  # ---------------------------------------------------------------------------
  swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1])
  swaps = tf.convert_to_tensor(swaps)
  swaps = tf.cast(swaps, tf.float32)

  swaps1 = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1])
  swaps1 = tf.convert_to_tensor(swaps1)
  swaps1 = tf.cast(swaps1, tf.float32)

  # Training graph
  with tf.variable_scope('inference'):
    prediction = models.Resnet(train_samples['low_input'],train_samples['high_input'],train_samples['image_input'])
    loss,loss_content,loss_color,loss_filter,loss_texture,loss_tv,discim_accuracy,discim_accuracy1 =\
      metrics.l2_loss(train_samples['image_output'], prediction, swaps, swaps1, args.batch_size)
    psnr = metrics.psnr(train_samples['image_output'], prediction)
    loss_ssim = MultiScaleSSIM(train_samples['image_output'],prediction)


  # Evaluation graph
  if args.eval_data_dir is not None:
    with tf.name_scope('eval'):
      with tf.variable_scope('inference', reuse=True):
        eval_prediction = models.Resnet( eval_samples['low_input'],eval_samples['high_input'],eval_samples['image_input'])
      eval_psnr = metrics.psnr(eval_samples['image_output'], eval_prediction)

  # Optimizer
  model_vars = [v for v in tf.global_variables() if not v.name.startswith("inference/l2_loss/discriminator") or  v.name.startswith("inference/l2_loss/discriminator1")]
  discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("inference/l2_loss/discriminator")]
  discriminator_vars1 = [v for v in tf.global_variables() if v.name.startswith("inference/l2_loss/discriminator1")]

  global_step = tf.contrib.framework.get_or_create_global_step()
  with tf.name_scope('optimizer'):
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    updates = tf.group(*update_ops, name='update_ops')
    log.info("Adding {} update ops".format(len(update_ops)))

    reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    if reg_losses and args.weight_decay is not None and args.weight_decay > 0:
      print("Regularization losses:")
      for rl in reg_losses:
        print(" ", rl.name)
      opt_loss = loss + args.weight_decay*sum(reg_losses)
    else:
      print("No regularization.")
      opt_loss = loss


    with tf.control_dependencies([updates]):
      opt = tf.train.AdamOptimizer(args.learning_rate)
      minimize = opt.minimize(opt_loss, name='optimizer', global_step=global_step,var_list=model_vars)
      minimize1 = opt.minimize(-loss_filter, name='optimizer1', global_step=global_step,var_list=discriminator_vars)
      minimize2 = opt.minimize(-loss_texture, name='optimizer2', global_step=global_step, var_list=discriminator_vars1)


  # Average loss and psnr for display
  with tf.name_scope("moving_averages"):
    ema = tf.train.ExponentialMovingAverage(decay=0.99)
    update_ma = ema.apply([loss,loss_content,loss_color,loss_filter,loss_texture,loss_tv,discim_accuracy,discim_accuracy1,psnr,loss_ssim])
    loss = ema.average(loss)
    loss_content=ema.average(loss_content)
    loss_color=ema.average(loss_color)
    loss_filter=ema.average(loss_filter)
    loss_texture=ema.average(loss_texture)
    loss_tv=ema.average(loss_tv)
    discim_accuracy = ema.average(discim_accuracy)
    discim_accuracy1 = ema.average(discim_accuracy1)
    psnr = ema.average(psnr)
    loss_ssim = ema.average(loss_ssim)

  # Training stepper operation
  train_op = tf.group(minimize,minimize1,minimize2,update_ma)

  # Save a few graphs to tensorboard
  summaries = [
    tf.summary.scalar('loss', loss),
    tf.summary.scalar('loss_content',loss_content),
    tf.summary.scalar('loss_color',loss_color),
    tf.summary.scalar('loss_filter', loss_filter),
    tf.summary.scalar('loss_texture', loss_texture),
    tf.summary.scalar('loss_tv', loss_tv),
    tf.summary.scalar('discim_accuracy',discim_accuracy),
    tf.summary.scalar('discim_accuracy1', discim_accuracy1),
    tf.summary.scalar('psnr', psnr),
    tf.summary.scalar('ssim', loss_ssim),
    tf.summary.scalar('learning_rate', args.learning_rate),
    tf.summary.scalar('batch_size', args.batch_size),
  ]

  log_fetches = {
      "loss_content":loss_content,
      "loss_color":loss_color,
      "loss_filter":loss_filter,
      "loss_texture": loss_texture,
      "loss_tv":loss_tv,
      "discim_accuracy":discim_accuracy,
      "discim_accuracy1": discim_accuracy1,
      "step": global_step,
      "loss": loss,
      "psnr": psnr,
      "loss_ssim":loss_ssim}

  model_vars = [v for v in tf.global_variables() if not v.name.startswith("inference/l2_loss/discriminator" or "inference/l2_loss/discriminator1")]
  discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("inference/l2_loss/discriminator")]
  discriminator_vars1 = [v for v in tf.global_variables() if v.name.startswith("inference/l2_loss/discriminator1")]

  # Train config
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True  # Do not canibalize the entire GPU

  sv = tf.train.Supervisor(
      local_init_op=tf.initialize_variables(discriminator_vars),
      saver=tf.train.Saver(var_list=model_vars,max_to_keep=100),
      logdir=args.checkpoint_dir,
      save_summaries_secs=args.summary_interval,
      save_model_secs=args.checkpoint_interval)
  # Train loop
  with sv.managed_session(config=config) as sess:
    sv.loop(args.log_interval, log_hook, (sess,log_fetches))
    last_eval = time.time()
    while True:
      if sv.should_stop():
        log.info("stopping supervisor")
        break
      try:
        step, _ = sess.run([global_step, train_op])
        since_eval = time.time()-last_eval

        if args.eval_data_dir is not None and since_eval > args.eval_interval:
          log.info("Evaluating on {} images at step {}".format(
              eval_data_pipeline.nsamples, step))

          p_ = 0
          eval_data_pipeline.nsamples = 3
          for it in range(eval_data_pipeline.nsamples):
            p_ += sess.run(eval_psnr)
          p_ /= eval_data_pipeline.nsamples

          sv.summary_writer.add_summary(tf.Summary(value=[
            tf.Summary.Value(tag="psnr/eval", simple_value=p_)]), global_step=step)

          log.info("  Evaluation PSNR = {:.1f} dB".format(p_))


          last_eval = time.time()

      except tf.errors.AbortedError:
        log.error("Aborted")
        break
      except KeyboardInterrupt:
        break
    chkpt_path = os.path.join(args.checkpoint_dir, 'on_stop.ckpt')
    log.info("Training complete, saving chkpt {}".format(chkpt_path))
    sv.saver.save(sess, chkpt_path)
    sv.request_stop()
Exemplo n.º 14
0
    adv_ = tf.placeholder(tf.float32, [None, 1])
    adv_color_ = tf.placeholder(tf.float32, [None, 1])

    # get processed enhanced image

    enhanced = models.resnet(phone_image)
    phone_image_gen = models.resnet1(enhanced)

    # transform both dslr and enhanced images to grayscale

    enhanced_gray = tf.reshape(tf.image.rgb_to_grayscale(enhanced),
                               [-1, PATCH_WIDTH * PATCH_HEIGHT])
    dslr_gray = tf.reshape(tf.image.rgb_to_grayscale(dslr_image),
                           [-1, PATCH_WIDTH * PATCH_HEIGHT])
    enhanced_blur = tf.reshape(utils.blur(enhanced),
                               [-1, PATCH_WIDTH * PATCH_HEIGHT * 3])
    dslr_blur = tf.reshape(utils.blur(dslr_image),
                           [-1, PATCH_WIDTH * PATCH_HEIGHT * 3])

    # push randomly the enhanced or dslr image to an adversarial CNN-discriminator

    adversarial_ = tf.multiply(enhanced_gray, 1 - adv_) + tf.multiply(
        dslr_gray, adv_)
    adversarial_image = tf.reshape(adversarial_,
                                   [-1, PATCH_HEIGHT, PATCH_WIDTH, 1])
    adversarial_color_ = tf.multiply(
        enhanced_blur, 1 - adv_color_) + tf.multiply(dslr_blur, adv_color_)
    adversarial_color_image = tf.reshape(adversarial_color_,
                                         [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])
Exemplo n.º 15
0
        if do_rotate:
            rot_ccw = lambda im: rotate(im, ROTATE_ANGLE)
            apply_modification_and_save(img, img_fp, 'rot_ccw', rot_ccw)

            rot_cw = lambda im: rotate(im, -ROTATE_ANGLE)
            apply_modification_and_save(img, img_fp, 'rot_cw', rot_cw)

        if do_skew:
            skew_r = lambda im: skew(im, SKEW_ANGLE)
            apply_modification_and_save(img, img_fp, 'skew_r', skew_r)

            skew_l = lambda im: skew(im, -SKEW_ANGLE)
            apply_modification_and_save(img, img_fp, 'skew_l', skew_l)

        if do_blur:
            blur_ = lambda im: blur(im, BLUR_RADIUS)
            apply_modification_and_save(img, img_fp, 'blur', blur_)

        if do_underline:
            ul = lambda im: underline(im, FONT_SIZE, BORDER)
            apply_modification_and_save(img, img_fp, 'ul', ul)

        if do_complex:
            skew_r_blur = lambda im: blur(skew(im, SKEW_ANGLE), BLUR_RADIUS)
            apply_modification_and_save(img, img_fp, 'skew_r_blur',
                                        skew_r_blur)

            skew_l_blur = lambda im: blur(skew(im, -SKEW_ANGLE), BLUR_RADIUS)
            apply_modification_and_save(img, img_fp, 'skew_l_blur',
                                        skew_l_blur)
Exemplo n.º 16
0
        loss_discrim_texture, loss_texture, discrim_accuracy_gray = models.discriminator_loss(enhanced_gray, dslr_gray, adv_, discrim_target)
    else:
        loss_discrim_texture = zero_
        loss_texture = zero_
        discrim_accuracy_gray = zero_

    # content loss
    CONTENT_LAYER = 'relu5_4'
    phone_vgg = vgg.net(vgg_dir, vgg.preprocess(phone_image * 255))
    rec_vgg = vgg.net(vgg_dir, vgg.preprocess(rec_image * 255))

    content_size = utils._tensor_size(phone_vgg[CONTENT_LAYER]) * batch_size
    loss_content = 2 * tf.nn.l2_loss(phone_vgg[CONTENT_LAYER] - rec_vgg[CONTENT_LAYER]) / content_size

    # color loss
    enhanced_blur = tf.reshape(utils.blur(enhanced), [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])
    dslr_blur = tf.reshape(utils.blur(dslr_image), [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])
    loss_discrim_color, loss_color, discrim_accuracy_color = models.discriminator_loss(enhanced_blur, dslr_blur, adv_, discrim_target)

    # gradient loss
    if w_gradient != 0:
        enhanced_gradient = tf.reshape(utils.gradient(enhanced_gray), [-1, PATCH_HEIGHT, PATCH_WIDTH, 2])
        dslr_gradient = tf.reshape(utils.gradient(dslr_gray), [-1, PATCH_HEIGHT, PATCH_WIDTH, 2])
        loss_discrim_gradient, loss_gradient, discrim_accuracy_gradient = models.discriminator_loss(enhanced_gradient, dslr_gradient, adv_, discrim_target)
    else:
        loss_discrim_gradient = zero_
        loss_gradient = zero_
        discrim_accuracy_gradient = zero_

    #laplacian loss
    if w_laplacian != 0: