def train(args, lsgan=True):
    
    multi_test_save_all = []

    for i in range(args.test_size):

        test_input = next(test_datasetA)

        multi_test_save_single = tf.reshape(tf.clip_by_value((test_input+1)/2, 0, 1), [args.load_img_size, args.load_img_size, 3]).numpy()

        start = time.time()

        highs_test, lows_test = pyramid.split(test_input, args.level)
        generated_low = genA2B(lows_test[-1], training=False)
        high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
        generated_high = transA2B(high_with_low_test, training=False)

        split_list_reverse_test = []

        for i_ in range(1, args.level+1):

            index = 0 - i_

            high_transed_test = tf.multiply(highs_test[index], generated_high)
            split_list_reverse_test.append(high_transed_test)
            generated_high = up_sample(generated_high)

        split_list_test = []

        index = -1

        for _ in range(args.level):

            split_list_test.append(split_list_reverse_test[index])

            index = index - 1

        generated_full = merge_image(generated_low, split_list_test)

        cost = time.time() - start

        print(cost)
        
        multi_test_save_single = np.vstack((multi_test_save_single, tf.reshape(tf.clip_by_value((generated_full+1)/2, 0, 1), [args.load_img_size, args.load_img_size, 3]).numpy()))
        

        if i == 0:
            multi_test_save_all = multi_test_save_single
        else:
            multi_test_save_all = np.hstack((multi_test_save_all, multi_test_save_single))

    image_path = os.path.join(save_dir_test, 'iteration_{}.jpg'.format(args.start_iter))

    misc.imsave(image_path, multi_test_save_all)
def train(train_dataset, args, lsgan=True):

    if not os.path.exists(save_dir_test):
        os.mkdir(save_dir_test)

    the_time = datetime.datetime.now()
    logs = open(save_dir_test + '/' + 'train_process.txt', "a")
    logs.write('#' * 20)
    logs.write('\n')
    logs.write(str(the_time))
    logs.write('\n')
    logs.write('#' * 20)
    logs.write('\n')
    logs.write(str(args))
    logs.write('\n\n')
    logs.close()

    start = time.time()

    for iteration in range(args.start_iter, args.max_iterations):

        with tf.GradientTape() as genA2B_tape, tf.GradientTape(
        ) as transA2B_tape:
            try:
                train_full = next(train_dataset)
                trainA_full, trainB_full = train_full[:, :,
                                                      256:, :], train_full[:, :, :
                                                                           256, :]

            except tf.errors.OutOfRangeError:
                print("Error, run out of data")
                break

            highs_A, lows_A = pyramid.split(trainA_full, args.level)
            highs_B, lows_B = pyramid.split(trainB_full, args.level)

            genA2B_output = genA2B(lows_A[-1], training=True)

            high_with_low_A = tf.concat([highs_A[-1], lows_A[-2]], 3)

            highA2B_output_o = transA2B(high_with_low_A, training=True)

            highA2B_output = highA2B_output_o

            split_list_A_reverse = []

            for i in range(1, args.level + 1):

                index = 0 - i

                high_transed_A = tf.multiply(highs_A[index], highA2B_output)
                split_list_A_reverse.append(high_transed_A)
                if i < args.level:
                    highA2B_output = up_sample(highA2B_output)

            split_list_A = []

            index = -1

            for _ in range(args.level):

                split_list_A.append(split_list_A_reverse[index])

                index = index - 1

            genA2B_output_full = merge_image(genA2B_output, split_list_A)

            # mse_loss_full = tf.reduce_mean(tf.square(tf.clip_by_value((trainB_full+1)/2, 0, 1) - tf.clip_by_value((genA2B_output_full+1)/2, 0, 1)))
            # mse_loss_low = tf.reduce_mean(tf.square(tf.clip_by_value((lows_B[-1]+1)/2, 0, 1) - tf.clip_by_value((genA2B_output+1)/2, 0, 1)))
            # mse_loss_high = tf.reduce_mean(tf.square(tf.clip_by_value((highs_B[0]+1)/2, 0, 1) - tf.clip_by_value((split_list_A[0]+1)/2, 0, 1)))

            mse_loss_full = tf.reduce_mean(
                tf.square(trainB_full - genA2B_output_full))
            mse_loss_low = tf.reduce_mean(tf.square(lows_B[-1] -
                                                    genA2B_output))
            mse_loss_high = tf.reduce_mean(
                tf.square(highs_B[0] - split_list_A[0]))

            # color_loss = color_losses(genA2B_output_full, trainB_full, args.batch_size)

            # tv_loss = total_variation_loss(tf.clip_by_value((genA2B_output_full+1)/2, 0, 1))

            gen_loss = args.mse_lambda_low * mse_loss_low + args.mse_lambda_full * mse_loss_full  # + args.tv_lambda * tv_loss + args.color_lambda * color_loss#
            trans_loss = args.mse_lambda_full * mse_loss_full + args.mse_lambda_high * mse_loss_high  # + args.tv_lambda * tv_loss + args.color_lambda * color_loss#

        genA2B_gradients = genA2B_tape.gradient(gen_loss,
                                                genA2B.trainable_variables)
        transA2B_gradients = transA2B_tape.gradient(
            trans_loss, transA2B.trainable_variables)

        genA2B_optimizer.apply_gradients(
            zip(genA2B_gradients, genA2B.trainable_variables))
        transA2B_optimizer.apply_gradients(
            zip(transA2B_gradients, transA2B.trainable_variables))

        if iteration % args.log_interval == 0:

            generated = np.clip((genA2B_output_full + 1) / 2, 0, 1)
            target = np.clip((trainB_full + 1) / 2, 0, 1)

            # generated = generated[0, 100:900, 100:900, :]
            # target = target[0, 100:900, 100:900, :]

            psnr = tf.image.psnr(generated, target, max_val=1.0)

            logs_train = 'Training ' + procname + ', Iteration: {}th, time: {:.4f}, LOSSES: mse_low: {:.4f}, mse_high: {:.4f}, mse_full: {:.4f}, psnr: {:.4f}'.format(
                iteration,
                time.time() - start, mse_loss_low, mse_loss_high,
                mse_loss_full, psnr[0])

            print(logs_train)

            logs = open(save_dir_test + '/' + 'train_process.txt', "a")
            logs.write(logs_train)
            logs.write('\n')
            logs.close()

            start = time.time()

        if iteration % args.save_interval == 0:

            # generate_images(trainA_full[0], trainB_full[0], genA2B_output_full[0], genA2B_output_full[0], save_dir_train, iteration)

            if not os.path.exists(model_save_dir):
                os.mkdir(model_save_dir)

            genA2B.save_weights(model_save_dir + '/genA2B_' + str(iteration))
            transA2B.save_weights(model_save_dir + '/transA2B_' +
                                  str(iteration))

        if iteration % args.test_interval == 0:

            multi_test_save_all = []

            for i in range(args.test_size):

                test_full = next(test_datasetA)
                test_input, test_output = test_full[:, :,
                                                    256:, :], test_full[:, :, :
                                                                        256, :]

                multi_test_save_single = tf.reshape(
                    tf.clip_by_value((test_input + 1) / 2, 0, 1),
                    [args.load_img_size, args.load_img_size, 3]).numpy()

                # start_test = time.time()
                highs_test, lows_test = pyramid.split(test_input, args.level)
                generated_low = genA2B(lows_test[-1], training=False)
                high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]],
                                               3)
                generated_high = transA2B(high_with_low_test, training=False)

                split_list_reverse_test = []

                for i_ in range(1, args.level + 1):

                    index = 0 - i_

                    high_transed_test = tf.multiply(highs_test[index],
                                                    generated_high)
                    split_list_reverse_test.append(high_transed_test)
                    generated_high = up_sample(generated_high)

                split_list_test = []

                index = -1

                for _ in range(args.level):

                    split_list_test.append(split_list_reverse_test[index])

                    index = index - 1

                generated_full = merge_image(generated_low, split_list_test)

                multi_test_save_single = np.vstack(
                    (multi_test_save_single,
                     tf.reshape(
                         tf.clip_by_value((generated_full + 1) / 2, 0, 1),
                         [args.load_img_size, args.load_img_size, 3]).numpy()))
                multi_test_save_single = np.vstack(
                    (multi_test_save_single,
                     tf.reshape(
                         tf.clip_by_value((test_output + 1) / 2, 0, 1),
                         [args.load_img_size, args.load_img_size, 3]).numpy()))

                if i == 0:
                    multi_test_save_all = multi_test_save_single
                else:
                    multi_test_save_all = np.hstack(
                        (multi_test_save_all, multi_test_save_single))

            image_path = os.path.join(save_dir_test,
                                      'iteration_{}.jpg'.format(iteration))

            misc.imsave(image_path, multi_test_save_all)

        if iteration % args.test_psnr_interval == 0 and iteration != 0:

            psnr_all = []
            cost_all = []

            print('******************************')
            print('testing on all test images...')

            for i in range(test_all_size):

                test_full = next(test_datasetA)
                test_input, test_output = test_full[:, :,
                                                    256:, :], test_full[:, :, :
                                                                        256, :]

                start = time.time()

                highs_test, lows_test = pyramid.split(test_input, args.level)
                generated_low = genA2B(lows_test[-1], training=False)
                high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]],
                                               3)
                generated_high = transA2B(high_with_low_test, training=False)

                split_list_reverse_test = []

                for i in range(1, args.level + 1):

                    index = 0 - i

                    high_transed_test = tf.multiply(highs_test[index],
                                                    generated_high)
                    split_list_reverse_test.append(high_transed_test)
                    generated_high = up_sample(generated_high)

                split_list_test = []

                index = -1

                for _ in range(args.level):

                    split_list_test.append(split_list_reverse_test[index])

                    index = index - 1

                generated_full = merge_image(generated_low, split_list_test)

                cost = time.time() - start

                generated = np.clip((generated_full + 1) / 2, 0, 1)
                target = np.clip((test_output + 1) / 2, 0, 1)

                # generated = generated[0, 100:900, 100:900, :]
                # target = target[0, 100:900, 100:900, :]

                psnr_test = tf.image.psnr(generated, target, max_val=1.0)

                psnr_all.append(psnr_test[0].numpy())
                cost_all.append(cost)

            logs_test = 'test iteration: {}th, mean psnr: {:.4f}, avg inference time: {:.4f}'.format(
                iteration, np.mean(psnr_all), np.mean(cost_all))
            print(logs_test)
            print(
                '****************************************************************************'
            )
            logs = open(save_dir_test + '/' + 'train_process.txt', "a")
            logs.write(logs_test)
            logs.write('\n')
            logs.close()
