示例#1
0
def upscale():
    loss = VGG_LOSS((384, 384, 3))
    model = keras.models.load_model('gen_model500.h5',
                                    custom_objects={'vgg_loss': loss.vgg_loss})
    # # print(model.summary())
    # inputs = keras.Input((384, 384, 3))

    # # Trace out the graph using the input:
    # outputs = model(inputs)

    # # Override the model:
    # # model = keras.model.Model(inputs, outputs)
    # model = keras.models.Model(inputs,outputs)
    # print(model.summary())
    img = cv2.imread("lr.jpg")
    img = (img.astype(np.float32) - 127.5) / 127.5
    output = model.predict(np.expand_dims(img, axis=0))
    # print(output)
    output = output[0]
    print(output.shape)

    output = (output + 1) * 127.5
    output = output.astype(np.uint8)
    # print(output)
    cv2.imwrite("sr.jpg", output)
示例#2
0
def load_backend_model(models, backend, tiling, img_shape, tile_size, use_init):
    '''Funktion that loads the respective backend model

    models          -- dictionary with the model paths
    backend         -- integer with the used backend
    tiling          -- bool should tiling be used
    tile_size       -- integer if tiling is used
    use_init        -- for backend 2 should initialization be used

    returns         -- loaded keras model  
    '''
    if backend == 0:
        print("load SRDense")
        model = load_model(models[backend])
        model.layers.pop(0)

        if tiling:
            _in = Input(shape=(tile_size, tile_size, 3))
        else:
            _in = Input(shape=img_shape)

        _out = model(_in)
        _model = Model(_in, _out)

    elif backend == 1:
        print("load SRResNet")
        loss = VGG_LOSS((504,504,3))
        model = load_model(models[backend], custom_objects={"tf": tf, "loss": loss.loss})
        model.layers.pop(0)

        if tiling:
            _in = Input(shape=(tile_size, tile_size, 3))
        else:
            _in = Input(shape=img_shape)

        _out = model(_in)
        _model = Model(_in, _out)

    elif backend == 2:
        if use_init:
            print("load initialized SRGAN")
        else:
            print("load SRGAN")
        model = load_model(models[backend][int(use_init)], custom_objects={"tf": tf})
        model.layers.pop(0)

        if tiling:
            _in = Input(shape=(tile_size, tile_size, 3))
        else:
            _in = Input(shape=img_shape)

        _out = model(_in)
        _model = Model(_in, _out)

    return _model
示例#3
0
def one_image():
    lr_shape = (64,64,3)
    loss = VGG_LOSS(lr_shape)
    optimizer = get_optimizer()

    last_epoch_number = Utils.get_last_epoch_number(MODEL_DIR+'last_model_epoch.txt')
    gen_model = MODEL_DIR+"inception_gen_model"+str(last_epoch_number)+".h5"

    generator = Generator(lr_shape,DOWNSCALE_FACTOR,"inception").generator()
    generator.load_weights(gen_model)
    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
    
    Utils.generate_one_image(INPUT_DIR,generator,OUTPUT_DIR)
示例#4
0
def plot_image():
    lr_shape = (64,64,3)
    loss = VGG_LOSS(lr_shape)
    optimizer = get_optimizer()
    
    image = cv2.imread(INPUT_DIR)
    image_lr = Utils.normalize(image)
    
    last_epoch_number = Utils.get_last_epoch_number(MODEL_DIR+'last_model_epoch.txt')
    gen_model = MODEL_DIR+"inception_gen_model"+str(last_epoch_number)+".h5"

    generator = Generator(lr_shape,DOWNSCALE_FACTOR,"inception").generator()
    generator.load_weights(gen_model)
    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)

    gen_img = generator.predict(np.expand_dims(image_lr,axis=0))
    sr_image = Utils.denormalize(gen_img)
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    Utils.generate_two_plot(sr_image,image,OUTPUT_DIR)
示例#5
0
                        action='store',
                        dest='number_of_images',
                        default=5,
                        help='Number of Images',
                        type=int)

    parser.add_argument(
        '-t',
        '--test_type',
        action='store',
        dest='test_type',
        default='test_model',
        help='Option to test model output or to test low resolution image')

    values = parser.parse_args()

    loss = VGG_LOSS(image_shape)
    model = load_model(values.model_dir,
                       custom_objects={'vgg_loss': loss.vgg_loss})

    if values.test_type == 'test_model':
        test_model(values.input_hig_res, model, values.number_of_images,
                   values.output_dir)

    elif values.test_type == 'test_lr_images':
        test_model_for_lr_images(values.input_low_res, model,
                                 values.number_of_images, values.output_dir)

    else:
        print("No such option")
