Esempio n. 1
0
if not os.path.exists(samples_dir): os.makedirs(samples_dir)

# testing loop
times = []
s = time.time()
for img_path in test_paths:
    # prepare data
    img_name = ntpath.basename(img_path).split('.')[0]
    img_lr = misc.imread(img_path, mode='RGB').astype(np.float)
    img_lr = misc.imresize(img_lr, (60, 80))
    im = preprocess(img_lr)
    im = np.expand_dims(im, axis=0)
    # generate enhanced image
    s = time.time()
    gen = generator.predict(im)
    gen = deprocess(gen)  # Rescale to 0-1
    tot = time.time() - s
    times.append(tot)
    # save sample images
    misc.imsave(os.path.join(samples_dir, img_name + '_gen.png'), gen[0])

# some statistics
num_test = len(test_paths)
if (num_test == 0):
    print("\nFound no images for test")
else:
    print("\nTotal images: {0}".format(num_test))
    Ttime = sum(times)
    print("Time taken: {0} sec at {1} fps".format(Ttime, num_test / Ttime))
    print("\nSaved generated images in in {0}\n".format(samples_dir))
Esempio n. 2
0
for img_path in test_paths:
    # prepare data
    img_name = ntpath.basename(img_path).split('.')[0]
    img_lrd = misc.imread(img_path, mode='RGB').astype(np.float)
    inp_h, inp_w, _ = img_lrd.shape  # save the input im-shape
    img_lrd = misc.imresize(img_lrd, (lr_height, lr_width))
    im = preprocess(img_lrd)
    im = np.expand_dims(im, axis=0)
    # generate enhanced image
    s = time.time()
    gen_op = generator.predict(im)
    gen_lr, gen_hr, gen_mask = gen_op[0], gen_op[1], gen_op[2]
    tot = time.time() - s
    times.append(tot)
    # save sample images
    gen_lr = deprocess(gen_lr).reshape(lr_shape)
    gen_hr = deprocess(gen_hr).reshape(hr_shape)
    gen_mask = gen_mask.reshape(lr_height, lr_width)
    # little clean-up of the saliency map
    # >> may add further post-processing for more informative map
    gen_mask[gen_mask < 0.1] = 0
    # reshape and save generated images for observation
    img_lrd = misc.imresize(img_lrd, (inp_h, inp_w))
    gen_lr = misc.imresize(gen_lr, (inp_h, inp_w))
    gen_mask = misc.imresize(gen_mask, (inp_h, inp_w))
    gen_hr = misc.imresize(gen_hr, (inp_h * scale, inp_w * scale))
    misc.imsave(os.path.join(samples_dir, img_name + '.png'), img_lrd)
    misc.imsave(os.path.join(samples_dir, img_name + '_En.png'), gen_lr)
    misc.imsave(os.path.join(samples_dir, img_name + '_Sal.png'), gen_mask)
    misc.imsave(os.path.join(samples_dir, img_name + '_SESR.png'), gen_hr)
    print("tested: {0}".format(img_path))
Esempio n. 3
0
        # train the generators
        image_features = gan_model.vgg.predict(imgs_hr)
        if (model_name.lower() == "srdrm-gan"):
            # custom loss function for SRDRM-GAN
            g_loss = gan_model.combined.train_on_batch(
                [imgs_lr, imgs_hr], [valid, image_features, imgs_hr])
        else:
            g_loss = gan_model.combined.train_on_batch([imgs_lr, imgs_hr],
                                                       [valid, image_features])
        # increment step, and show the progress
        step += 1
        elapsed_time = datetime.datetime.now() - start_time
        if (step % 10 == 0):
            print("[Epoch %d: batch %d/%d] [d_loss: %f] [g_loss: %03f]" %
                  (epoch, i + 1, steps_per_epoch, d_loss[0], g_loss[0]))
        ## validate and save generated samples at regular intervals
        if (step % sample_interval == 0):
            imgs_lr, imgs_hr = data_loader.load_val_data(batch_size=2)
            fake_hr = gan_model.generator.predict(imgs_lr)
            gen_imgs = np.concatenate([deprocess(fake_hr), deprocess(imgs_hr)])
            save_val_samples(samples_dir, gen_imgs, step)
    # increment epoch, save model at regular intervals
    epoch += 1
    ## save model and weights
    if (epoch % ckpt_interval == 0):
        ckpt_name = os.path.join(checkpoint_dir, ("model_%d" % epoch))
        with open(ckpt_name + "_.json", "w") as json_file:
            json_file.write(gan_model.generator.to_json())
        gan_model.generator.save_weights(ckpt_name + "_.h5")
        print("\nSaved trained model in {0}\n".format(checkpoint_dir))
Esempio n. 4
0
# load weights into new model
funie_gan_generator.load_weights(model_h5)
print("\nLoaded data and model")