Beispiel #3
0
def train(train_dataset, args, lsgan=True):
    

    start = time.time()

    for iteration in range(args.start_iter, args.max_iterations):

        with tf.GradientTape() as genA2B_tape, tf.GradientTape() as transA2B_tape:
            try:
                trainA_full, trainB_full = next(train_dataset)

            except tf.errors.OutOfRangeError:
                print("Error, run out of data")
                break

            highs_A, lows_A = pyramid.split(trainA_full, args.level)
            highs_B, lows_B = pyramid.split(trainB_full, args.level)

            # print('max pixel value_input: ', trainA.numpy().mean())
            # print('min pixel value_input: ', trainA.numpy().std())
            # print('max pixel value_target: ', trainB.numpy().mean())
            # print('min pixel value_target: ', trainB.numpy().std())
            # print('*********************************')
            
            genA2B_output = genA2B(lows_A[-1], training=True)

            high_with_low_A = tf.concat([highs_A[-1], lows_A[-2]], 3)

            mask_0, mask_1, mask_2, mask_3 = transA2B(high_with_low_A, training=True)

            split_list_A = highs_A

            split_list_A[-1] = tf.multiply(highs_A[-1], mask_0)
            split_list_A[-2] = tf.multiply(highs_A[-2], mask_1)
            split_list_A[-3] = tf.multiply(highs_A[-3], mask_2)
            split_list_A[-4] = tf.multiply(highs_A[-4], mask_3)

            genA2B_output_full = merge_image(genA2B_output, split_list_A)

            mse_loss_full = tf.reduce_mean(tf.square(trainB_full - genA2B_output_full))
            mse_loss_low = tf.reduce_mean(tf.square(lows_B[-1] - genA2B_output))

            # color_loss = color_losses(genA2B_output_full, trainB_full, args.batch_size)

            # tv_loss = total_variation_loss(genA2B_output_full)

            genA2B_loss = args.mse_lambda_low * mse_loss_low + args.mse_lambda_full * mse_loss_full# + color_loss# + args.tv_lambda * tv_loss
            
        genA2B_gradients = genA2B_tape.gradient(genA2B_loss, genA2B.trainable_variables)
        transA2B_gradients = transA2B_tape.gradient(genA2B_loss, transA2B.trainable_variables)

        genA2B_optimizer.apply_gradients(zip(genA2B_gradients, genA2B.trainable_variables))
        transA2B_optimizer.apply_gradients(zip(transA2B_gradients, transA2B.trainable_variables))

        if iteration % args.log_interval == 0:

            generated = np.clip((genA2B_output_full+1)/2, 0, 1)
            target = np.clip((trainB_full+1)/2, 0, 1)

            psnr = tf.image.psnr(generated, target, max_val=1.0)

            print('Training ' + procname + ', Iteration: {}th, time: {:.4f}, LOSSES: mse_low: {:.4f}, mse_full: {:.4f}, psnr: {:.4f}'.format(
                iteration, time.time() - start, mse_loss_low, mse_loss_full, psnr[0]))
            start = time.time()

        if iteration % args.save_interval == 0:

            generate_images(trainA_full, trainB_full, genA2B_output_full, genA2B_output_full, save_dir_train, iteration)

            if not os.path.exists(model_save_dir):
                os.mkdir(model_save_dir)

            genA2B.save_weights(model_save_dir + '/genA2B_' + str(iteration))
            transA2B.save_weights(model_save_dir + '/transA2B_' + str(iteration))
        
        if iteration % args.test_interval == 0:

            input_images = []
            # target_images = []
            output_images = []

            for i in range(args.test_size):

                test_input, test_output = next(test_dataset)

                # start_test = time.time()
                highs_test, lows_test = pyramid.split(test_input, args.level)
                generated_low = genA2B(lows_test[-1], training=False)
                high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
                mask_0, mask_1, mask_2, mask_3 = transA2B(high_with_low_test, training=True)

                split_list_test = highs_test

                split_list_test[-1] = tf.multiply(highs_test[-1], mask_0)
                split_list_test[-2] = tf.multiply(highs_test[-2], mask_1)
                split_list_test[-3] = tf.multiply(highs_test[-3], mask_2)
                split_list_test[-4] = tf.multiply(highs_test[-4], mask_3)

                generated_full = merge_image(generated_low, split_list_test)

                # test_time = time.time() - start_test

                # if iteration % 100 == 0:

                #     generated = np.clip((generated_full+1)/2, 0, 1)
                #     target = np.clip((test_output+1)/2, 0, 1)

                #     psnr_test = tf.image.psnr(generated, target, max_val=1.0)
                #     print('test time: {:.4f}, psnr: {:.4f}'.format(test_time, psnr_test[0]))

                input_images.append(test_input)
                output_images.append(generated_full)
                
                # image_shape = np.shape(generated_full)
                
                # saved_image = np.zeros([image_shape[1], image_shape[2]*3, image_shape[3]])
                # saved_image[:, image_shape[1]*0 : image_shape[1]*1, :] = (test_input[0, :, :, :])# * 255.0
                # saved_image[:, image_shape[1]*1 : image_shape[1]*2, :] = (generated_full[0, :, :, :])# * 255.0
                # saved_image[:, image_shape[1]*2 : image_shape[1]*3, :] = (test_output[0, :, :, :])# * 255.0
                # # print(saved_image.shape)

                # if not os.path.exists(args.save_dir_test):
                #     os.mkdir(args.save_dir_test)

                # save_path = args.save_dir_test + '/generated_{}_{}.png'.format(iteration, i)

                # skimage.io.imsave(save_path, saved_image)

            generate_images(input_images[0], input_images[1], output_images[0], output_images[1], save_dir_test, iteration)

        if (iteration + 1) % args.test_psnr_interval == 0:

            psnr_all = []
            cost_all = []

            print('******************************')
            print('testing on all test images...')

            for i in range(test_size):

                test_input, test_output = next(test_dataset)

                start = time.time()

                highs_test, lows_test = pyramid.split(test_input, args.level)
                generated_low = genA2B(lows_test[-1], training=False)
                high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
                mask_0, mask_1, mask_2, mask_3 = transA2B(high_with_low_test, training=True)

                split_list_test = highs_test

                split_list_test[-1] = tf.multiply(highs_test[-1], mask_0)
                split_list_test[-2] = tf.multiply(highs_test[-2], mask_1)
                split_list_test[-3] = tf.multiply(highs_test[-3], mask_2)
                split_list_test[-4] = tf.multiply(highs_test[-4], mask_3)

                generated_full = merge_image(generated_low, split_list_test)

                cost = time.time() - start

                generated = np.clip((generated_full+1)/2, 0, 1)
                target = np.clip((test_output+1)/2, 0, 1)

                psnr_test = tf.image.psnr(generated, target, max_val=1.0)

                psnr_all.append(psnr_test[0].numpy())
                cost_all.append(cost)

            print('test iteration: {}th, mean psnr: {:.4f}, mean inference time: {:.4f}'.format(iteration, np.mean(psnr_all), np.mean(cost_all)))
            print('****************************************************************************')