示例#6
0
文件: eval.py 项目: Timozen/srcc
def main():
    # paths to the models
    model_paths = [
        os.path.join("..", "models", "SRDense-Type-3_ep80.h5"),
        os.path.join("..", "models", "srdense-norm.h5"),
        os.path.join("..", "models", "srresnet85.h5"),
        os.path.join("..", "models", "gen_model90.h5"),
        os.path.join("..", "models", "srgan60.h5"),
        os.path.join("..", "models", "srgan-mse-20.h5"), "Nearest"
    ]

    # corresponding names of the models
    model_names = [
        "SRDense", "SRDense-norm", "SRResNet", "SRGAN-from-scratch",
        "SRGAN-percept.-loss", "SRGAN-mse", "NearestNeighbor"
    ]

    # corresponding tile shapes
    tile_shapes = [((168, 168), (42, 42)), ((168, 168), (42, 42)),
                   ((504, 504), (126, 126)), ((336, 336), (84, 84)),
                   ((504, 504), (126, 126)), ((504, 504), (126, 126)),
                   ((336, 336), (84, 84))]

    # used to load the models with custom loss functions
    loss = VGG_LOSS((504, 504, 3))
    custom_objects = [{}, {
        "tf": tf
    }, {
        "tf": tf
    }, {
        "tf": tf
    }, {
        "tf": tf
    }, {
        "tf": tf
    }, {}]

    # creating a list of test images
    # [(lr, hr)]
    DOWN_SCALING_FACTOR = 4
    INTERPOLATION = cv2.INTER_CUBIC

    test_images = []
    root = os.path.join("..", "DSIDS", "test")
    # iterating over all files in the test folder
    for img in os.listdir(root):
        # chekcing if the file is an image
        if not ".jpg" in img:
            continue
        hr = Utils.crop_into_lr_shape(cv2.cvtColor(
            cv2.imread(os.path.join(root, img), cv2.IMREAD_COLOR),
            cv2.COLOR_BGR2RGB),
                                      shape=(3024, 4032))
        lr = cv2.resize(hr, (0, 0),
                        fx=1 / DOWN_SCALING_FACTOR,
                        fy=1 / DOWN_SCALING_FACTOR,
                        interpolation=INTERPOLATION)
        test_images.append((lr, hr))

    if TILES:
        '''
        First calculating performance metrics on single image tiles
        '''

        tile_performance = {}
        for i, mp in tqdm(enumerate(model_paths)):
            keras.backend.clear_session()
            # first step: load the model
            if i < 6:
                model = load_model(mp, custom_objects=custom_objects[i])

            mse = []
            psnr = []
            ssim = []
            mssim = []
            # second step: iterate over the test images
            for test_pair in tqdm(test_images):
                # third step: tile the test image
                lr_tiles = Utils.tile_image(test_pair[0],
                                            shape=tile_shapes[i][1])
                hr_tiles = Utils.tile_image(test_pair[1],
                                            shape=tile_shapes[i][0])

                m = []
                p = []
                s = []
                ms = []

                # fourth step: iterate over the tiles
                for lr, hr in zip(lr_tiles, hr_tiles):
                    # fifth step: calculate the sr tile
                    if i < 2:
                        if i == 1:
                            lr = lr.astype(np.float64)
                            lr = lr / 255
                        tmp = np.squeeze(
                            model.predict(np.expand_dims(lr, axis=0)))
                        if i == 1:
                            tmp = tmp * 255
                        tmp[tmp < 0] = 0
                        tmp[tmp > 255] = 255
                        sr = tmp.astype(np.uint8)
                    elif i < 6:
                        sr = Utils.denormalize(
                            np.squeeze(model.predict(
                                np.expand_dims(rescale_imgs_to_neg1_1(lr),
                                               axis=0)),
                                       axis=0))
                    else:
                        sr = cv2.resize(lr, (0, 0),
                                        fx=4,
                                        fy=4,
                                        interpolation=cv2.INTER_NEAREST)

                    # sixth step: append the calculated metric
                    m.append(metrics.MSE(hr, sr))
                    p.append(metrics.PSNR(hr, sr))
                    s.append(metrics.SSIM(hr, sr))
                    ms.append(metrics.MSSIM(hr, sr))

                # seventh step: append the mean metric for this image
                mse.append(np.mean(m))
                psnr.append(np.mean(p))
                ssim.append(np.mean(s))
                mssim.append(np.mean(ms))

            # eight step: append the mean metric for this model
            tile_performance[model_names[i]] = (np.mean(mse), np.mean(psnr),
                                                np.mean(ssim), np.mean(mssim))

        # final output
        print("Performance on single tiles:")
        f = open("tile_performance.txt", "w")
        for key in tile_performance:
            print(
                key + ":   MSE = " + str(tile_performance[key][0]) +
                ", PSNR = " + str(tile_performance[key][1]) + ", SSIM = " +
                str(tile_performance[key][2]),
                ", MSSIM = " + str(tile_performance[key][3]))
            f.write(key + " " + str(tile_performance[key][0]) + " " +
                    str(tile_performance[key][1]) + " " +
                    str(tile_performance[key][2]) + " " +
                    str(tile_performance[key][3]) + "\n")
        f.close()

    if WHOLE_LR:
        '''
        Second calculating performance metrics on a single upscaled image
        '''

        img_performance = {}
        for i, mp in tqdm(enumerate(model_paths)):
            keras.backend.clear_session()
            # first step: load the model
            if i < 6:
                model = load_model(mp, custom_objects=custom_objects[i])

                # second step: changing the input layer
                _in = Input(shape=test_images[0][0].shape)
                _out = model(_in)
                _model = Model(_in, _out)

            mse = []
            psnr = []
            ssim = []
            mssim = []
            # third step: iterate over the test images
            for test_pair in tqdm(test_images):
                # fourth step: calculate the sr image
                try:
                    if i < 2:
                        if i == 1:
                            lr = test_pair[0].astype(np.float64)
                            lr = lr / 255
                        else:
                            lr = test_pair[0]
                        tmp = np.squeeze(
                            _model.predict(np.expand_dims(lr, axis=0)))
                        if i == 1:
                            tmp = tmp * 255
                        tmp[tmp < 0] = 0
                        tmp[tmp > 255] = 255
                        sr = tmp.astype(np.uint8)
                    elif i < 6:
                        sr = Utils.denormalize(
                            np.squeeze(_model.predict(
                                np.expand_dims(rescale_imgs_to_neg1_1(
                                    test_pair[0]),
                                               axis=0)),
                                       axis=0))
                    else:
                        sr = cv2.resize(test_pair[0], (0, 0),
                                        fx=4,
                                        fy=4,
                                        interpolation=cv2.INTER_NEAREST)

                    # fifth step: append the metric for this image
                    mse.append(metrics.MSE(test_pair[1], sr))
                    psnr.append(metrics.PSNR(test_pair[1], sr))
                    ssim.append(metrics.SSIM(test_pair[1], sr))
                    mssim.append(metrics.MSSIM(test_pair[1], sr))
                except:
                    mse.append("err")
                    psnr.append("err")
                    ssim.append("err")
                    mssim.append("err")

            # sixth step: append the mean metric for this model
            try:
                img_performance[model_names[i]] = (np.mean(mse), np.mean(psnr),
                                                   np.mean(ssim),
                                                   np.mean(mssim))
            except:
                img_performance[model_names[i]] = ("err", "err", "err", "err")

        # final output
        print("Performance on whole lr:")
        f = open("whole_lr_performance.txt", "w")
        for key in img_performance:
            print(
                key + ":   MSE = " + str(img_performance[key][0]) +
                ", PSNR = " + str(img_performance[key][1]) + ", SSIM = " +
                str(img_performance[key][2]),
                ", MSSIM = " + str(img_performance[key][3]))
            f.write(key + " " + str(img_performance[key][0]) + " " +
                    str(img_performance[key][1]) + " " +
                    str(img_performance[key][2]) + " " +
                    str(img_performance[key][3]) + "\n")
        f.close()

    if STITCHED:
        '''
        Second calculating performance metrics on a stitched image
        '''

        stitch_performance = {}
        for i, mp in tqdm(enumerate(model_paths)):
            keras.backend.clear_session()
            # first step: load the model
            if i < 6:
                model = load_model(mp, custom_objects=custom_objects[i])

            mse = []
            psnr = []
            ssim = []
            mssim = []

            o_mse = []
            o_psnr = []
            o_ssim = []
            o_mssim = []
            # second step: iterate over the test images
            for test_pair in tqdm(test_images):
                # third step: tile the test image
                lr_tiles = Utils.tile_image(test_pair[0],
                                            shape=tile_shapes[i][1])
                lr_tiles_overlap = Utils.tile_image(test_pair[0],
                                                    shape=tile_shapes[i][1],
                                                    overlap=True)

                sr_tiles = []
                sr_tiles_overlap = []
                # fourth step: iterate over the tiles
                for lr in lr_tiles:
                    # fifth step: calculate the sr tiles
                    if i < 2:
                        if i == 1:
                            lr = lr.astype(np.float64)
                            lr = lr / 255
                        tmp = np.squeeze(
                            model.predict(np.expand_dims(lr, axis=0)))
                        if i == 1:
                            tmp = tmp * 255
                        tmp[tmp < 0] = 0
                        tmp[tmp > 255] = 255
                        sr = tmp.astype(np.uint8)
                        sr_tiles.append(sr)
                    elif i < 6:
                        sr_tiles.append(
                            Utils.denormalize(
                                np.squeeze(model.predict(
                                    np.expand_dims(rescale_imgs_to_neg1_1(lr),
                                                   axis=0)),
                                           axis=0)))
                    else:
                        sr_tiles.append(
                            cv2.resize(lr, (0, 0),
                                       fx=4,
                                       fy=4,
                                       interpolation=cv2.INTER_NEAREST))

                for lr in lr_tiles_overlap:
                    # fifth step: calculate the sr tiles
                    if i < 2:
                        if i == 1:
                            lr = lr.astype(np.float64)
                            lr = lr / 255
                        tmp = np.squeeze(
                            model.predict(np.expand_dims(lr, axis=0)))
                        if i == 1:
                            tmp = tmp * 255
                        tmp[tmp < 0] = 0
                        tmp[tmp > 255] = 255
                        sr = tmp.astype(np.uint8)
                        sr_tiles_overlap.append(sr)
                    elif i < 6:
                        sr_tiles_overlap.append(
                            Utils.denormalize(
                                np.squeeze(model.predict(
                                    np.expand_dims(rescale_imgs_to_neg1_1(lr),
                                                   axis=0)),
                                           axis=0)))
                    else:
                        sr_tiles_overlap.append(
                            cv2.resize(lr, (0, 0),
                                       fx=4,
                                       fy=4,
                                       interpolation=cv2.INTER_NEAREST))

                # sixth step: stitch the image
                sr_simple = ImageStitching.stitch_images(
                    sr_tiles, test_pair[1].shape[1], test_pair[1].shape[0],
                    sr_tiles[0].shape[1], sr_tiles[0].shape[0],
                    test_pair[1].shape[1] // sr_tiles[0].shape[1],
                    test_pair[1].shape[0] // sr_tiles[0].shape[0])
                sr_advanced = ImageStitching.stitching(
                    sr_tiles_overlap,
                    LR=None,
                    image_size=(test_pair[1].shape[0], test_pair[1].shape[1]),
                    adjustRGB=False,
                    overlap=True)

                # seventh step: append the mean metric for this image
                mse.append(metrics.MSE(test_pair[1], sr_simple))
                psnr.append(metrics.PSNR(test_pair[1], sr_simple))
                ssim.append(metrics.SSIM(test_pair[1], sr_simple))
                mssim.append(metrics.MSSIM(test_pair[1], sr_simple))

                o_mse.append(metrics.MSE(test_pair[1], sr_advanced))
                o_psnr.append(metrics.PSNR(test_pair[1], sr_advanced))
                o_ssim.append(metrics.SSIM(test_pair[1], sr_advanced))
                o_mssim.append(metrics.MSSIM(test_pair[1], sr_advanced))

            # ninth step: append the mean metric for this model
            stitch_performance[model_names[i]] = [
                (np.mean(mse), np.mean(psnr), np.mean(ssim), np.mean(mssim)),
                (np.mean(o_mse), np.mean(o_psnr), np.mean(o_ssim),
                 np.mean(o_mssim))
            ]

        # final output
        print("Performance on stitched images:")
        f = open("stitch_performance.txt", "w")
        for key in stitch_performance:
            print(
                "simple stitch:  " + key + ":   MSE = " +
                str(stitch_performance[key][0][0]) + ", PSNR = " +
                str(stitch_performance[key][0][1]) + ", SSIM = " +
                str(stitch_performance[key][0][2]),
                ", MSSIM = " + str(stitch_performance[key][0][3]))
            print(
                "advanced stitch:  " + key + ":   MSE = " +
                str(stitch_performance[key][1][0]) + ", PSNR = " +
                str(stitch_performance[key][1][1]) + ", SSIM = " +
                str(stitch_performance[key][1][2]),
                ", MSSIM = " + str(stitch_performance[key][1][3]))
            f.write(key + " " + str(stitch_performance[key][0][0]) + " " +
                    str(stitch_performance[key][0][1]) + " " +
                    str(stitch_performance[key][0][2]) + " " +
                    str(stitch_performance[key][0][3]) + "\n")
            f.write(key + " " + str(stitch_performance[key][1][0]) + " " +
                    str(stitch_performance[key][1][1]) + " " +
                    str(stitch_performance[key][1][2]) + " " +
                    str(stitch_performance[key][1][3]) + "\n")
        f.close()
示例#7
0
def train(epochs, batch_size, input_dir, output_dir, model_save_dir,
          number_of_images, train_test_ratio):
    x_train_lr, x_train_hr, x_test_lr, x_test_hr = Utils.load_training_data(
        input_dir, '.png', number_of_images, train_test_ratio)
    loss = VGG_LOSS(image_shape)

    batch_count = int(x_train_hr.shape[0] / batch_size)
    shape = (image_shape[0] // downscale_factor,
             image_shape[1] // downscale_factor, image_shape[2])  # Not good

    generator, _ = Generator(shape).generator()
    discriminator = Discriminator(image_shape).discriminator()

    optimizer = Utils_model.get_optimizer()
    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
    discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)

    gan = get_gan_network(discriminator, shape, generator, optimizer,
                          loss.vgg_loss)

    loss_file = open(model_save_dir + 'losses.txt', 'w+')
    loss_file.close()

    for e in range(1, epochs + 1):
        print('-' * 15, 'Epoch %d' % e, '-' * 15)
        for _ in tqdm(range(batch_count)):
            rand_nums = np.random.randint(0,
                                          x_train_hr.shape[0],
                                          size=batch_size)

            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]
            generated_images_sr = generator.predict(image_batch_lr)

            real_data_Y = np.ones(
                batch_size) - np.random.random_sample(batch_size) * 0.2
            fake_data_Y = np.random.random_sample(batch_size) * 0.2

            discriminator.trainable = True

            d_loss_real = discriminator.train_on_batch(image_batch_hr,
                                                       real_data_Y)
            d_loss_fake = discriminator.train_on_batch(generated_images_sr,
                                                       fake_data_Y)
            discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

            rand_nums = np.random.randint(0,
                                          x_train_hr.shape[0],
                                          size=batch_size)
            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]

            gan_Y = np.ones(
                batch_size) - np.random.random_sample(batch_size) * 0.2
            discriminator.trainable = False
            gan_loss = gan.train_on_batch(image_batch_lr,
                                          [image_batch_hr, gan_Y])

        print("discriminator_loss : %f" % discriminator_loss)
        print("gan_loss :", gan_loss)
        gan_loss = str(gan_loss)

        loss_file = open(model_save_dir + 'losses.txt', 'a')
        loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' %
                        (e, gan_loss, discriminator_loss))
        loss_file.close()

        if e == 1 or e % 5 == 0:
            Utils.plot_generated_images(output_dir, e, generator, x_test_hr,
                                        x_test_lr)
            # generator.save(model_save_dir + 'gen_model%d.h5' % e)
            # generator.save_weights(model_save_dir + 'gen_w%d.h5' % e)
            # exit()
        if e % 500 == 0:
            # generator.save(model_save_dir + 'gen_model%d.h5' % e)
            generator.save_weights(model_save_dir + 'gen_w%d.h5' % e)
