Пример #1
0
def predict(mean=30.0, std=50.0):
    # load and normalize data
    if mean == 0.0 and std == 1.0:
        imgs_train, _, _ = load_data(train_images_path)
        mean = np.mean(imgs_train)
        std = np.std(imgs_train)

    imgs_test, imgs_mask_test, names_test = load_data(test_images_path)
    original_imgs_test = imgs_test.astype(np.uint8)

    imgs_test -= mean
    imgs_test /= std

    # load model with weights
    model = unet()
    model.load_weights(weights_path)

    # make predictions
    imgs_mask_pred = model.predict(imgs_test, verbose=1)

    # save to mat file for further processing
    if not os.path.exists(predictions_path):
        os.mkdir(predictions_path)

    matdict = {
        'pred': imgs_mask_pred,
        'image': original_imgs_test,
        'mask': imgs_mask_test,
        'name': names_test
    }
    savemat(os.path.join(predictions_path, 'predictions.mat'), matdict)
Пример #2
0
def train():
    imgs_train, imgs_mask_train, _ = load_data(train_images_path)

    mean = np.mean(imgs_train)
    std = np.std(imgs_train)

    imgs_train -= mean
    imgs_train /= std

    imgs_valid, imgs_mask_valid, _ = load_data(valid_images_path)

    imgs_valid -= mean
    imgs_valid /= std

    imgs_train, imgs_mask_train = oversample(imgs_train, imgs_mask_train)

    model = unet()
    if os.path.exists(init_weights_path):
        model.load_weights(init_weights_path)

    optimizer = Adam(lr=base_lr)
    model.compile(optimizer=optimizer,
                  loss=dice_coef_loss,
                  metrics=[dice_coef])

    if not os.path.exists(log_path):
        os.mkdir(log_path)

    training_log = TensorBoard(log_dir=log_path)

    model.fit(
        imgs_train,
        imgs_mask_train,
        validation_data=(imgs_valid, imgs_mask_valid),
        batch_size=batch_size,
        epochs=epochs,
        shuffle=True,
        callbacks=[training_log],
    )

    if not os.path.exists(weights_path):
        os.mkdir(weights_path)
    model.save_weights(
        os.path.join(weights_path, "weights_{}.h5".format(epochs)))
Пример #3
0
def test():
    global res, img_y, mask_arrary
    epoch_dice = 0
    with torch.no_grad():
        dataloaders = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=True,
                                 num_workers=0)
        for x, mask in dataloaders:
            id = x[1:]  # ('1026',), ('10018',)]先病人号后片号
            x = x[0].to(device)
            y = unet(x)
            mask_arrary = mask[1].cpu().squeeze(0).detach().numpy()
            img_y = torch.squeeze(y).cpu().numpy()
            img_y[img_y >= rate] = 1
            img_y[img_y < rate] = 0
            img_y = img_y * 255
            epoch_dice += dice_loss.dice(img_y, mask_arrary)
            # cv.imwrite(f'data/out/{mask[0][0]}-result.png', img_y, (cv.IMWRITE_PNG_COMPRESSION, 0))
        print('test dice %f' % (epoch_dice / len(dataloaders)))
        res['dice'].append(epoch_dice / len(dataloaders))
def main(args):
    x = tf.placeholder(dtype=tf.float32,shape=[1,args.full_size1,args.full_size2,3],name='input_ori')
    x_low = tf.placeholder(dtype=tf.float32, shape=[1, args.low_size, args.low_size, 3], name='input_low')
    input_ori = np.random.randn(1,args.full_size1,args.full_size2,3)
    input_low = np.random.randn(1, args.low_size, args.low_size, 3)

    # image = tf.random_normal(shape=[1,args.full_size1,args.full_size2,3])
    # image_low = tf.random_normal(shape=[1,args.low_size,args.low_size,3])
    out = unet(x_low)


    config = None
    with tf.Session(config) as sess:
        sess.run(tf.global_variables_initializer())
        time_start = int(round(time.time() * 1000))
        for i in range(args.iters):
            output = sess.run(out,feed_dict={x_low:input_low})
            # output = sess.run(out, feed_dict={ x: input_ori})
            # out = unet(image,image_low)
        time_end = int(round(time.time()*1000))
        print("ms:%.1f ms"%((time_end-time_start)/args.iters))