def train(train_datasetA, train_datasetB, args):

    if not os.path.exists(save_dir_test):
        os.mkdir(save_dir_test)
    the_time = datetime.datetime.now()
    logs = open(save_dir_test + '/' + 'train_process.txt', "a")
    logs.write('#' * 20)
    logs.write('\n')
    logs.write(str(the_time))
    logs.write('\n')
    logs.write('#' * 20)
    logs.write('\n')
    logs.write(str(args))
    logs.write('\n\n')
    logs.close()

    start = time.time()

    for iteration in range(args.start_iter, args.max_iterations):

        with tf.GradientTape() as gen_tape, tf.GradientTape(
        ) as disc_tape, tf.GradientTape() as trans_tape:
            try:
                trainA_full = next(test_datasetA)
                trainB_full = next(train_datasetB)
            except tf.errors.OutOfRangeError:
                print("Error, run out of data")
                break

            start_time = time.time()
            highs_A, lows_A = pyramid.split(trainA_full, args.level)
            highs_B, lows_B = pyramid.split(trainB_full, args.level)

            lows_A_transed = adain(lows_A[-1], lows_B[-1])

            A_adain = merge_image(lows_A_transed, highs_A)

            print('cost: ', time.time() - start_time)

            multi_test_save_single = tf.reshape(
                tf.clip_by_value((trainA_full + 1) / 2, 0, 1),
                [args.load_img_size, args.load_img_size, 3]).numpy()
            multi_test_save_single = np.hstack(
                (multi_test_save_single,
                 tf.reshape(tf.clip_by_value(
                     (A_adain + 1) / 2, 0,
                     1), [args.load_img_size, args.load_img_size, 3]).numpy()))
            multi_test_save_single = np.hstack(
                (multi_test_save_single,
                 tf.reshape(tf.clip_by_value(
                     (trainB_full + 1) / 2, 0,
                     1), [args.load_img_size, args.load_img_size, 3]).numpy()))
            misc.imsave('2.jpg', multi_test_save_single)

            genA2B_output = gen(lows_A[-1], training=True)

            high_with_low_A = tf.concat(
                [highs_A[-1], up_sample(genA2B_output)], 3)
            high_with_low_A = tf.concat(
                [high_with_low_A, up_sample(lows_A[-1])], 3)

            mask_0, mask_1, mask_2, mask_3 = trans(high_with_low_A,
                                                   training=True)

            split_list_A = highs_A

            split_list_A[-1] = tf.multiply(highs_A[-1], mask_0) + highs_A[-1]
            split_list_A[-2] = tf.multiply(highs_A[-2], mask_1) + highs_A[-2]
            split_list_A[-3] = tf.multiply(highs_A[-3], mask_2) + highs_A[-3]
            split_list_A[-4] = tf.multiply(highs_A[-4], mask_3) + highs_A[-4]

            genA2B_output_full = merge_image(genA2B_output, split_list_A)

            disc_real_output = disc(trainB_full)

            disc_fake_output = disc(genA2B_output_full)

            reconstruction_loss = tf.reduce_mean(
                tf.square(genA2B_output_full - trainA_full))
            # color_loss = color_losses(genA2B_output_full, trainB_full, args.batch_size)
            # tv_loss = total_variation_loss(genA2B_output_full)

            disc_loss = discriminator_loss_cal(disc_real_output,
                                               disc_fake_output)
            generator_loss = generator_loss_cal(disc_fake_output)

            # disc_loss = discriminator_losssss('lsgan', disc_real_output, disc_fake_output)
            # generator_loss = generator_losssss('lsgan', disc_fake_output)

            gen_loss = args.dis_lambda * generator_loss + args.recons_lambda * reconstruction_loss  #+ args.color_lambda * color_loss + args.tv_lambda * tv_loss

        gen_gradients = gen_tape.gradient(gen_loss, gen.trainable_variables)
        trans_gradients = trans_tape.gradient(gen_loss,
                                              trans.trainable_variables)
        disc_gradients = disc_tape.gradient(disc_loss,
                                            disc.trainable_variables)
        disc_gradients, _ = tf.clip_by_global_norm(disc_gradients, 15)

        gen_optimizer.apply_gradients(
            zip(gen_gradients, gen.trainable_variables))
        trans_optimizer.apply_gradients(
            zip(trans_gradients, trans.trainable_variables))
        disc_optimizer.apply_gradients(
            zip(disc_gradients, disc.trainable_variables))

        if iteration % args.log_interval == 0:

            logs_train = 'Training ' + procname + ', Iteration: {}th, Duration: {:.4f}, LOSSES: recons: {:.4f}, generator: {:.4f}, disc: {:.4f}'.format(
                iteration,
                time.time() - start, reconstruction_loss,
                generator_loss.numpy(), disc_loss.numpy())

            print(logs_train)

            logs = open(save_dir_test + '/' + 'train_process.txt', "a")
            logs.write(logs_train)
            logs.write('\n')
            logs.close()

            start = time.time()

        if iteration % args.save_interval == 0:
            # generate_images(trainA_full, trainA_full, genA2B_output_full, highA_output, save_dir_train, iteration)

            if not os.path.exists(model_save_dir):
                os.mkdir(model_save_dir)
            # checkpoint.save(file_prefix = checkpoint_prefix)
            disc.save_weights(model_save_dir + '/disc_' + str(iteration))
            gen.save_weights(model_save_dir + '/gen_' + str(iteration))
            trans.save_weights(model_save_dir + '/trans_' + str(iteration))

        if iteration % args.test_interval == 0:

            multi_test_save_all = []

            for i in range(args.test_size):

                test_input = next(test_datasetA)

                multi_test_save_single = tf.reshape(
                    tf.clip_by_value((test_input + 1) / 2, 0, 1),
                    [args.load_img_size, args.load_img_size, 3]).numpy()

                # start_test = time.time()
                highs_test, lows_test = pyramid.split(test_input, args.level)
                generated_low = gen(lows_test[-1], training=False)
                high_with_low_test = tf.concat(
                    [highs_test[-1], up_sample(generated_low)], 3)
                high_with_low_test = tf.concat(
                    [high_with_low_test,
                     up_sample(lows_test[-1])], 3)
                mask_0, mask_1, mask_2, mask_3 = trans(high_with_low_test,
                                                       training=True)

                split_list_test = highs_test

                split_list_test[-1] = tf.multiply(highs_test[-1],
                                                  mask_0) + highs_test[-1]
                split_list_test[-2] = tf.multiply(highs_test[-2],
                                                  mask_1) + highs_test[-2]
                split_list_test[-3] = tf.multiply(highs_test[-3],
                                                  mask_2) + highs_test[-3]
                split_list_test[-4] = tf.multiply(highs_test[-4],
                                                  mask_3) + highs_test[-4]

                generated_full = merge_image(generated_low, split_list_test)

                multi_test_save_single = np.vstack(
                    (multi_test_save_single,
                     tf.reshape(
                         tf.clip_by_value((generated_full + 1) / 2, 0, 1),
                         [args.load_img_size, args.load_img_size, 3]).numpy()))

                if i == 0:
                    multi_test_save_all = multi_test_save_single
                else:
                    multi_test_save_all = np.hstack(
                        (multi_test_save_all, multi_test_save_single))

            image_path = os.path.join(save_dir_test,
                                      'iteration_{}.jpg'.format(iteration))

            misc.imsave(image_path, multi_test_save_all)