示例#8
0
def train(epochs, batch_size, input_dir, output_dir, model_save_dir, number_of_images, train_test_ratio, image_extension):
    # Loading images
    x_train_lr, x_train_hr, x_test_lr, x_test_hr = \
        Utils.load_training_data(input_dir, image_extension, image_shape, number_of_images, train_test_ratio)

    print('======= Loading VGG_loss ========')
    # Loading VGG loss
    loss = VGG_LOSS(image_shape)
    loss2 = VGG_LOSS(image_shape)
    print('====== VGG_LOSS =======', loss)

    batch_count = int(x_train_hr.shape[0] / batch_size)
    print('====== Batch_count =======', batch_count)

    shape = (image_shape[0] // downscale_factor, image_shape[1] // downscale_factor, image_shape[2])
    print('====== Shape =======', shape)

    # Generator description
    generator = Generator(shape).generator()
    complex_generator = complex_Generator(shape).generator()
    # Discriminator description
    discriminator = Discriminator(image_shape).discriminator()
    discriminator2 = Discriminator(image_shape).discriminator()

    optimizer = Utils_model.get_optimizer()

    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
    complex_generator.compile(loss=loss2.vgg_loss, optimizer=optimizer)

    discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)
    discriminator2.compile(loss="binary_crossentropy", optimizer=optimizer)

    gan = get_gan_network(discriminator, shape, generator, optimizer, loss.vgg_loss)
    complex_gan = get_gan_network(discriminator2, shape, complex_generator, optimizer, loss2.vgg_loss)

    loss_file = open(model_save_dir + 'losses.txt', 'w+')

    loss_file.close()

    for e in range(1, epochs + 1):
        print('-' * 15, 'Epoch %d' % e, '-' * 15)
        for _ in tqdm(range(batch_count)):
            rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)

            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]
            generated_images_sr = generator.predict(image_batch_lr)
            generated_images_csr = complex_generator.predict(image_batch_lr)
            real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size) * 0.2
            fake_data_Y = np.random.random_sample(batch_size) * 0.2

            discriminator.trainable = True
            discriminator2.trainable = True

            d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
            d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
            discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

            d_loss_creal = discriminator2.train_on_batch(image_batch_hr, real_data_Y)
            d_loss_cfake = discriminator2.train_on_batch(generated_images_csr, fake_data_Y)
            discriminator_c_loss = 0.5 * np.add(d_loss_cfake, d_loss_creal)
            ########
            rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]

            gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size) * 0.2
            discriminator.trainable = False
            discriminator2.trainable = False
            gan_loss = gan.train_on_batch(image_batch_lr, [image_batch_hr, gan_Y])
            gan_c_loss = complex_gan.train_on_batch(image_batch_lr, [image_batch_hr, gan_Y])

        print("discriminator_loss : %f" % discriminator_loss)
        print("gan_loss :", gan_loss)
        print("gan_c_loss :", gan_c_loss)
        gan_loss = str(gan_loss)

        loss_file = open(model_save_dir + 'losses.txt', 'a')
        loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' % (e, gan_loss, discriminator_loss))
        loss_file.close()

        if e % 1 == 0:
            Utils.plot_generated_images(output_dir, e, generator,complex_generator, x_test_hr, x_test_lr)
        if e % 50 == 0:
            generator.save(model_save_dir + 'gen_model%d.h5' % e)
            discriminator.save(model_save_dir + 'dis_model%d.h5' % e)