# testing loop
times = []
s = time.time()
for img_path in test_paths:
    # prepare data
    inp_img = read_and_resize(img_path, (256, 256))
    im = preprocess(inp_img)
    im = np.expand_dims(im, axis=0)  # (1,256,256,3)
    # generate enhanced image
    s = time.time()
    gen = funie_gan_generator.predict(im)
    gen_img = deprocess(gen)[0]
    tot = time.time() - s
    times.append(tot)
    # save output images
    img_name = ntpath.basename(img_path)
    out_img = np.hstack((inp_img, gen_img)).astype('uint8')
    Image.fromarray(out_img).save(join(samples_dir, img_name))

# some statistics
num_test = len(test_paths)
if (num_test == 0):
    print("\nFound no images for test")
else:
    print("\nTotal images: {0}".format(num_test))
    # accumulate frame processing times (without bootstrap)
    Ttime, Mtime = np.sum(times[1:]), np.mean(times[1:])
Esempio n. 5
0
def train(cfg):
    """ Training pipeline
         - cfg: yaml file with trainig parameters (see configs/)
    """
    # dataset info
    dataset = cfg["dataset_name"]
    data_path = cfg["dataset_path"]
    # image info
    chans = cfg["channels"]
    im_res = (cfg["im_width"], cfg["im_height"])
    im_shape = (im_res[1], im_res[0], chans)
    # training params
    num_epochs = cfg["num_epochs"]
    batch_size = cfg["batch_size"]
    val_interval = cfg["val_interval"]
    ckpt_interval = cfg["ckpt_interval"]
    # create validation and checkpoint directories
    val_dir = join("samples/", dataset + '_usal')
    if not exists(val_dir): os.makedirs(val_dir)
    ckpt_dir = join("checkpoints/", dataset + '_usal')
    if not exists(ckpt_dir): os.makedirs(ckpt_dir)

    ## data pipeline
    data_loader = dataLoaderSOD(data_path, dataset, im_res)
    steps_per_epoch = (data_loader.num_train // batch_size)
    num_step = num_epochs * steps_per_epoch

    ## define model, load pretrained weights
    model = SVAM_Net(model_h5='checkpoints/vgg16_ed_pt.h5')

    ## compile model
    model.compile(optimizer=Adam(3e-4, 0.5),
                  loss=[
                      'binary_crossentropy', EdgeHoldLoss, EdgeHoldLoss,
                      'binary_crossentropy'
                  ],
                  loss_weights=[0.5, 1, 2, 4],
                  metrics=['accuracy'])

    ## setup training pipeline and fit model
    print("\nTraining SVAM-Net...")
    it, step, epoch = 1, 1, 1
    while (step <= num_step):
        for imgs, masks in data_loader.load_batch(batch_size):
            loss = model.train_on_batch(imgs, [masks, masks, masks, masks])
            # increment step, and show the progress
            it += 1
            step += 1
            if not step % 100:
                print("Epoch {0}:{1}/{2}. Loss: {3}".format(
                    epoch, step, num_step, loss[0]))
            ## validate and save samples at regular intervals
            if (step % val_interval == 0):
                inp_img, gt_sal = data_loader.load_val_data(batch_size=1)
                saml, sambu, samtd, out = model.predict(inp_img)
                inp_img = deprocess(inp_img).reshape(im_shape)
                saml, sambu, samtd, out = deprocess_gens(
                    saml, sambu, samtd, out, im_res)
                Image.fromarray(inp_img).save(join(val_dir,
                                                   "%d_in.png" % step))
                Image.fromarray(np.hstack(
                    (saml, sambu, samtd,
                     out))).save(join(val_dir, "%d_sal.png" % step))
        epoch += 1
        it = 0
        ## save model at regular intervals
        if epoch % ckpt_interval == 0:
            model.save_weights(join(ckpt_dir, ("model_%d.h5" % epoch)))
            print("\nSaved model in {0}\n".format(ckpt_dir))
Esempio n. 6
0
model_json = checkpoint_dir + model_name_by_epoch + ".json"

with open(model_json, "r") as json_file:
    loaded_model_json = json_file.read()
funie_gan_generator = model_from_json(loaded_model_json)
funie_gan_generator.load_weights(model_h5)
print("\nLoaded data and model")

times = []
s = time.time()
for root, dirs, files in os.walk(test_paths):
    for img_path in files:
        if not img_path.lower().endswith('.jpg'):
            continue
        img_name = ntpath.basename(img_path).split('.')[0]
        im, shape = read_and_resize(os.path.join(root, img_path), (256, 256))
        im = preprocess(im)
        s = time.time()
        gen = funie_gan_generator.predict(im)
        gen = deprocess(gen, shape)
        tot = time.time() - s
        times.append(tot)
        misc.imsave(os.path.join(root, img_name + '_gen.png'), gen[0])

    # some statistics
    num_test = len(test_paths)
    if (num_test == 0):
        print("\nFound no images for test")
    else:
        print("\nTotal images: {0}".format(num_test))