Beispiel #5
0
def train(train_datasetA, train_datasetB, args):

    if not os.path.exists(save_dir_test):
        os.mkdir(save_dir_test)
    the_time = datetime.datetime.now()
    logs = open(save_dir_test + '/' + 'train_process.txt', "a")
    logs.write('#' * 20)
    logs.write('\n')
    logs.write(str(the_time))
    logs.write('\n')
    logs.write('#' * 20)
    logs.write('\n')
    logs.write(str(args))
    logs.write('\n\n')
    logs.close()

    start = time.time()

    for iteration in range(args.start_iter, args.max_iterations):

        with tf.GradientTape() as genA2B_tape, tf.GradientTape() as genB2A_tape, \
                tf.GradientTape() as discA_tape, tf.GradientTape() as discB_tape, tf.GradientTape() as transA2B_tape, tf.GradientTape() as transB2A_tape:
            try:
                trainA_full = next(train_datasetA)
                trainB_full = next(train_datasetB)
            except tf.errors.OutOfRangeError:
                print("Error, run out of data")
                break

            highs_A, lows_A = pyramid.split(trainA_full, args.level)
            highs_B, lows_B = pyramid.split(trainB_full, args.level)

            genA2B_output = genA2B(lows_A[-1], training=True)
            genB2A_output = genB2A(lows_B[-1], training=True)

            high_with_low_A = tf.concat([highs_A[-1], lows_A[-2]], 3)
            high_with_low_B = tf.concat([highs_B[-1], lows_B[-2]], 3)

            # kernel = pyramid._binomial_kernel(tf.shape(input=lows_A[-1])[3], dtype=lows_A[-1].dtype)

            # low_up_before_A = pyramid._upsample(lows_A[-1], kernel)
            # low_up_after_A = pyramid._upsample(genA2B_output, kernel)
            # low_up_before_B = pyramid._upsample(lows_B[-1], kernel)
            # low_up_after_B = pyramid._upsample(genB2A_output, kernel)

            # high_with_low_A = tf.concat([highs_A[-1], low_up_before_A], 3)
            # high_with_low_B = tf.concat([highs_B[-1], low_up_before_B], 3)

            # high_with_low_A = tf.concat([high_with_low_A, low_up_after_A], 3)
            # high_with_low_B = tf.concat([high_with_low_B, low_up_after_B], 3)

            highA2B_output_o = transA2B(high_with_low_A, training=True)
            highB2A_output_o = transB2A(high_with_low_B, training=True)

            highA2B_output = highA2B_output_o
            highB2A_output = highB2A_output_o

            split_list_A_reverse = []
            split_list_B_reverse = []

            for i in range(1, args.level + 1):

                index = 0 - i

                high_transed_A = tf.multiply(highs_A[index], highA2B_output)
                split_list_A_reverse.append(high_transed_A)
                highA2B_output = up_sample(highA2B_output)

                high_transed_B = tf.multiply(highs_B[index], highB2A_output)
                split_list_B_reverse.append(high_transed_B)
                highB2A_output = up_sample(highB2A_output)

            split_list_A = []
            split_list_B = []

            index = -1

            for _ in range(args.level):

                split_list_A.append(split_list_A_reverse[index])
                split_list_B.append(split_list_B_reverse[index])

                index = index - 1

            genA2B_output_full = merge_image(genA2B_output, split_list_A)
            genB2A_output_full = merge_image(genB2A_output, split_list_B)

            genA2B_output_full = tf.clip_by_value((genA2B_output_full + 1) / 2,
                                                  0, 1)
            genB2A_output_full = tf.clip_by_value((genB2A_output_full + 1) / 2,
                                                  0, 1)
            trainA_full = tf.clip_by_value((trainA_full + 1) / 2, 0, 1)
            trainB_full = tf.clip_by_value((trainB_full + 1) / 2, 0, 1)

            reconstructed_A_low = genB2A(genA2B_output, training=True)
            reconstructed_B_low = genA2B(genB2A_output, training=True)

            #################

            high_with_low_A_r = tf.concat([highA2B_output_o, lows_A[-2]], 3)
            high_with_low_B_r = tf.concat([highB2A_output_o, lows_B[-2]], 3)

            highA2B_output_r = transB2A(high_with_low_A_r, training=True)
            highB2A_output_r = transA2B(high_with_low_B_r, training=True)

            split_list_A_reverse = []
            split_list_B_reverse = []

            for i in range(1, args.level + 1):

                index = 0 - i

                high_transed_A = tf.multiply(highs_A[index], highA2B_output_r)
                split_list_A_reverse.append(high_transed_A)
                highA2B_output_r = up_sample(highA2B_output_r)

                high_transed_B = tf.multiply(highs_B[index], highB2A_output_r)
                split_list_B_reverse.append(high_transed_B)
                highB2A_output_r = up_sample(highB2A_output_r)

            split_list_A = []
            split_list_B = []

            index = -1

            for _ in range(args.level):

                split_list_A.append(split_list_A_reverse[index])
                split_list_B.append(split_list_B_reverse[index])

                index = index - 1

            #################

            reconstructed_A_full = merge_image(reconstructed_A_low,
                                               split_list_A)
            reconstructed_B_full = merge_image(reconstructed_B_low,
                                               split_list_B)

            reconstructed_A_full = tf.clip_by_value(
                (reconstructed_A_full + 1) / 2, 0, 1)
            reconstructed_B_full = tf.clip_by_value(
                (reconstructed_B_full + 1) / 2, 0, 1)

            discA_real_output = discA(trainA_full)
            discB_real_output = discB(trainB_full)

            discA_fake_output = discA(genB2A_output_full)
            discB_fake_output = discB(genA2B_output_full)

            discA_loss = discriminator_losssss('lsgan', discA_real_output,
                                               discA_fake_output)
            discB_loss = discriminator_losssss('lsgan', discB_real_output,
                                               discB_fake_output)

            color_loss_A2B = color_losses(genA2B_output_full, trainB_full,
                                          args.batch_size)
            color_loss_B2A = color_losses(genB2A_output_full, trainA_full,
                                          args.batch_size)

            # content_loss_A2B = content_losses(genA2B_output_full, trainA_full, args.batch_size)
            # content_loss_B2A = content_losses(genB2A_output_full, trainB_full, args.batch_size)

            # gp_A = gradient_penalty(trainA_full, genB2A_output_full, discA)
            # gp_B = gradient_penalty(trainB_full, genA2B_output_full, discB)

            generatorA2B_loss = generator_losssss('lsgan', discB_fake_output)
            cycleA2B_loss = cycle_consistency_loss(trainA_full, trainB_full,
                                                   reconstructed_A_full,
                                                   reconstructed_B_full)

            generatorB2A_loss = generator_losssss('lsgan', discA_fake_output)
            cycleB2A_loss = cycle_consistency_loss(trainA_full, trainB_full,
                                                   reconstructed_A_full,
                                                   reconstructed_B_full)

            genA2B_loss = args.dis_lambda * generatorA2B_loss + args.cyc_lambda * cycleA2B_loss + args.color_lambda * color_loss_A2B

            genB2A_loss = args.dis_lambda * generatorB2A_loss + args.cyc_lambda * cycleB2A_loss + args.color_lambda * color_loss_B2A

        genA2B_gradients = genA2B_tape.gradient(genA2B_loss,
                                                genA2B.trainable_variables)
        genB2A_gradients = genB2A_tape.gradient(genB2A_loss,
                                                genB2A.trainable_variables)
        transA2B_gradients = transA2B_tape.gradient(
            genA2B_loss, transA2B.trainable_variables)
        transB2A_gradients = transB2A_tape.gradient(
            genB2A_loss, transB2A.trainable_variables)

        discA_gradients = discA_tape.gradient(discA_loss,
                                              discA.trainable_variables)
        discA_gradients, _ = tf.clip_by_global_norm(discA_gradients, 15)
        discB_gradients = discB_tape.gradient(discB_loss,
                                              discB.trainable_variables)
        discB_gradients, _ = tf.clip_by_global_norm(discB_gradients, 15)

        genA2B_optimizer.apply_gradients(
            zip(genA2B_gradients, genA2B.trainable_variables))
        genB2A_optimizer.apply_gradients(
            zip(genB2A_gradients, genB2A.trainable_variables))
        transA2B_optimizer.apply_gradients(
            zip(transA2B_gradients, transA2B.trainable_variables))
        transB2A_optimizer.apply_gradients(
            zip(transB2A_gradients, transB2A.trainable_variables))

        discA_optimizer.apply_gradients(
            zip(discA_gradients, discA.trainable_variables))
        discB_optimizer.apply_gradients(
            zip(discB_gradients, discB.trainable_variables))

        if iteration % args.log_interval == 0:

            logs_train = 'Training ' + procname + ', Iteration: {}th, Duration: {:.4f}, LOSSES: cycle: {:.4f}, generator: {:.4f}, disc: {:.4f}, color: {:.4f}'.format(
                iteration,
                time.time() - start, cycleA2B_loss, generatorA2B_loss.numpy(),
                discA_loss.numpy(), color_loss_A2B)

            print(logs_train)

            logs = open(save_dir_test + '/' + 'train_process.txt', "a")
            logs.write(logs_train)
            logs.write('\n')
            logs.close()

            start = time.time()

        if iteration % args.save_interval == 0:
            # generate_images(trainA_full, trainB_full, genA2B_output_full, genB2A_output_full, save_dir_train, iteration)

            # print('Time taken for iteration {} is {} sec'.format(iteration + 1, time.time() - start))

            if not os.path.exists(model_save_dir):
                os.mkdir(model_save_dir)
            # checkpoint.save(file_prefix = checkpoint_prefix)
            discA.save_weights(model_save_dir + '/discA_' + str(iteration))
            discB.save_weights(model_save_dir + '/discB_' + str(iteration))
            genA2B.save_weights(model_save_dir + '/genA2B_' + str(iteration))
            genB2A.save_weights(model_save_dir + '/genB2A_' + str(iteration))
            transA2B.save_weights(model_save_dir + '/transA2B_' +
                                  str(iteration))
            transB2A.save_weights(model_save_dir + '/transB2A_' +
                                  str(iteration))

        if iteration % args.test_interval == 0:

            multi_test_save_all = []

            for i in range(args.test_size):

                test_input, test_output = next(test_dataset)

                multi_test_save_single = tf.reshape(
                    tf.clip_by_value((test_input + 1) / 2, 0, 1),
                    [args.load_img_size, args.load_img_size, 3]).numpy()

                highs_test, lows_test = pyramid.split(test_input, args.level)
                generated_low = genA2B(lows_test[-1], training=False)
                high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]],
                                               3)
                generated_high = transA2B(high_with_low_test, training=False)

                split_list_reverse_test = []

                for i_ in range(1, args.level + 1):

                    index = 0 - i_

                    high_transed_test = tf.multiply(highs_test[index],
                                                    generated_high)
                    split_list_reverse_test.append(high_transed_test)
                    generated_high = up_sample(generated_high)

                split_list_test = []

                index = -1

                for _ in range(args.level):

                    split_list_test.append(split_list_reverse_test[index])

                    index = index - 1

                generated_full = merge_image(generated_low, split_list_test)

                multi_test_save_single = np.vstack(
                    (multi_test_save_single,
                     tf.reshape(
                         tf.clip_by_value((generated_full + 1) / 2, 0, 1),
                         [args.load_img_size, args.load_img_size, 3]).numpy()))
                multi_test_save_single = np.vstack(
                    (multi_test_save_single,
                     tf.reshape(
                         tf.clip_by_value((test_output + 1) / 2, 0, 1),
                         [args.load_img_size, args.load_img_size, 3]).numpy()))

                if i == 0:
                    multi_test_save_all = multi_test_save_single
                else:
                    multi_test_save_all = np.hstack(
                        (multi_test_save_all, multi_test_save_single))

            image_path = os.path.join(save_dir_test,
                                      'iteration_{}.jpg'.format(iteration))

            misc.imsave(image_path, multi_test_save_all)

        if iteration % args.test_psnr_interval == 0:

            psnr_all = []
            cost_all = []

            print('******************************')
            print('testing on all test images...')

            for i in range(test_all_size):

                test_input, test_output = next(test_dataset)

                start = time.time()

                highs_test, lows_test = pyramid.split(test_input, args.level)
                generated_low = genA2B(lows_test[-1], training=False)
                high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]],
                                               3)
                generated_high = transA2B(high_with_low_test, training=False)

                split_list_reverse_test = []

                for i in range(1, args.level + 1):

                    index = 0 - i

                    high_transed_test = tf.multiply(highs_test[index],
                                                    generated_high)
                    split_list_reverse_test.append(high_transed_test)
                    generated_high = up_sample(generated_high)

                split_list_test = []

                index = -1

                for _ in range(args.level):

                    split_list_test.append(split_list_reverse_test[index])

                    index = index - 1

                generated_full = merge_image(generated_low, split_list_test)

                cost = time.time() - start

                generated = tf.clip_by_value((generated_full + 1) / 2, 0, 1)
                target = tf.clip_by_value((test_output + 1) / 2, 0, 1)

                # generated = generated[0, 100:900, 100:900, :]
                # target = target[0, 100:900, 100:900, :]

                psnr_test = tf.image.psnr(generated, target, max_val=1.0)

                psnr_all.append(psnr_test[0].numpy())
                cost_all.append(cost)

            logs_test = 'test iteration: {}th, mean psnr: {:.4f}, avg inference time: {:.4f}'.format(
                iteration, np.mean(psnr_all), np.mean(cost_all))
            print(logs_test)
            print(
                '****************************************************************************'
            )
            logs = open(save_dir_test + '/' + 'train_process.txt', "a")
            logs.write(logs_test)
            logs.write('\n')
            logs.close()