def train(epochs, batch_size, input_dir, tgt_dir, output_dir, model_save_dir,
          number_of_images, train_test_ratio, saved_model):
    x_train_lr, x_train_hr, x_test_lr, x_test_hr = Utils.load_training_data(
        input_dir, tgt_dir, '.npy', number_of_images, train_test_ratio)
    # x_train_lr, x_train_hr, x_test_lr, x_test_hr = Utils.load_training_data(input_dir, '.jpg', number_of_images, train_test_ratio)
    # x_train_hr = np.expand_dims(x_train_hr, axis=3)
    # x_test_hr = np.expand_dims(x_test_hr, axis=3)
    # x_train_hr = np.reshape(x_train_hr,(x_train_hr[0], x_train_hr[1], x_train_hr[2], 1))
    # x_test_hr = np.reshape(x_test_hr, (x_test_hr[0], x_test_hr[1], x_test_hr[2], 1))

    loss = VGG_LOSS(image_shape)

    batch_count = int(x_train_hr.shape[0] / batch_size)
    shape = (image_shape[0], image_shape[1], image_shape[2], image_shape[3])
    #
    # generator = Generator(shape).generator()
    # # model = squeeze(Activation('tanh')(model), 4)
    #
    # # discriminator = Discriminator(dis_shape).discriminator()

    generator = Generator(shape).generator()
    generator_old = load_model(saved_model,
                               custom_objects={'vgg_loss': loss.vgg_loss})
    x_tmp = generator_old.layers[-2].output
    # generator_old.layers.pop()
    # generator_old.layers.pop()
    # for layers in generator_old.layers:
    #     layers.trainable = False
    # define new layrs
    # x1 = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")
    # x2 = Conv2D(filters=1, kernel_size=5, strides=1, padding="same")
    # x3 = Activation('tanh')
    # generator = Concatenate()[generator_old, x1,x2,x3]

    # gan_input = Input(shape=shape)
    # x_tmp = mid_out#generator_old(gan_input)
    # x_tmp = Conv3D(filters=64, kernel_size=9, strides=1, padding="same", name="TransConv2d_1")(x_tmp)
    x_tmp = Conv3D(filters=128,
                   kernel_size=5,
                   strides=1,
                   padding="same",
                   name="TransConv2d_2")(x_tmp)
    x_tmp = Conv3D(filters=32,
                   kernel_size=3,
                   strides=1,
                   padding="same",
                   name="TransConv2d_3")(x_tmp)
    x_tmp = Conv3D(filters=1,
                   kernel_size=1,
                   strides=1,
                   padding="same",
                   name="TransConv2d_4")(x_tmp)
    # gan_output = Activation('tanh')(x_tmp)
    generator = Model(inputs=generator_old.input, outputs=x_tmp)

    # fine tune the layers.
    for layers in generator.layers[:-4]:
        layers.trainable = False

    optimizer = Utils_model.get_optimizer()
    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
    # discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)

    # gan = get_gan_network(discriminator, shape, generator, optimizer, loss.vgg_loss)

    loss_file = open(model_save_dir + 'losses.txt', 'w+')
    loss_file.close()

    for e in range(1, epochs + 1):
        print('-' * 15, 'Epoch %d' % e, '-' * 15)
        for _ in tqdm(range(batch_count)):
            rand_nums = np.random.randint(0,
                                          x_train_hr.shape[0],
                                          size=batch_size)

            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]

            gan_loss = generator.train_on_batch(image_batch_lr, image_batch_hr)

        # print("discriminator_loss : %f" % discriminator_loss)
        print("gan_loss :", gan_loss)
        gan_loss = str(gan_loss)

        loss_file = open(model_save_dir + 'losses.txt', 'a')
        loss_file.write('epoch%d : Resnet_loss = %s ; \n' % (e, gan_loss))
        loss_file.close()

        if e == 1 or e % 5 == 0:
            rand_nums = np.random.randint(0,
                                          x_test_hr.shape[0],
                                          size=batch_size)
            image_batch_hr = x_test_hr[rand_nums]
            image_batch_lr = x_test_lr[rand_nums]
            test_loss = generator.test_on_batch(image_batch_lr, image_batch_hr)
            print("test_loss :", test_loss)
            test_loss = str(test_loss)
            loss_file = open(model_save_dir + 'test_losses.txt', 'a')
            loss_file.write('epoch%d : test_loss = %s ; \n' % (e, test_loss))
            loss_file.close()
        if e % 50 == 0:
            generator.save(model_save_dir + 'Resnet_model%d.h5' % e)