Пример #5
0
def train():
    global res
    dataloaders = DataLoader(train_dataset,
                             batch_size=2,
                             shuffle=True,
                             num_workers=0)
    for epoch in range(epochs):
        dt_size = len(dataloaders.dataset)
        epoch_loss, epoch_dice = 0, 0
        step = 0
        for x, y in dataloaders:
            id = x[1:]
            step += 1
            x = x[0].to(device)
            y = y[1].to(device)
            print(x.size())
            print(y.size())
            optimizer.zero_grad()
            outputs = unet(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            # dice
            # a = outputs.cpu().detach().squeeze(1).numpy()
            # a[a >= rate] = 1
            # a[a < rate] = 0
            # b = y.cpu().detach().numpy()
            # dice = dice_loss.dice(a, b)
            # epoch_loss += float(loss.item())
            # epoch_dice += dice

            if step % 100 == 0:
                res['epoch'].append((epoch + 1) * step)
                res['loss'].append(loss.item())
                print(
                    "epoch%d step%d/%d train_loss:%0.3f" %
                    (epoch, step,
                     (dt_size - 1) // dataloaders.batch_size + 1, loss.item()),
                    end='')
                test()
    #  print("epoch %d loss:%0.3f,dice %f" % (epoch, epoch_loss / step, epoch_dice / step))
    plt.plot(res['epoch'], np.squeeze(res['cost']), label='Train cost')
    plt.ylabel('cost')
    plt.xlabel('epochs')
    plt.title("Model: train cost")
    plt.legend()

    plt.plot(res['epoch'],
             np.squeeze(res),
             label='Validation cost',
             color='#FF9966')
    plt.ylabel('loss')
    plt.xlabel('epochs')
    plt.title("Model:validation  loss")
    plt.legend()

    plt.savefig("examples.jpg")

    # torch.save(unet, 'unet.pkl')
    # model = torch.load('unet.pkl')
    test()
Пример #6
0
def predict(mean=20.0, std=43.0):
    # load and normalize data
    if mean == 0.0 and std == 1.0:
        imgs_train, _, _ = load_data(train_images_path)
        mean = np.mean(imgs_train)
        std = np.std(imgs_train)

    imgs_test, imgs_mask_test, names_test = load_data(test_images_path)
    original_imgs_test = imgs_test.astype(np.uint8)

    imgs_test -= mean
    imgs_test /= std

    # load model with weights
    model = unet()
    model.load_weights(weights_path)

    # make predictions
    imgs_mask_pred = model.predict(imgs_test, verbose=1)

    # save to mat file for further processing
    if not os.path.exists(predictions_path):
        os.mkdir(predictions_path)

    matdict = {
        "pred": imgs_mask_pred,
        "image": original_imgs_test,
        "mask": imgs_mask_test,
        "name": names_test,
    }
    savemat(os.path.join(predictions_path, "predictions.mat"), matdict)

    # save images with segmentation and ground truth mask overlay
    for i in range(len(imgs_test)):
        pred = imgs_mask_pred[i]
        image = original_imgs_test[i]
        mask = imgs_mask_test[i]

        # segmentation mask is for the middle slice
        image_rgb = gray2rgb(image[:, :, 1])

        # prediction contour image
        pred = (np.round(pred[:, :, 0]) * 255.0).astype(np.uint8)
        pred, contours, _ = cv2.findContours(pred, cv2.RETR_EXTERNAL,
                                             cv2.CHAIN_APPROX_NONE)
        pred = np.zeros(pred.shape)
        cv2.drawContours(pred, contours, -1, (255, 0, 0), 1)

        # ground truth contour image
        mask = (np.round(mask[:, :, 0]) * 255.0).astype(np.uint8)
        mask, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
                                             cv2.CHAIN_APPROX_NONE)
        mask = np.zeros(mask.shape)
        cv2.drawContours(mask, contours, -1, (255, 0, 0), 1)

        # combine image with contours
        pred_rgb = np.array(image_rgb)
        annotation = pred_rgb[:, :, 1]
        annotation[np.maximum(pred, mask) == 255] = 0
        pred_rgb[:, :, 0] = pred_rgb[:, :, 1] = pred_rgb[:, :, 2] = annotation
        pred_rgb[:, :, 2] = np.maximum(pred_rgb[:, :, 2], mask)
        pred_rgb[:, :, 0] = np.maximum(pred_rgb[:, :, 0], pred)

        imsave(os.path.join(predictions_path, names_test[i] + ".png"),
               pred_rgb)

    return imgs_mask_test, imgs_mask_pred, names_test
Пример #7
0
def train():
    imgs_train, imgs_mask_train, imgs_names_train = load_data(
        train_images_path, num_classes)

    mean = np.mean(imgs_train)
    std = np.std(imgs_train)

    imgs_train -= mean
    imgs_train /= std

    imgs_valid, imgs_mask_valid, _ = load_data(valid_images_path, num_classes)

    imgs_valid -= mean
    imgs_valid /= std

    #define the model
    model = unet(num_classes)
    #model dilated convolutions
    #model = get_frontend(imageDim,imageDim, num_classes)
    #model = get_dilation_model_unet(imageDim,imageDim, num_classes)

    if os.path.exists(init_weights_path):
        model.load_weights(init_weights_path)

    optimizer = Adam(lr=base_lr)
    model.compile(
        optimizer=optimizer,
        #loss='categorical_crossentropy',
        #metrics=['accuracy', dice_coef])
        loss=dice_coef_loss,
        metrics=[dice_coef])

    if not os.path.exists(log_path):
        os.mkdir(log_path)

    save_model = ModelCheckpoint(filepath=os.path.join(
        weights_path, "weights_test_{epoch:03d}.h5"),
                                 period=50)
    training_log = TensorBoard(log_dir=log_path)

    #Data Augmentation
    datagen = ImageDataGenerator(rotation_range=10,
                                 horizontal_flip=True,
                                 width_shift_range=0.2,
                                 height_shift_range=0.2,
                                 brightness_range=[0.7, 1])
    datagen_flow = datagen.flow(imgs_train,
                                imgs_mask_train,
                                batch_size=batch_size // 2)
    datagen2 = ImageDataGenerator(rotation_range=0)
    datagen2_flow = datagen2.flow(imgs_train,
                                  imgs_mask_train,
                                  batch_size=batch_size // 2)

    #need the len(imgs)*2 because I data augment it to twice as many
    model.fit_generator(Train_datagen(datagen_flow, datagen2_flow),
                        validation_data=(imgs_valid, imgs_mask_valid),
                        steps_per_epoch=(len(imgs_train) * 2) // batch_size,
                        epochs=epochs,
                        shuffle=True,
                        callbacks=[training_log, save_model])

    if not os.path.exists(weights_path):
        os.mkdir(weights_path)
    model.save_weights(
        os.path.join(weights_path, 'weights_test_1_{}.h5'.format(epochs)))
Пример #8
0
def main(args):
    # loading training and test data
    logger.info("Loading test data...")
    test_data, test_answ = load_test_data(args.dataset, args.dataset_dir, args.test_size, args.patch_size)
    logger.info("Test data was loaded\n")

    logger.info("Loading training data...")
    train_data, train_answ = load_batch(args.dataset, args.dataset_dir, args.train_size, args.patch_size)
    logger.info("Training data was loaded\n")

    TEST_SIZE = test_data.shape[0]
    num_test_batches = int(test_data.shape[0] / args.batch_size)

    # defining system architecture
    with tf.Graph().as_default(), tf.Session() as sess:

        # placeholders for training data
        phone_ = tf.placeholder(tf.float32, [None, args.patch_size])
        phone_image = tf.reshape(phone_, [-1, args.patch_height, args.patch_width, 3])

        dslr_ = tf.placeholder(tf.float32, [None, args.patch_size])
        dslr_image = tf.reshape(dslr_, [-1, args.patch_height, args.patch_width, 3])

        adv_ = tf.placeholder(tf.float32, [None, 1])
        enhanced = unet(phone_image)
        [w, h, d] = enhanced.get_shape().as_list()[1:]

        # # learning rate exponential_decay
        # global_step = tf.Variable(0)
        # learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps=args.train_size / args.batch_size, decay_rate=0.98, staircase=True)

        ## loss introduce
        '''
        content loss three ways : 
        1. vgg_loss: mat model load;
        2. vgg_loss: npy model load;
        3. iqa model(meon_loss): feature and scores
        '''
        # vgg = vgg19_loss.Vgg19(vgg_path=args.pretrain_weights) #  # load vgg models
        # vgg_content = 2000*tf.reduce_mean(tf.sqrt(tf.reduce_sum(
        #     tf.square((vgg.extract_feature(enhanced) - vgg.extract_feature(dslr_image))))) / (w * h * d))
        # # loss_content = multi_content_loss(args.pretrain_weights, enhanced, dslr_image, args.batch_size) # change another way

        # meon loss
        # with tf.variable_scope('meon_loss') as scope: # load ckpt is not conveient.
        MEON_evaluate_model, loss_content = meon_loss(dslr_image, enhanced)

        loss_texture, discim_accuracy = texture_loss(enhanced, dslr_image, args.patch_width, args.patch_height, adv_)
        loss_discrim = -loss_texture

        loss_color = color_loss(enhanced, dslr_image, args.batch_size)
        loss_tv = variation_loss(enhanced, args.patch_width, args.patch_height, args.batch_size)

        loss_psnr = PSNR(enhanced, dslr_image)
        loss_ssim = MultiScaleSSIM(enhanced, dslr_image)

        loss_generator = args.w_content * loss_content + args.w_texture * loss_texture + args.w_tv * loss_tv + 1000 * (
                    1 - loss_ssim) + args.w_color * loss_color

        # optimize parameters of image enhancement (generator) and discriminator networks
        generator_vars = [v for v in tf.global_variables() if v.name.startswith("generator")]
        discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("discriminator")]
        meon_vars = [v for v in tf.global_variables() if v.name.startswith("conv") or v.name.startswith("subtask")]

        # train_step_gen = tf.train.AdamOptimizer(args.learning_rate).minimize(loss_generator, var_list=generator_vars)
        # train_step_disc = tf.train.AdamOptimizer(args.learning_rate).minimize(loss_discrim, var_list=discriminator_vars)

        train_step_gen = tf.train.AdamOptimizer(5e-5).minimize(loss_generator, var_list=generator_vars)
        train_step_disc = tf.train.AdamOptimizer(5e-5).minimize(loss_discrim, var_list=discriminator_vars)

        saver = tf.train.Saver(var_list=generator_vars, max_to_keep=100)
        meon_saver = tf.train.Saver(var_list=meon_vars)

        logger.info('Initializing variables')
        sess.run(tf.global_variables_initializer())
        logger.info('Training network')
        train_loss_gen = 0.0
        train_acc_discrim = 0.0
        all_zeros = np.reshape(np.zeros((args.batch_size, 1)), [args.batch_size, 1])
        test_crops = test_data[np.random.randint(0, TEST_SIZE, 5), :]  # choose five images to visual

        # summary ,add the scalar you want to see
        tf.summary.scalar('loss_generator', loss_generator),
        tf.summary.scalar('loss_content', loss_content),
        tf.summary.scalar('loss_color', loss_color),
        tf.summary.scalar('loss_texture', loss_texture),
        tf.summary.scalar('loss_tv', loss_tv),
        tf.summary.scalar('discim_accuracy', discim_accuracy),
        tf.summary.scalar('psnr', loss_psnr),
        tf.summary.scalar('ssim', loss_ssim),
        tf.summary.scalar('learning_rate', args.learning_rate),
        merge_summary = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(os.path.join(args.tesorboard_logs_dir, 'train', args.exp_name), sess.graph,
                                             filename_suffix=args.exp_name)
        test_writer = tf.summary.FileWriter(os.path.join(args.tesorboard_logs_dir, 'test', args.exp_name), sess.graph,
                                            filename_suffix=args.exp_name)
        tf.global_variables_initializer().run()

        '''load ckpt models'''
        ckpt = tf.train.get_checkpoint_state(args.checkpoint_dir)
        start_i = 0
        if ckpt and ckpt.model_checkpoint_path:
            logger.info('loading checkpoint:' + ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
            import re
            start_i = int(re.findall("_(\d+).ckpt", ckpt.model_checkpoint_path)[0])
        MEON_evaluate_model.initialize(sess, meon_saver,
                                       args.meod_ckpt_path)  # initialize with anohter model pretrained weights

        '''start training...'''
        for i in range(start_i, args.iter_max):

            iter_start = time.time()
            # train generator
            idx_train = np.random.randint(0, args.train_size, args.batch_size)
            phone_images = train_data[idx_train]
            dslr_images = train_answ[idx_train]

            [loss_temp, temp] = sess.run([loss_generator, train_step_gen],
                                         feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: all_zeros})
            train_loss_gen += loss_temp / args.eval_step

            # train discriminator
            idx_train = np.random.randint(0, args.train_size, args.batch_size)

            # generate image swaps (dslr or enhanced) for discriminator
            swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1])

            phone_images = train_data[idx_train]
            dslr_images = train_answ[idx_train]
            # sess.run(train_step_disc)=train_step_disc.compute_gradients(loss,var)+train_step_disc.apply_gradients(var) @20190105
            [accuracy_temp, temp] = sess.run([discim_accuracy, train_step_disc],
                                             feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
            train_acc_discrim += accuracy_temp / args.eval_step

            if i % args.summary_step == 0:
                # summary intervals
                # enhance_f1_, enhance_f2_, enhance_s_, vgg_content_ = sess.run([enhance_f1, enhance_f2, enhance_s,vgg_content],
                #                          feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                # loss_content1_, loss_content2_, loss_content3_ = sess.run([loss_content1,loss_content2,loss_content3],
                #                          feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                # print("-----------------------------------------------")
                # print(enhance_f1_, enhance_f2_, enhance_s_,vgg_content_,loss_content1_, loss_content2_, loss_content3_)
                # print("-----------------------------------------------")
                train_summary = sess.run(merge_summary,
                                         feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                train_writer.add_summary(train_summary, i)

            if i % args.eval_step == 0:
                # test generator and discriminator CNNs
                test_losses_gen = np.zeros((1, 7))
                test_accuracy_disc = 0.0

                for j in range(num_test_batches):
                    be = j * args.batch_size
                    en = (j + 1) * args.batch_size

                    swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1])
                    phone_images = test_data[be:en]
                    dslr_images = test_answ[be:en]

                    [enhanced_crops, accuracy_disc, losses] = sess.run([enhanced, discim_accuracy, \
                                                                        [loss_generator, loss_content, loss_color,
                                                                         loss_texture, loss_tv, loss_psnr, loss_ssim]], \
                                                                       feed_dict={phone_: phone_images,
                                                                                  dslr_: dslr_images, adv_: swaps})

                    test_losses_gen += np.asarray(losses) / num_test_batches
                    test_accuracy_disc += accuracy_disc / num_test_batches

                logs_disc = "step %d/%d, %s | discriminator accuracy | train: %.4g, test: %.4g" % \
                            (i, args.iter_max, args.dataset, train_acc_discrim, test_accuracy_disc)
                logs_gen = "generator losses | train: %.4g, test: %.4g | content: %.4g, color: %.4g, texture: %.4g, tv: %.4g | psnr: %.4g, ssim: %.4g\n" % \
                           (train_loss_gen, test_losses_gen[0][0], test_losses_gen[0][1], test_losses_gen[0][2],
                            test_losses_gen[0][3], test_losses_gen[0][4], test_losses_gen[0][5], test_losses_gen[0][6])

                logger.info(logs_disc)
                logger.info(logs_gen)

                test_summary = sess.run(merge_summary,
                                        feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                test_writer.add_summary(test_summary, i)

                # save visual results for several test image crops
                if args.save_visual_result:
                    enhanced_crops = sess.run(enhanced,
                                              feed_dict={phone_: test_crops, dslr_: dslr_images, adv_: all_zeros})
                    idx = 0
                    for crop in enhanced_crops:
                        before_after = np.hstack(
                            (np.reshape(test_crops[idx], [args.patch_height, args.patch_width, 3]), crop))
                        misc.imsave(
                            os.path.join(args.checkpoint_dir, str(args.dataset) + str(idx) + '_iteration_' + str(i) +
                                         '.jpg'), before_after)
                        idx += 1

                # save the model that corresponds to the current iteration
                if args.save_ckpt_file:
                    saver.save(sess,
                               os.path.join(args.checkpoint_dir, str(args.dataset) + '_iteration_' + str(i) + '.ckpt'),
                               write_meta_graph=False)

                train_loss_gen = 0.0
                train_acc_discrim = 0.0
                # reload a different batch of training data
                del train_data
                del train_answ
                del test_data
                del test_answ
                test_data, test_answ = load_test_data(args.dataset, args.dataset_dir, args.test_size, args.patch_size)
                train_data, train_answ = load_batch(args.dataset, args.dataset_dir, args.train_size, args.patch_size)
import matplotlib.pyplot as plt

sys.path.append('/userhome/Enhance')

from net import unet

SCALE = 1

logging.basicConfig(
    format="[%(process)d] %(levelname)s %(filename)s:%(lineno)s | %(message)s")
log = logging.getLogger("train")
log.setLevel(logging.INFO)

input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
with tf.variable_scope('inference'):
    output = unet(input)


def output_psnr_mse(img_orig, img_out):
    squared_error = np.square(img_orig - img_out)
    mse = np.mean(squared_error)
    psnr = 10 * np.log10(1.0 / mse)
    return psnr


input_path = "/userhome/dped/validation/input/"
target_path = "/userhome/dped/validation/output/"

chkpt_path = tf.train.get_checkpoint_state("/userhome/Enhance/checkpoint/")
test_images = os.listdir(target_path)
num_test_images = len(test_images)
Пример #10
0
def main(args, data_params):
    procname = os.path.basename(args.checkpoint_dir)

    log.info('Preparing summary and checkpoint directory {}'.format(
        args.checkpoint_dir))
    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    tf.set_random_seed(1234)  # Make experiments repeatable

    # Select an architecture

    # Add model parameters to the graph (so they are saved to disk at checkpoint)

    # --- Train/Test datasets ---------------------------------------------------
    data_pipe = getattr(dp, args.data_pipeline)
    with tf.variable_scope('train_data'):
        train_data_pipeline = data_pipe(
            args.data_dir,
            shuffle=True,
            batch_size=args.batch_size,
            nthreads=args.data_threads,
            fliplr=args.fliplr,
            flipud=args.flipud,
            rotate=args.rotate,
            random_crop=args.random_crop,
            params=data_params,
            output_resolution=args.output_resolution,
            scale=args.scale)
        train_samples = train_data_pipeline.samples

    if args.eval_data_dir is not None:
        with tf.variable_scope('eval_data'):
            eval_data_pipeline = data_pipe(
                args.eval_data_dir,
                shuffle=True,
                batch_size=args.batch_size,
                nthreads=args.data_threads,
                fliplr=False,
                flipud=False,
                rotate=False,
                random_crop=False,
                params=data_params,
                output_resolution=args.output_resolution,
                scale=args.scale)
            eval_samples = eval_data_pipeline.samples
    # ---------------------------------------------------------------------------
    swaps = np.reshape(np.random.randint(0, 2, args.batch_size),
                       [args.batch_size, 1])
    swaps = tf.convert_to_tensor(swaps)
    swaps = tf.cast(swaps, tf.float32)
    # Training graph
    with tf.variable_scope('inference'):
        prediction = unet(train_samples['image_input'])
        loss,loss_content,loss_texture,loss_color,loss_Mssim,loss_tv,discim_accuracy =\
          compute_loss.total_loss(train_samples['image_output'], prediction, swaps, args.batch_size)
        psnr = PSNR(train_samples['image_output'], prediction)
        loss_ssim = MultiScaleSSIM(train_samples['image_output'], prediction)

    # Evaluation graph
    if args.eval_data_dir is not None:
        with tf.name_scope('eval'):
            with tf.variable_scope('inference', reuse=True):
                eval_prediction = unet(eval_samples['image_input'])
            eval_psnr = PSNR(eval_samples['image_output'], eval_prediction)
            eval_ssim = MultiScaleSSIM(eval_samples['image_output'],
                                       eval_prediction)

    # Optimizer
    model_vars1 = [
        v for v in tf.global_variables()
        if v.name.startswith("inference/generator")
    ]
    discriminator_vars1 = [
        v for v in tf.global_variables()
        if v.name.startswith("inference/l2_loss/discriminator")
    ]

    global_step = tf.contrib.framework.get_or_create_global_step()
    with tf.name_scope('optimizer'):
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        updates = tf.group(*update_ops, name='update_ops')
        log.info("Adding {} update ops".format(len(update_ops)))

        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        if reg_losses and args.weight_decay is not None and args.weight_decay > 0:
            print("Regularization losses:")
            for rl in reg_losses:
                print(" ", rl.name)
            opt_loss = loss + args.weight_decay * sum(reg_losses)
        else:
            print("No regularization.")
            opt_loss = loss

        with tf.control_dependencies([updates]):
            opt = tf.train.AdamOptimizer(args.learning_rate)
            minimize = opt.minimize(opt_loss,
                                    name='optimizer',
                                    global_step=global_step,
                                    var_list=model_vars1)
            minimize_discrim = opt.minimize(-loss_texture,
                                            name='discriminator',
                                            global_step=global_step,
                                            var_list=discriminator_vars1)

    # Average loss and psnr for display
    with tf.name_scope("moving_averages"):
        ema = tf.train.ExponentialMovingAverage(decay=0.99)
        update_ma = ema.apply([
            loss, loss_content, loss_texture, loss_color, loss_Mssim, loss_tv,
            discim_accuracy, psnr, loss_ssim
        ])
        loss = ema.average(loss)
        loss_content = ema.average(loss_content)
        loss_texture = ema.average(loss_texture)
        loss_color = ema.average(loss_color)
        loss_Mssim = ema.average(loss_Mssim)
        loss_tv = ema.average(loss_tv)
        discim_accuracy = ema.average(discim_accuracy)
        psnr = ema.average(psnr)
        loss_ssim = ema.average(loss_ssim)

    # Training stepper operation
    train_op = tf.group(minimize, update_ma)
    train_discrim_op = tf.group(minimize_discrim, update_ma)

    # Save a few graphs to
    summaries = [
        tf.summary.scalar('loss', loss),
        tf.summary.scalar('loss_content', loss_content),
        tf.summary.scalar('loss_color', loss_color),
        tf.summary.scalar('loss_texture', loss_texture),
        tf.summary.scalar('loss_ssim', loss_Mssim),
        tf.summary.scalar('loss_tv', loss_tv),
        tf.summary.scalar('discim_accuracy', discim_accuracy),
        tf.summary.scalar('psnr', psnr),
        tf.summary.scalar('ssim', loss_ssim),
        tf.summary.scalar('learning_rate', args.learning_rate),
        tf.summary.scalar('batch_size', args.batch_size),
    ]

    log_fetches = {
        "loss_content": loss_content,
        "loss_texture": loss_texture,
        "loss_color": loss_color,
        "loss_Mssim": loss_Mssim,
        "loss_tv": loss_tv,
        "discim_accuracy": discim_accuracy,
        "step": global_step,
        "loss": loss,
        "psnr": psnr,
        "loss_ssim": loss_ssim
    }

    model_vars = [
        v for v in tf.global_variables()
        if not v.name.startswith("inference/l2_loss/discriminator")
    ]
    discriminator_vars = [
        v for v in tf.global_variables()
        if v.name.startswith("inference/l2_loss/discriminator")
    ]

    # Train config
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # Do not canibalize the entire GPU

    sv = tf.train.Supervisor(
        saver=tf.train.Saver(var_list=model_vars, max_to_keep=100),
        local_init_op=tf.initialize_variables(discriminator_vars),
        logdir=args.checkpoint_dir,
        save_summaries_secs=args.summary_interval,
        save_model_secs=args.checkpoint_interval)
    # Train loop
    with sv.managed_session(config=config) as sess:
        sv.loop(args.log_interval, log_hook, (sess, log_fetches))
        last_eval = time.time()
        while True:
            if sv.should_stop():
                log.info("stopping supervisor")
                break
            try:
                step, _ = sess.run([global_step, train_op])
                _ = sess.run(train_discrim_op)
                since_eval = time.time() - last_eval

                if args.eval_data_dir is not None and since_eval > args.eval_interval:
                    log.info("Evaluating on {} images at step {}".format(
                        3, step))

                    p_ = 0
                    s_ = 0
                    for it in range(3):
                        p_ += sess.run(eval_psnr)
                        s_ += sess.run(eval_ssim)
                    p_ /= 3
                    s_ /= 3

                    sv.summary_writer.add_summary(tf.Summary(value=[
                        tf.Summary.Value(tag="psnr/eval", simple_value=p_)
                    ]),
                                                  global_step=step)

                    sv.summary_writer.add_summary(tf.Summary(value=[
                        tf.Summary.Value(tag="ssim/eval", simple_value=s_)
                    ]),
                                                  global_step=step)

                    log.info("  Evaluation PSNR = {:.2f} dB".format(p_))
                    log.info("  Evaluation SSIM = {:.4f} ".format(s_))

                    last_eval = time.time()

            except tf.errors.AbortedError:
                log.error("Aborted")
                break
            except KeyboardInterrupt:
                break
        chkpt_path = os.path.join(args.checkpoint_dir, 'on_stop.ckpt')
        log.info("Training complete, saving chkpt {}".format(chkpt_path))
        sv.saver.save(sess, chkpt_path)
        sv.request_stop()