def train(train_datasetA, train_datasetB, args):

    if not os.path.exists(save_dir_test):
        os.mkdir(save_dir_test)
    the_time = datetime.datetime.now()
    logs = open(save_dir_test + '/' + 'train_process.txt', "a")
    logs.write('#' * 20)
    logs.write('\n')
    logs.write(str(the_time))
    logs.write('\n')
    logs.write('#' * 20)
    logs.write('\n')
    logs.write(str(args))
    logs.write('\n\n')
    logs.close()
    
    start = time.time()

    for iteration in range(args.start_iter, args.max_iterations):

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape, tf.GradientTape() as trans_tape:
            try:
                trainA_full = next(train_datasetA)
                trainB_full = next(train_datasetB)
            except tf.errors.OutOfRangeError:
                print("Error, run out of data")
                break

            highs_A, lows_A = pyramid.split(trainA_full, args.level)
            
            genA2B_output = gen(lows_A[-1], training=True)

            highA_output = trans(highs_A, training=True)

            genA2B_output_full = merge_image(genA2B_output, highA_output)

            # genA2B_output_full = tf.clip_by_value((genA2B_output_full+1)/2, 0, 1)
            # trainA_full = tf.clip_by_value((trainA_full+1)/2, 0, 1)
            # trainB_full = tf.clip_by_value((trainB_full+1)/2, 0, 1)

            disc_real_output = disc(trainB_full)

            disc_fake_output = disc(genA2B_output_full)

            disc_loss = discriminator_loss_cal(disc_real_output, disc_fake_output)
            generator_loss = generator_loss_cal(disc_fake_output)

            # generator_loss = generator_losssss('lsgan', disc_fake_output)
            # disc_loss = discriminator_losssss('lsgan', disc_real_output, disc_fake_output)

            # color_loss = color_losses(genA2B_output_full, trainB_full, args.batch_size)

            reconstruction_loss = tf.reduce_mean(tf.square(genA2B_output_full - trainA_full))
            # tv_loss = total_variation_loss(genA2B_output_full)
            
            gen_loss = args.dis_lambda * generator_loss + args.recons_lambda * reconstruction_loss #+ args.color_lambda * color_loss + args.tv_lambda * tv_loss

        gen_gradients = gen_tape.gradient(gen_loss, gen.trainable_variables)
        trans_gradients = trans_tape.gradient(gen_loss, trans.trainable_variables)
        disc_gradients = disc_tape.gradient(disc_loss, disc.trainable_variables)
        disc_gradients, _ = tf.clip_by_global_norm(disc_gradients, 15)

        gen_optimizer.apply_gradients(zip(gen_gradients, gen.trainable_variables))
        trans_optimizer.apply_gradients(zip(trans_gradients, trans.trainable_variables))
        disc_optimizer.apply_gradients(zip(disc_gradients, disc.trainable_variables))

        if iteration % args.log_interval == 0:

            logs_train = 'Training ' + procname + ', Iteration: {}th, Duration: {:.4f}, LOSSES: recons: {:.4f}, generator: {:.4f}, disc: {:.4f}'.format(
                iteration, time.time() - start, reconstruction_loss, generator_loss.numpy(), disc_loss.numpy())

            print(logs_train)

            logs = open(save_dir_test + '/' + 'train_process.txt', "a")
            logs.write(logs_train)
            logs.write('\n')
            logs.close()

            start = time.time()

        if iteration % args.save_interval == 0:
            # generate_images(trainA_full, trainA_full, genA2B_output_full, highA_output, save_dir_train, iteration)

            if not os.path.exists(model_save_dir):
                os.mkdir(model_save_dir)
            # checkpoint.save(file_prefix = checkpoint_prefix)
            disc.save_weights(model_save_dir + '/disc_' + str(iteration))
            gen.save_weights(model_save_dir + '/gen_' + str(iteration))
            trans.save_weights(model_save_dir + '/trans_' + str(iteration))
        
        if iteration % args.test_interval == 0:

            multi_test_save_all = []

            for i in range(args.test_size):

                test_input, test_output = next(test_dataset)

                multi_test_save_single = tf.reshape(tf.clip_by_value((test_input+1)/2, 0, 1), [args.load_img_size, args.load_img_size, 3]).numpy()

                # start_test = time.time()
                highs_test, lows_test = pyramid.split(test_input, args.level)
                generated_low = gen(lows_test[-1], training=False)
                generated_high = trans(highs_test, training=False)

                generated_full = merge_image(generated_low, generated_high)
                
                multi_test_save_single = np.vstack((multi_test_save_single, tf.reshape(tf.clip_by_value((generated_full+1)/2, 0, 1), [args.load_img_size, args.load_img_size, 3]).numpy()))
                multi_test_save_single = np.vstack((multi_test_save_single, tf.reshape(tf.clip_by_value((test_output+1)/2, 0, 1), [args.load_img_size, args.load_img_size, 3]).numpy()))

                if i == 0:
                    multi_test_save_all = multi_test_save_single
                else:
                    multi_test_save_all = np.hstack((multi_test_save_all, multi_test_save_single))

            image_path = os.path.join(save_dir_test, 'iteration_{}.jpg'.format(iteration))

            misc.imsave(image_path, multi_test_save_all)

        if iteration % args.test_psnr_interval == 0:

            psnr_all = []
            cost_all = []

            print('******************************')
            print('testing on all test images...')

            for i in range(test_all_size):

                test_input, test_output = next(test_dataset)

                start = time.time()

                highs_test, lows_test = pyramid.split(test_input, args.level)
                generated_low = gen(lows_test[-1], training=False)
                generated_high = trans(highs_test, training=False)

                
                generated_full = merge_image(generated_low, generated_high)

                cost = time.time() - start

                generated = tf.clip_by_value((generated_full+1)/2, 0, 1)
                target = tf.clip_by_value((test_output+1)/2, 0, 1)

                # generated = generated[0, 100:900, 100:900, :]
                # target = target[0, 100:900, 100:900, :]

                psnr_test = tf.image.psnr(generated, target, max_val=1.0)

                psnr_all.append(psnr_test[0].numpy())
                cost_all.append(cost)

            logs_test = 'test iteration: {}th, mean psnr: {:.4f}, avg inference time: {:.4f}'.format(iteration, np.mean(psnr_all), np.mean(cost_all))
            print(logs_test)
            print('****************************************************************************')
            logs = open(save_dir_test + '/' + 'train_process.txt', "a")
            logs.write(logs_test)
            logs.write('\n')
            logs.close()