示例#10
0
def train(epochs, batch_size, input_dir, output_dir, model_save_dir,
          number_of_images, train_test_ratio, image_extension):

    # Loading images

    x_train_lr, x_train_hr, x_test_lr, x_test_hr = \
        Utils.load_training_data(input_dir, image_extension, image_shape, number_of_images, train_test_ratio)

    # convert to loading PATCHES
    #num_samples = dataset_info['num_samples'][1]

    print('======= Loading VGG_loss ========')
    # Loading VGG loss

    # convert to 3 channels
    #img_input = Input(shape=original_image_shape)
    #image_shape_gray = Concatenate()([img_input, img_input, img_input])
    #image_shape_gray = Concatenate()([original_image_shape, original_image_shape])
    #image_shape_gray = Concatenate()([image_shape_gray,original_image_shape])
    #image_shape = patch_shape
    #experimental_run_tf_function=False

    loss = VGG_LOSS(image_shape)  # was image_shape

    print('====== VGG_LOSS =======', loss)

    # 1 channel
    #image_shape= original_image_shape
    batch_count = int(x_train_hr.shape[0] / batch_size)
    #batch_count = int(x_train_hr_patch.shape[0] / batch_size) # for patch

    print('====== Batch_count =======', batch_count)

    shape = (image_shape[0] // downscale_factor,
             image_shape[1] // downscale_factor, image_shape[2]
             )  # commented by Furat
    #shape = (image_shape[0] // downscale_factor, image_shape[1] // downscale_factor)
    print('====== Shape =======', shape)

    # Generator description
    generator = Generator(shape).generator()

    # Discriminator description
    discriminator = Discriminator(image_shape).discriminator()

    optimizer = Utils_model.get_optimizer()

    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)

    discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)

    gan = get_gan_network(discriminator, shape, generator, optimizer,
                          loss.vgg_loss)

    loss_file = open(model_save_dir + 'losses.txt', 'w+')

    loss_file.close()

    ## restore the patches into 1 image:
    # x_train_hr should have a whole image insted of patches?

    ######
    # input_data= x_train_hr
    # patch_shape = train_conf['patch_shape']
    # output_shape = train_conf['output_shape']
    # num_chs = num_modalities * dataset_info['RGBch'] # number of total channels

    # if input_data.ndim == 6: # augmentation case
    #     num_samples = dataset_info['num_samples'][1]

    #num_samples = 3
    ######

    #lr_data= x_train_lr
    #hr_data= x_train_hr

    #for patch_idx in range (num_patches):
    #this_input_data = np.reshape(input_data[:,:,patch_idx], input_data.shape[:2]+input_data.shape[3:])
    #this_hr_patch, this_lr_patch = overlap_patching(gen_conf, train_conf, x_train_hr)
    #this_output_patch, out = overlap_patching(gen_conf, train_conf, output_data)

    # take patches:
    this_hr_patch, = extract_2d(x_train_hr, (32, 32))
    this_lr_patch = extract_2d(x_train_lr, (32, 32))

    x_train_lr = this_lr_patch
    x_train_hr = this_hr_patch

    #convert to grayscale
    #x_train_hr= tf.image.rgb_to_grayscale(x_train_hr)
    #x_train_hr= rgb2gray(x_train_hr)
    #x_train_hr= np.concatenate(x_train_hr,1)
    #x_train_hr= np.array(x_train_hr)
    #x_train_lr= tf.image.rgb_to_grayscale(x_train_lr)
    #x_train_lr= np.array(x_train_lr)
    #x_train_lr= rgb2gray(x_train_lr)
    #x_train_lr= np.concatenate(x_train_lr,1)

    for e in range(1, epochs + 1):
        print('-' * 15, 'Epoch %d' % e, '-' * 15)
        for _ in tqdm(range(batch_count)):
            rand_nums = np.random.randint(0,
                                          x_train_hr.shape[0],
                                          size=batch_size)

            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]
            generated_images_sr = generator.predict(image_batch_lr)

            real_data_Y = np.ones(
                batch_size) - np.random.random_sample(batch_size) * 0.2
            fake_data_Y = np.random.random_sample(batch_size) * 0.2

            discriminator.trainable = True

            d_loss_real = discriminator.train_on_batch(image_batch_hr,
                                                       real_data_Y)
            d_loss_fake = discriminator.train_on_batch(generated_images_sr,
                                                       fake_data_Y)
            discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

            rand_nums = np.random.randint(0,
                                          x_train_hr.shape[0],
                                          size=batch_size)
            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]

            gan_Y = np.ones(
                batch_size) - np.random.random_sample(batch_size) * 0.2
            discriminator.trainable = False
            gan_loss = gan.train_on_batch(image_batch_lr,
                                          [image_batch_hr, gan_Y])

        print("discriminator_loss : %f" % discriminator_loss)
        print("gan_loss :", gan_loss)
        gan_loss = str(gan_loss)

        loss_file = open(model_save_dir + 'losses.txt', 'a')
        loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' %
                        (e, gan_loss, discriminator_loss))
        loss_file.close()

        if e == 1 or e % 5 == 0:
            Utils.plot_generated_images(output_dir, e, generator, x_test_hr,
                                        x_test_lr)
        if e % 200 == 0:
            generator.save(model_save_dir + 'gen_model%d.h5' % e)
            discriminator.save(model_save_dir + 'dis_model%d.h5' % e)