def train(train_datasetA, train_datasetB, args):

    start = time.time()

    for iteration in range(args.start_iter, args.max_iterations):

        with tf.GradientTape() as gen_tape, tf.GradientTape(
        ) as disc_tape, tf.GradientTape() as trans_tape:
            try:
                trainA_full = next(train_datasetA)
                trainB_full = next(train_datasetB)
            except tf.errors.OutOfRangeError:
                print("Error, run out of data")
                break

            highs_A, lows_A = pyramid.split(trainA_full, args.level)

            genA2B_output = gen(lows_A[-1], training=True)

            high_with_low_A = tf.concat([highs_A[-1], lows_A[-2]], 3)

            highA_output = trans(high_with_low_A, training=True)

            split_list_A_reverse = []

            for i in range(1, args.level + 1):

                index = 0 - i

                high_transed_A = tf.multiply(highs_A[index], highA_output)
                split_list_A_reverse.append(high_transed_A)
                if i != args.level:
                    highA_output = up_sample(highA_output)

            split_list_A = []

            index = -1

            for _ in range(args.level):

                split_list_A.append(split_list_A_reverse[index])

                index = index - 1

            genA2B_output_full = merge_image(genA2B_output, split_list_A)

            disc_real_output = disc(trainB_full)

            disc_fake_output = disc(genA2B_output_full)

            disc_loss = discriminator_losssss('lsgan', disc_real_output,
                                              disc_fake_output)

            color_loss = color_losses(genA2B_output_full, trainB_full,
                                      args.batch_size)

            generator_loss = generator_losssss('lsgan', disc_fake_output)
            reconstruction_loss = tf.reduce_mean(
                tf.square(genA2B_output_full - trainA_full))
            tv_loss = total_variation_loss(genA2B_output_full)

            gen_loss = args.dis_lambda * generator_loss + args.recons_lambda * reconstruction_loss + args.color_lambda * color_loss + args.tv_lambda * tv_loss

        gen_gradients = gen_tape.gradient(gen_loss, gen.trainable_variables)
        trans_gradients = trans_tape.gradient(gen_loss,
                                              trans.trainable_variables)
        disc_gradients = disc_tape.gradient(disc_loss,
                                            disc.trainable_variables)
        disc_gradients, _ = tf.clip_by_global_norm(disc_gradients, 15)

        gen_optimizer.apply_gradients(
            zip(gen_gradients, gen.trainable_variables))
        trans_optimizer.apply_gradients(
            zip(trans_gradients, trans.trainable_variables))
        disc_optimizer.apply_gradients(
            zip(disc_gradients, disc.trainable_variables))

        if iteration % args.log_interval == 0:

            print(
                'Training ' + procname +
                ', Iteration: {}th, Duration: {:.4f}, LOSSES: recons: {:.4f}, generator: {:.4f}, disc: {:.4f}, color: {:.4f}, tv: {:.4f}'
                .format(iteration,
                        time.time() -
                        start, reconstruction_loss, generator_loss.numpy(),
                        disc_loss.numpy(), color_loss, tv_loss))

            start = time.time()

        if iteration % args.save_interval == 0:
            generate_images(trainA_full, trainA_full, genA2B_output_full,
                            highA_output, save_dir_train, iteration)

            if not os.path.exists(model_save_dir):
                os.mkdir(model_save_dir)
            # checkpoint.save(file_prefix = checkpoint_prefix)
            disc.save_weights(model_save_dir + '/disc_' + str(iteration))
            gen.save_weights(model_save_dir + '/gen_' + str(iteration))
            trans.save_weights(model_save_dir + '/trans_' + str(iteration))

        if iteration % args.test_interval == 0:

            input_images = []
            output_images = []

            test_full = next(test_datasetA)

            start_test = time.time()

            highs_test, lows_test = pyramid.split(test_full, args.level)
            generated_low = gen(lows_test[-1], training=False)
            high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
            generated_high = trans(high_with_low_test, training=False)

            split_list_reverse_test = []

            for i in range(1, args.level + 1):

                index = 0 - i

                high_transed_test = tf.multiply(highs_test[index],
                                                generated_high)
                split_list_reverse_test.append(high_transed_test)
                if i != args.level:
                    generated_high = up_sample(generated_high)

            split_list_test = []

            index = -1

            for _ in range(args.level):

                split_list_test.append(split_list_reverse_test[index])

                index = index - 1

            generated_full = merge_image(generated_low, split_list_test)

            test_time_1 = time.time() - start_test

            input_images.append(test_full)
            output_images.append(generated_full)

            test_full = next(test_datasetA)

            start_test = time.time()

            highs_test, lows_test = pyramid.split(test_full, args.level)
            generated_low = gen(lows_test[-1], training=False)
            high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
            generated_high = trans(high_with_low_test, training=False)

            split_list_reverse_test = []

            for i in range(1, args.level + 1):

                index = 0 - i

                high_transed_test = tf.multiply(highs_test[index],
                                                generated_high)
                split_list_reverse_test.append(high_transed_test)
                if i != args.level:
                    generated_high = up_sample(generated_high)

            split_list_test = []

            index = -1

            for _ in range(args.level):

                split_list_test.append(split_list_reverse_test[index])

                index = index - 1

            generated_full = merge_image(generated_low, split_list_test)

            test_time_2 = time.time() - start_test

            input_images.append(test_full)
            output_images.append(generated_full)

            generate_images(input_images[0], input_images[1], output_images[0],
                            output_images[1], save_dir_test, iteration)

            if iteration % 100 == 0:
                print('test time: ', test_time_1)
                print('test time: ', test_time_2)
Beispiel #8
0
def train(train_datasetA, train_datasetB, args):
    
    start = time.time()

    for iteration in range(args.start_iter, args.max_iterations):

        with tf.GradientTape() as genA2B_tape, tf.GradientTape() as genB2A_tape, \
                tf.GradientTape() as discA_tape, tf.GradientTape() as discB_tape, tf.GradientTape() as transA2B_tape, tf.GradientTape() as transB2A_tape:
            try:
                trainA_full = next(train_datasetA)
                trainB_full = next(train_datasetB)
            except tf.errors.OutOfRangeError:
                print("Error, run out of data")
                break

            highs_A, lows_A = pyramid.split(trainA_full, args.level)
            highs_B, lows_B = pyramid.split(trainB_full, args.level)
            
            genA2B_output = genA2B(lows_A[-1], training=True)
            genB2A_output = genB2A(lows_B[-1], training=True)

            high_with_low_A = tf.concat([highs_A[-1], lows_A[-2]], 3)
            high_with_low_B = tf.concat([highs_B[-1], lows_B[-2]], 3)

            # kernel = pyramid._binomial_kernel(tf.shape(input=lows_A[-1])[3], dtype=lows_A[-1].dtype)

            # low_up_before_A = pyramid._upsample(lows_A[-1], kernel)
            # low_up_after_A = pyramid._upsample(genA2B_output, kernel)
            # low_up_before_B = pyramid._upsample(lows_B[-1], kernel)
            # low_up_after_B = pyramid._upsample(genB2A_output, kernel)

            # high_with_low_A = tf.concat([highs_A[-1], low_up_before_A], 3)
            # high_with_low_B = tf.concat([highs_B[-1], low_up_before_B], 3)

            # high_with_low_A = tf.concat([high_with_low_A, low_up_after_A], 3)
            # high_with_low_B = tf.concat([high_with_low_B, low_up_after_B], 3)

            highA2B_output_o = transA2B(high_with_low_A, training=True)
            highB2A_output_o = transB2A(high_with_low_B, training=True)

            highA2B_output = highA2B_output_o
            highB2A_output = highB2A_output_o

            split_list_A_reverse = []
            split_list_B_reverse = []

            for i in range(1, args.level+1):

                index = 0 - i

                high_transed_A = tf.multiply(highs_A[index], highA2B_output)
                split_list_A_reverse.append(high_transed_A)
                highA2B_output = up_sample(highA2B_output)

                high_transed_B = tf.multiply(highs_B[index], highB2A_output)
                split_list_B_reverse.append(high_transed_B)
                highB2A_output = up_sample(highB2A_output)

            split_list_A = []
            split_list_B = []

            index = -1

            for _ in range(args.level):

                split_list_A.append(split_list_A_reverse[index])
                split_list_B.append(split_list_B_reverse[index])

                index = index - 1

            genA2B_output_full = merge_image(genA2B_output, split_list_A)
            genB2A_output_full = merge_image(genB2A_output, split_list_B)

            reconstructed_A_low = genB2A(genA2B_output, training=True)
            reconstructed_B_low = genA2B(genB2A_output, training=True)

            #################

            high_with_low_A_r = tf.concat([highA2B_output_o, lows_A[-2]], 3)
            high_with_low_B_r = tf.concat([highB2A_output_o, lows_B[-2]], 3)

            highA2B_output_r = transB2A(high_with_low_A_r, training=True)
            highB2A_output_r = transA2B(high_with_low_B_r, training=True)

            split_list_A_reverse = []
            split_list_B_reverse = []

            for i in range(1, args.level+1):

                index = 0 - i

                high_transed_A = tf.multiply(highs_A[index], highA2B_output_r)
                split_list_A_reverse.append(high_transed_A)
                highA2B_output_r = up_sample(highA2B_output_r)

                high_transed_B = tf.multiply(highs_B[index], highB2A_output_r)
                split_list_B_reverse.append(high_transed_B)
                highB2A_output_r = up_sample(highB2A_output_r)

            split_list_A = []
            split_list_B = []

            index = -1

            for _ in range(args.level):

                split_list_A.append(split_list_A_reverse[index])
                split_list_B.append(split_list_B_reverse[index])

                index = index - 1

            #################

            reconstructed_A_full = merge_image(reconstructed_A_low, split_list_A)
            reconstructed_B_full = merge_image(reconstructed_B_low, split_list_B)

            discA_real_output = discA(trainA_full)
            discB_real_output = discB(trainB_full)

            discA_fake_output = discA(genB2A_output_full)
            discB_fake_output = discB(genA2B_output_full)

            discA_loss = discriminator_losssss('lsgan', discA_real_output, discA_fake_output)
            discB_loss = discriminator_losssss('lsgan', discB_real_output, discB_fake_output)

            color_loss_A2B = color_losses(genA2B_output_full, trainB_full, args.batch_size)
            color_loss_B2A = color_losses(genB2A_output_full, trainA_full, args.batch_size)

            # content_loss_A2B = content_losses(genA2B_output_full, trainA_full, args.batch_size)
            # content_loss_B2A = content_losses(genB2A_output_full, trainB_full, args.batch_size)

            # gp_A = gradient_penalty(trainA_full, genB2A_output_full, discA)
            # gp_B = gradient_penalty(trainB_full, genA2B_output_full, discB)

            generatorA2B_loss = generator_losssss('lsgan', discB_fake_output)
            cycleA2B_loss = cycle_consistency_loss(trainA_full, trainB_full, reconstructed_A_full, reconstructed_B_full)
            
            generatorB2A_loss = generator_losssss('lsgan', discA_fake_output)
            cycleB2A_loss = cycle_consistency_loss(trainA_full, trainB_full, reconstructed_A_full, reconstructed_B_full)
            

            genA2B_loss = args.dis_lambda * generatorA2B_loss + args.cyc_lambda * cycleA2B_loss + args.color_lambda * color_loss_A2B
                          
            genB2A_loss = args.dis_lambda * generatorB2A_loss + args.cyc_lambda * cycleB2A_loss + args.color_lambda * color_loss_B2A
                          

        genA2B_gradients = genA2B_tape.gradient(genA2B_loss, genA2B.trainable_variables)
        genB2A_gradients = genB2A_tape.gradient(genB2A_loss, genB2A.trainable_variables)
        transA2B_gradients = transA2B_tape.gradient(genA2B_loss, transA2B.trainable_variables)
        transB2A_gradients = transB2A_tape.gradient(genB2A_loss, transB2A.trainable_variables)

        discA_gradients = discA_tape.gradient(discA_loss, discA.trainable_variables)
        discA_gradients, _ = tf.clip_by_global_norm(discA_gradients, 15)
        discB_gradients = discB_tape.gradient(discB_loss, discB.trainable_variables)
        discB_gradients, _ = tf.clip_by_global_norm(discB_gradients, 15)

        genA2B_optimizer.apply_gradients(zip(genA2B_gradients, genA2B.trainable_variables))
        genB2A_optimizer.apply_gradients(zip(genB2A_gradients, genB2A.trainable_variables))
        transA2B_optimizer.apply_gradients(zip(transA2B_gradients, transA2B.trainable_variables))
        transB2A_optimizer.apply_gradients(zip(transB2A_gradients, transB2A.trainable_variables))

        discA_optimizer.apply_gradients(zip(discA_gradients, discA.trainable_variables))
        discB_optimizer.apply_gradients(zip(discB_gradients, discB.trainable_variables))

        if iteration % args.log_interval == 0:

            print('Training ' + procname + ', Iteration: {}th, Duration: {:.4f}, LOSSES: cycle: {:.4f}, generator: {:.4f}, disc: {:.4f}, color: {:.4f}'.format(
                iteration, time.time() - start, cycleA2B_loss, generatorA2B_loss.numpy(), discA_loss.numpy(), color_loss_A2B))

            start = time.time()

        if iteration % args.save_interval == 0:
            generate_images(trainA_full, trainB_full, genA2B_output_full, genB2A_output_full, save_dir_train, iteration)

            # print('Time taken for iteration {} is {} sec'.format(iteration + 1, time.time() - start))

            if not os.path.exists(model_save_dir):
                os.mkdir(model_save_dir)
            # checkpoint.save(file_prefix = checkpoint_prefix)
            discA.save_weights(model_save_dir + '/discA_' + str(iteration))
            discB.save_weights(model_save_dir + '/discB_' + str(iteration))
            genA2B.save_weights(model_save_dir + '/genA2B_' + str(iteration))
            genB2A.save_weights(model_save_dir + '/genB2A_' + str(iteration))  
            transA2B.save_weights(model_save_dir + '/transA2B_' + str(iteration))
            transB2A.save_weights(model_save_dir + '/transB2A_' + str(iteration))  
        
        if iteration % args.test_interval == 0:

            input_images = []
            output_images = []

            test_full = next(test_datasetA)

            start_test = time.time()
            
            highs_test, lows_test = pyramid.split(test_full, args.level)
            generated_low = genA2B(lows_test[-1], training=False)
            high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
            generated_high = transA2B(high_with_low_test, training=False)

            split_list_reverse_test = []

            for i in range(1, args.level+1):

                index = 0 - i

                high_transed_test = tf.multiply(highs_test[index], generated_high)
                split_list_reverse_test.append(high_transed_test)
                generated_high = up_sample(generated_high)

            split_list_test = []

            index = -1

            for _ in range(args.level):

                split_list_test.append(split_list_reverse_test[index])

                index = index - 1


            generated_full = merge_image(generated_low, split_list_test)

            test_time_1 = time.time() - start_test

            input_images.append(test_full)
            output_images.append(generated_full)

            test_full = next(test_datasetB)

            start_test = time.time()
            
            highs_test, lows_test = pyramid.split(test_full, args.level)
            generated_low = genB2A(lows_test[-1], training=False)
            high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
            generated_high = transB2A(high_with_low_test, training=False)

            split_list_reverse_test = []

            for i in range(1, args.level+1):

                index = 0 - i

                high_transed_test = tf.multiply(highs_test[index], generated_high)
                split_list_reverse_test.append(high_transed_test)
                generated_high = up_sample(generated_high)

            split_list_test = []

            index = -1

            for _ in range(args.level):

                split_list_test.append(split_list_reverse_test[index])

                index = index - 1

            generated_full = merge_image(generated_low, split_list_test)

            test_time_2 = time.time() - start_test

            input_images.append(test_full)
            output_images.append(generated_full)

            generate_images(input_images[0], input_images[1], output_images[0], output_images[1], save_dir_test, iteration)

            if iteration % 100 == 0:
                print('test time: ', test_time_1)
                print('test time: ', test_time_2)
def train(args, lsgan=True):

    multi_test_save_all = []

    for i in range(args.test_size):

        test_input_A = next(test_datasetA)
        test_input_B = next(test_datasetB)

        # multi_test_save_single = tf.reshape(tf.clip_by_value((test_input+1)/2, 0, 1), [args.load_img_size, args.load_img_size, 3]).numpy()

        start = time.time()

        # test_input = cv2.imread('../data/content3.png')
        # test_input = tf.reshape(test_input, [1, 546,366,3])

        highs_test_A, lows_test_A = pyramid.split(test_input_A, args.level)
        highs_test_B, lows_test_B = pyramid.split(test_input_B, args.level)

        genA2B_output = gen(lows_test_A[-1], training=True)

        high_with_low_A = tf.concat(
            [highs_test_A[-1], up_sample(genA2B_output)], 3)
        high_with_low_A = tf.concat(
            [high_with_low_A, up_sample(lows_test_A[-1])], 3)

        mask_0, mask_1, mask_2, mask_3 = trans(high_with_low_A, training=True)

        cv2.imwrite('mask0.png', mask_0.numpy()[0, :, :, :] * 255.0)
        cv2.imwrite('mask1.png', mask_1.numpy()[0, :, :, :] * 255.0)
        cv2.imwrite('mask2.png', mask_2.numpy()[0, :, :, :] * 255.0)
        cv2.imwrite('mask3.png', mask_3.numpy()[0, :, :, :] * 255.0)

        misc.imsave('mask0.png', mask_0.numpy()[0, :, :, :])
        misc.imsave('mask1.png', mask_1.numpy()[0, :, :, :])
        misc.imsave('mask2.png', mask_2.numpy()[0, :, :, :])
        misc.imsave('mask3.png', mask_3.numpy()[0, :, :, :])

        up_last = up_sample(highs_test_B[-1])
        up_2last = up_sample(highs_test_B[-2])
        up_3last = up_sample(highs_test_B[-3])
        up_4last = up_sample(highs_test_B[-4])

        # difference0 = tf.reduce_mean(tf.square(up_4last - highs_test_B[0])).numpy()
        difference1 = tf.reduce_mean(tf.square(up_3last -
                                               highs_test_B[0])).numpy()
        difference2 = tf.reduce_mean(tf.square(up_2last -
                                               highs_test_B[1])).numpy()
        difference3 = tf.reduce_mean(tf.square(up_last -
                                               highs_test_B[2])).numpy()
        # difference4 = tf.reduce_mean(tf.square(highs_test_A[4] - highs_test_B[4])).numpy()

        # print('high_0: ', difference0)
        print('high_1: ', difference1)
        print('high_2: ', difference2)
        print('high_3: ', difference3)

        misc.imsave('up_last.png', up_last.numpy()[0, :, :, :])
        misc.imsave('up_2last.png', up_2last.numpy()[0, :, :, :])
        misc.imsave('up_3last.png', up_3last.numpy()[0, :, :, :])
        misc.imsave('up_4last.png', up_4last.numpy()[0, :, :, :])

        misc.imsave('content_low_0.png', lows_test_A[0].numpy()[0, :, :, :])
        misc.imsave('content_low_1.png', lows_test_A[1].numpy()[0, :, :, :])
        misc.imsave('content_low_2.png', lows_test_A[2].numpy()[0, :, :, :])
        misc.imsave('content_low_3.png', lows_test_A[3].numpy()[0, :, :, :])
        # misc.imsave('content_low_4.png', lows_test_A[4].numpy()[0,:,:,:])
        # misc.imsave('content_low_5.png', lows_test[5].numpy()[0,:,:,:])
        misc.imsave('content_high_0.png', highs_test_A[0].numpy()[0, :, :, :])
        misc.imsave('content_high_1.png', highs_test_A[1].numpy()[0, :, :, :])
        misc.imsave('content_high_2.png', highs_test_A[2].numpy()[0, :, :, :])
        misc.imsave('content_high_3.png', highs_test_A[3].numpy()[0, :, :, :])
        # misc.imsave('content_high_4.png', highs_test_A[4].numpy()[0,:,:,:])
        # misc.imsave('content_high_5.png', highs_test[5].numpy()[0,:,:,:])
        misc.imsave('content.png', test_input_A.numpy()[0, :, :, :])

        misc.imsave('predict_low_0.png', lows_test_B[0].numpy()[0, :, :, :])
        misc.imsave('predict_low_1.png', lows_test_B[1].numpy()[0, :, :, :])
        misc.imsave('predict_low_2.png', lows_test_B[2].numpy()[0, :, :, :])
        misc.imsave('predict_low_3.png', lows_test_B[3].numpy()[0, :, :, :])
        # misc.imsave('predict_low_4.png', lows_test_B[4].numpy()[0,:,:,:])
        # misc.imsave('predict_low_5.png', lows_test[5].numpy()[0,:,:,:])
        misc.imsave('predict_high_0.png', highs_test_B[0].numpy()[0, :, :, :])
        misc.imsave('predict_high_1.png', highs_test_B[1].numpy()[0, :, :, :])
        misc.imsave('predict_high_2.png', highs_test_B[2].numpy()[0, :, :, :])
        misc.imsave('predict_high_3.png', highs_test_B[3].numpy()[0, :, :, :])
        # misc.imsave('predict_high_4.png', highs_test_B[4].numpy()[0,:,:,:])
        # misc.imsave('predict_high_5.png', highs_test[5].numpy()[0,:,:,:])
        misc.imsave('predict.png', test_input_B.numpy()[0, :, :, :])

        difference0 = tf.reduce_mean(
            tf.square(highs_test_A[0] - highs_test_B[0])).numpy()
        difference1 = tf.reduce_mean(
            tf.square(highs_test_A[1] - highs_test_B[1])).numpy()
        difference2 = tf.reduce_mean(
            tf.square(highs_test_A[2] - highs_test_B[2])).numpy()
        difference3 = tf.reduce_mean(
            tf.square(highs_test_A[3] - highs_test_B[3])).numpy()
        # difference4 = tf.reduce_mean(tf.square(highs_test_A[4] - highs_test_B[4])).numpy()

        print('high_0: ', difference0)
        print('high_1: ', difference1)
        print('high_2: ', difference2)
        print('high_3: ', difference3)
        # print('high_4: ', difference4)

        difference0 = tf.reduce_mean(tf.square(lows_test_A[0] -
                                               lows_test_B[0])).numpy()
        difference1 = tf.reduce_mean(tf.square(lows_test_A[1] -
                                               lows_test_B[1])).numpy()
        difference2 = tf.reduce_mean(tf.square(lows_test_A[2] -
                                               lows_test_B[2])).numpy()
        difference3 = tf.reduce_mean(tf.square(lows_test_A[3] -
                                               lows_test_B[3])).numpy()
        # difference4 = tf.reduce_mean(tf.square(lows_test_A[4] - lows_test_B[4])).numpy()

        print('low_0: ', difference0)
        print('low_1: ', difference1)
        print('low_2: ', difference2)
        print('low_3: ', difference3)
        # print('low_4: ', difference4)

        difference = tf.reduce_mean(
            tf.square(test_input_A[0] - test_input_B[0])).numpy()

        print('original: ', difference)

        from PIL import Image
        import numpy as np
        import matplotlib.pyplot as plt

        src = Image.open('content.png')
        paths = 'content_h.png'
        plot_fig(src, paths)
        src = Image.open('predict.png')
        paths = 'predict_h.png'
        plot_fig(src, paths)
        src = Image.open('content_high_0.png')
        paths = 'content_high_0_h.png'
        plot_fig(src, paths)
        src = Image.open('predict_high_0.png')
        paths = 'predict_high_0_h.png'
        plot_fig(src, paths)
        src = Image.open('content_high_1.png')
        paths = 'content_high_1_h.png'
        plot_fig(src, paths)
        src = Image.open('predict_high_1.png')
        paths = 'predict_high_1_h.png'
        plot_fig(src, paths)
        src = Image.open('content_high_2.png')
        paths = 'content_high_2_h.png'
        plot_fig(src, paths)
        src = Image.open('predict_high_2.png')
        paths = 'predict_high_2_h.png'
        plot_fig(src, paths)
        src = Image.open('content_high_3.png')
        paths = 'content_high_3_h.png'
        plot_fig(src, paths)
        src = Image.open('predict_high_3.png')
        paths = 'predict_high_3_h.png'
        plot_fig(src, paths)
        src = Image.open('content_low_1.png')
        paths = 'content_low_1_l.png'
        plot_fig(src, paths)
        src = Image.open('predict_low_1.png')
        paths = 'predict_low_1_l.png'
        plot_fig(src, paths)
        src = Image.open('content_low_2.png')
        paths = 'content_low_2_l.png'
        plot_fig(src, paths)
        src = Image.open('predict_low_2.png')
        paths = 'predict_low_2_l.png'
        plot_fig(src, paths)
        src = Image.open('content_low_3.png')
        paths = 'content_low_3_l.png'
        plot_fig(src, paths)
        src = Image.open('predict_low_3.png')
        paths = 'predict_low_3_l.png'
        plot_fig(src, paths)

        generated_low = genA2B(lows_test[-1], training=False)
        high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
        generated_high = transA2B(high_with_low_test, training=False)

        split_list_reverse_test = []

        for i_ in range(1, args.level + 1):

            index = 0 - i_

            high_transed_test = tf.multiply(highs_test[index], generated_high)
            split_list_reverse_test.append(high_transed_test)
            generated_high = up_sample(generated_high)

        split_list_test = []

        index = -1

        for _ in range(args.level):

            split_list_test.append(split_list_reverse_test[index])

            index = index - 1

        generated_full = merge_image(generated_low, split_list_test)

        cost = time.time() - start

        print(cost)

        multi_test_save_single = np.vstack(
            (multi_test_save_single,
             tf.reshape(tf.clip_by_value((generated_full + 1) / 2, 0, 1),
                        [args.load_img_size, args.load_img_size, 3]).numpy()))

        if i == 0:
            multi_test_save_all = multi_test_save_single
        else:
            multi_test_save_all = np.hstack(
                (multi_test_save_all, multi_test_save_single))

    image_path = os.path.join(save_dir_test,
                              'iteration_{}.jpg'.format(args.start_iter))

    misc.imsave(image_path, multi_test_save_all)