示例#11
0
def train(epochs, batch_size, input_dir, tgt_dir, output_dir, model_save_dir,
          number_of_images, train_test_ratio):

    x_train_lr, x_train_hr, x_test_lr, x_test_hr = Utils.load_training_data(
        input_dir, tgt_dir, '.npy', number_of_images, train_test_ratio)
    # x_train_lr, x_train_hr, x_test_lr, x_test_hr = Utils.load_training_data(input_dir, '.jpg', number_of_images, train_test_ratio)
    # x_train_hr = np.expand_dims(x_train_hr, axis=3)
    # x_test_hr = np.expand_dims(x_test_hr, axis=3)
    # x_train_hr = np.reshape(x_train_hr,(x_train_hr[0], x_train_hr[1], x_train_hr[2], 1))
    # x_test_hr = np.reshape(x_test_hr, (x_test_hr[0], x_test_hr[1], x_test_hr[2], 1))

    loss = VGG_LOSS(image_shape)

    batch_count = int(x_train_hr.shape[0] / batch_size)
    shape = (image_shape[0], image_shape[1], image_shape[2], image_shape[3])

    generator = Generator(shape).generator()
    # model = squeeze(Activation('tanh')(model), 4)

    # discriminator = Discriminator(dis_shape).discriminator()

    optimizer = Utils_model.get_optimizer()
    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
    # discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)

    # gan = get_gan_network(discriminator, shape, generator, optimizer, loss.vgg_loss)

    loss_file = open(model_save_dir + 'losses.txt', 'w+')
    loss_file.close()

    for e in range(1, epochs + 1):
        print('-' * 15, 'Epoch %d' % e, '-' * 15)
        for _ in tqdm(range(batch_count)):

            rand_nums = np.random.randint(0,
                                          x_train_hr.shape[0],
                                          size=batch_size)

            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]
            generated_images_sr = generator.predict(image_batch_lr)

            # real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
            # fake_data_Y = np.random.random_sample(batch_size)*0.2
            #
            # discriminator.trainable = True

            # d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
            # d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
            # discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
            #
            # rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
            # image_batch_hr = x_train_hr[rand_nums]
            # image_batch_lr = x_train_lr[rand_nums]
            #
            # gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
            # discriminator.trainable = False
            gan_loss = generator.train_on_batch(image_batch_lr, image_batch_hr)

        # print("discriminator_loss : %f" % discriminator_loss)
        print("gan_loss :", gan_loss)
        gan_loss = str(gan_loss)

        loss_file = open(model_save_dir + 'losses.txt', 'a')
        loss_file.write('epoch%d : gan_loss = %s ; \n' % (e, gan_loss))
        loss_file.close()

        if e == 1 or e % 5 == 0:
            flag = 1
            # Utils.plot_generated_images(output_dir, e, generator, x_test_hr, x_test_lr)
        if e % 500 == 0:
            generator.save(model_save_dir + 'gen_model%d.h5' % e)
示例#12
0
    # print(model.summary())
    img = cv2.imread("lr.jpg")
    img = (img.astype(np.float32) - 127.5) / 127.5
    output = model.predict(np.expand_dims(img, axis=0))
    # print(output)
    output = output[0]
    print(output.shape)

    output = (output + 1) * 127.5
    output = output.astype(np.uint8)
    # print(output)
    cv2.imwrite("sr.jpg", output)


# upscale()
loss = VGG_LOSS((384, 384, 3))
#load the model 4x
model = keras.models.load_model('gen_model500.h5',
                                custom_objects={'vgg_loss': loss.vgg_loss})
#change shape of model as per requirement i.e. resolution of input image
model = change_model(model, new_shape=(None, 96, 96, 3))

# print(model.summary())

img = cv2.imread("lr.jpg")  #read image
img = (img.astype(np.float32) - 127.5) / 127.5  #normalize (-1 to 1)
output = model.predict(np.expand_dims(img, axis=0))  #predict the output
#print(output)
output = output[0]
print(output.shape)
output = (output + 1) * 127.5  #denormalize
示例#13
0
def train(epochs, batch_size, input_dir, output_dir, model_save_dir,
          number_of_images, train_test_ratio, resume_train, downscale_factor,
          arch):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    set_session(sess)

    hr_images, hr_label, lr_images, lr_label = Utils.load_training_data(
        input_dir, number_of_images)

    print(hr_images)
    loss = VGG_LOSS(IMAGE_SHAPE)
    lr_shape = (IMAGE_SHAPE[0] // downscale_factor,
                IMAGE_SHAPE[1] // downscale_factor, IMAGE_SHAPE[2])
    print(lr_shape)
    generator = Generator(lr_shape, downscale_factor, arch).generator()
    discriminator = Discriminator(IMAGE_SHAPE).discriminator()

    optimizer = Utils_model.get_optimizer()

    if (resume_train == True):
        last_epoch_number = Utils.get_last_epoch_number(model_save_dir +
                                                        'last_model_epoch.txt')

        gen_model = model_save_dir + arch + "_gen_model" + str(
            last_epoch_number) + ".h5"
        dis_model = model_save_dir + arch + "_dis_model" + str(
            last_epoch_number) + ".h5"
        generator.load_weights(gen_model)
        discriminator.load_weights(dis_model)

    else:
        last_epoch_number = 1

    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
    discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)

    gan = gan_network(discriminator, lr_shape, generator, optimizer,
                      loss.vgg_loss)

    for e in range(last_epoch_number, last_epoch_number + epochs):
        print('-' * 15, 'Epoch %d' % e, '-' * 15)
        for _ in tqdm(range(1)):

            rand_nums = np.random.randint(0,
                                          hr_images.shape[0],
                                          size=batch_size)
            image_batch_hr = hr_images[rand_nums]
            image_batch_lr = lr_images[rand_nums]
            # video_images = lr_images[0]
            generated_images = generator.predict(
                image_batch_lr)  #array of generated images

            real_data = np.ones(
                batch_size) - np.random.random_sample(batch_size) * 0.2
            fake_data = np.random.random_sample(batch_size) * 0.2

            discriminator.trainable = True

            discriminator_loss_real = discriminator.train_on_batch(
                image_batch_hr, real_data)
            discriminator_loss_fake = discriminator.train_on_batch(
                generated_images, fake_data)
            discriminator_loss = 0.5 * np.add(
                discriminator_loss_fake,
                discriminator_loss_real)  #Mean Of Discriminator Loss

            rand_nums = np.random.randint(0,
                                          hr_images.shape[0],
                                          size=batch_size)

            discriminator.trainable = False
            gan_loss = gan.train_on_batch(image_batch_lr,
                                          [image_batch_hr, real_data])

        print("discriminator_loss : %f" % discriminator_loss)
        print("gan_loss : ", gan_loss)
        gan_loss = str(gan_loss)
        # generated_video_image = generator.predict(np.expand_dims(video_images,axis=0))
        Utils.save_losses_file(model_save_dir, e, gan_loss, discriminator_loss,
                               arch + '_losses.txt')
        # image_array.append(cv2.cvtColor(denormalize(generated_video_image[0]),cv2.COLOR_BGR2RGB))

        if e % EPOCHS_CHECKPOINT == 0:
            Utils.save_losses_file(model_save_dir, e, gan_loss,
                                   discriminator_loss, 'last_model_epoch.txt')
            generator.save(model_save_dir + arch + '_gen_model%d.h5' % e)
            discriminator.save(model_save_dir + arch + '_dis_model%d.h5' % e)