Beispiel #10
0
def train(args, lsgan=True):
    
    multi_test_save_all = []

    for i in range(args.test_size):

        test_input, test_output = next(test_dataset)

        multi_test_save_single = tf.reshape(tf.clip_by_value((test_input+1)/2, 0, 1), [args.load_img_size, args.load_img_size, 3]).numpy()

        highs_test, lows_test = pyramid.split(test_input, args.level)
        generated_low = genA2B(lows_test[-1], training=False)
        high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
        generated_high = transA2B(high_with_low_test, training=False)

        split_list_reverse_test = []

        for i_ in range(1, args.level+1):

            index = 0 - i_

            high_transed_test = tf.multiply(highs_test[index], generated_high)
            split_list_reverse_test.append(high_transed_test)
            generated_high = up_sample(generated_high)

        split_list_test = []

        index = -1

        for _ in range(args.level):

            split_list_test.append(split_list_reverse_test[index])

            index = index - 1

        generated_full = merge_image(generated_low, split_list_test)
        
        multi_test_save_single = np.vstack((multi_test_save_single, tf.reshape(tf.clip_by_value((generated_full+1)/2, 0, 1), [args.load_img_size, args.load_img_size, 3]).numpy()))
        multi_test_save_single = np.vstack((multi_test_save_single, tf.reshape(tf.clip_by_value((test_output+1)/2, 0, 1), [args.load_img_size, args.load_img_size, 3]).numpy()))

        if i == 0:
            multi_test_save_all = multi_test_save_single
        else:
            multi_test_save_all = np.hstack((multi_test_save_all, multi_test_save_single))

    image_path = os.path.join(save_dir_test, 'iteration_{}.jpg'.format(args.start_iter))

    # misc.imsave(image_path, multi_test_save_all)

    psnr_all = []
    cost_all = []

    print('******************************')
    print('testing on all test images...')

    for i in range(test_all_size):

        test_input, test_output = next(test_dataset)

        start = time.time()

        highs_test, lows_test = pyramid.split(test_input, args.level)
        generated_low = genA2B(lows_test[-1], training=False)
        high_with_low_test = tf.concat([highs_test[-1], lows_test[-2]], 3)
        generated_high = transA2B(high_with_low_test, training=False)

        split_list_reverse_test = []

        for i_ in range(1, args.level+1):

            index = 0 - i_

            high_transed_test = tf.multiply(highs_test[index], generated_high)
            split_list_reverse_test.append(high_transed_test)
            generated_high = up_sample(generated_high)

        split_list_test = []

        index = -1

        for _ in range(args.level):

            split_list_test.append(split_list_reverse_test[index])

            index = index - 1


        generated_full = merge_image(generated_low, split_list_test)

        cost = time.time() - start

        generated = np.clip((generated_full+1)/2, 0, 1)
        target = np.clip((test_output+1)/2, 0, 1)

        # generated = generated[0, 100:900, 100:900, :]
        # target = target[0, 100:900, 100:900, :]

        psnr_test = tf.image.psnr(generated, target, max_val=1.0)

        psnr_all.append(psnr_test[0].numpy())
        cost_all.append(cost)

        print('image {}, psnr: {:.4f}, duration: {:.4f}'.format(i, psnr_test[0].numpy(), cost))

    logs_test = 'test iteration: {}th, mean psnr: {:.4f}, avg inference time: {:.4f}'.format(args.start_iter, np.mean(psnr_all), np.mean(cost_all))
    print(logs_test)
    print('****************************************************************************')