示例#14
0
def train(img_shape,
          epochs,
          batch_size,
          rescaling_factor,
          input_dirs,
          output_dir,
          model_save_dir,
          train_test_ratio,
          gpu=1):

    lr_shape = (img_shape[0] // rescaling_factor,
                img_shape[1] // rescaling_factor, img_shape[2])

    img_train_gen, img_test_gen = create_data_generator(
        input_dirs[1],
        input_dirs[0],
        target_size_lr=(lr_shape[0], lr_shape[1]),
        target_size_hr=(img_shape[0], img_shape[1]),
        preproc_lr=rescale_imgs_to_neg1_1,
        preproc_hr=rescale_imgs_to_neg1_1,
        validation_split=train_test_ratio,
        batch_size=batch_size)

    batch_count = int(
        (len(os.listdir(os.path.join(input_dirs[1], 'ignore'))) / batch_size) *
        (1 - train_test_ratio))

    test_image = []
    for img in sorted(os.listdir(os.path.join(input_dirs[1], 'ignore'))):
        if 'niklas_city_0009' in img:
            test_image.append(
                rescale_imgs_to_neg1_1(
                    cv2.imread(os.path.join(input_dirs[1], 'ignore', img))))

    print("test length: ", len(test_image))

    loss = VGG_LOSS(image_shape)

    generator = Generator(lr_shape, rescaling_factor).generator()

    print('memory usage generator: ',
          get_model_memory_usage(batch_size, generator))

    optimizer = Utils_model.get_optimizer()

    if gpu > 1:
        try:
            print("multi_gpu_model generator")
            par_generator = multi_gpu_model(generator, gpus=2)
        except:
            par_generator = generator
            print("single_gpu_model generator")
    else:
        par_generator = generator
        print("single_gpu_model generator")

    par_generator.compile(loss=loss.loss, optimizer=optimizer)

    par_generator.summary()

    for e in range(1, epochs + 1):
        print('-' * 15, 'Epoch %d' % e, '-' * 15)
        for i in tqdm(range(batch_count)):

            batch = next(img_train_gen)
            image_batch_hr = batch[1]
            image_batch_lr = batch[0]

            if image_batch_hr.shape[0] == batch_size and image_batch_lr.shape[
                    0] == batch_size:
                g_loss = par_generator.train_on_batch(image_batch_lr,
                                                      image_batch_hr)
            else:
                print("weird multi_gpu_model batch error dis: ")
                print("hr batch shape: ", image_batch_hr.shape)
                print("lr batch shape: ", image_batch_lr.shape)

        #if e == 1 or e % 5 == 0:
        #Utils.generate_test_image(output_dir, e, generator, test_image)
        if e % 5 == 0:
            generator.save(os.path.join(model_save_dir, 'srresnet%d.h5' % e))

    generator.save(os.path.join(model_save_dir, 'srresnet.h5' % e))
示例#15
0
文件: train.py 项目: Timozen/srcc
def train(img_shape, epochs, batch_size, rescaling_factor, input_dirs,
          output_dir, model_save_dir, train_test_ratio):

    lr_shape = (img_shape[0] // rescaling_factor,
                img_shape[1] // rescaling_factor, img_shape[2])

    img_train_gen, img_test_gen = create_data_generator(
        input_dirs[1],
        input_dirs[0],
        target_size_lr=(lr_shape[0], lr_shape[1]),
        target_size_hr=(img_shape[0], img_shape[1]),
        preproc_lr=rescale_imgs_to_neg1_1,
        preproc_hr=rescale_imgs_to_neg1_1,
        validation_split=train_test_ratio,
        batch_size=batch_size)
    loss = VGG_LOSS(image_shape)

    batch_count = int(
        (len(os.listdir(os.path.join(input_dirs[1], 'ignore'))) / batch_size) *
        (1 - train_test_ratio))

    test_image = []
    for img in sorted(os.listdir(os.path.join(input_dirs[1], 'ignore'))):
        if 'niklas_city_0009' in img:
            test_image.append(
                rescale_imgs_to_neg1_1(
                    cv2.imread(os.path.join(input_dirs[1], 'ignore', img))))

    print("test length: ", len(test_image))

    generator = Generator(lr_shape, rescaling_factor).generator()
    discriminator = Discriminator(img_shape).discriminator()

    print('memory usage generator: ',
          get_model_memory_usage(batch_size, generator))
    print('memory usage discriminator: ',
          get_model_memory_usage(batch_size, discriminator))

    optimizer = Utils_model.get_optimizer()

    try:
        print("multi_gpu_model generator")
        par_generator = multi_gpu_model(generator, gpus=2)
    except:
        par_generator = generator
        print("single_gpu_model generator")

    try:
        print("multi_gpu_model discriminator")
        par_discriminator = multi_gpu_model(discriminator, gpus=2)
    except:
        par_discriminator = discriminator
        print("single_gpu_model discriminator")

    par_generator.compile(loss=loss.loss, optimizer=optimizer)
    par_discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)

    gan, par_gan = get_gan_network(par_discriminator, lr_shape, par_generator,
                                   optimizer, loss.loss, batch_size)

    par_discriminator.summary()
    par_generator.summary()
    par_gan.summary()

    loss_file = open(model_save_dir + 'losses.txt', 'w+')
    loss_file.close()

    for e in range(1, epochs + 1):
        print('-' * 15, 'Epoch %d' % e, '-' * 15)

        if e == 100:
            optimizer.lr = 1e-5

        for i in tqdm(range(batch_count)):

            batch = next(img_train_gen)
            image_batch_hr = batch[1]
            image_batch_lr = batch[0]
            generated_images_sr = generator.predict(image_batch_lr)

            real_data_Y = np.ones(batch_size) - \
                np.random.random_sample(batch_size)*0.2
            fake_data_Y = np.random.random_sample(batch_size) * 0.2

            par_discriminator.trainable = True

            if image_batch_hr.shape[0] == batch_size and image_batch_lr.shape[
                    0] == batch_size:
                d_loss_real = par_discriminator.train_on_batch(
                    image_batch_hr, real_data_Y)
                d_loss_fake = par_discriminator.train_on_batch(
                    generated_images_sr, fake_data_Y)
                discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
            else:
                print("weird multi_gpu_model batch error dis: ")
                print("hr batch shape: ", image_batch_hr.shape)
                print("lr batch shape: ", image_batch_lr.shape)
                print("gan y shape: ", gan_Y.shape)

            batch = next(img_train_gen)
            image_batch_hr = batch[1]
            image_batch_lr = batch[0]

            gan_Y = np.ones(batch_size) - \
                np.random.random_sample(batch_size)*0.2
            discriminator.trainable = False

            if image_batch_hr.shape[0] == batch_size and image_batch_lr.shape[
                    0] == batch_size:
                gan_loss = par_gan.train_on_batch(image_batch_lr,
                                                  [image_batch_hr, gan_Y])
            else:
                print("weird multi_gpu_model batch error gan: ")
                print("hr batch shape: ", image_batch_hr.shape)
                print("lr batch shape: ", image_batch_lr.shape)
                print("gan y shape: ", gan_Y.shape)

        print("discriminator_loss : %f" % discriminator_loss)
        print("gan_loss :", gan_loss)
        gan_loss = str(gan_loss)

        loss_file = open(model_save_dir + '_losses.txt', 'a')
        loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' %
                        (e, gan_loss, discriminator_loss))
        loss_file.close()

        if e == 1 or e % 5 == 0:
            Utils.generate_test_image(output_dir, e, generator, test_image)
        if e % 5 == 0:
            generator.save(os.path.join(model_save_dir, 'gen_model%d.h5' % e))
            discriminator.save(
                os.path.join(model_save_dir, 'dis_model%d.h5' % e))