Ejemplo n.º 1
0
def train(upscaling_factor,
          residual_blocks,
          feature_size,
          path_prediction,
          checkpoint_dir,
          img_width,
          img_height,
          img_depth,
          subpixel_NN,
          nn,
          restore,
          batch_size=1,
          div_patches=4,
          epochs=10):
    traindataset = Train_dataset(batch_size)
    iterations_train = math.ceil(
        (len(traindataset.subject_list) * 0.8) / batch_size)
    print(iterations_train)
    num_patches = traindataset.num_patches

    # ##========================== DEFINE MODEL ============================##
    t_input_gen = tf1.placeholder(
        'float32',
        [int((batch_size * num_patches) / div_patches), None, None, None, 1],
        name='t_image_input_to_SRGAN_generator')
    t_target_image = tf1.placeholder('float32', [
        int((batch_size * num_patches) / div_patches), img_width, img_height,
        img_depth, 1
    ],
                                     name='t_target_image')
    t_input_mask = tf1.placeholder('float32', [
        int((batch_size * num_patches) / div_patches), img_width, img_height,
        img_depth, 1
    ],
                                   name='t_image_input_mask')

    net_gen = generator(input_gen=t_input_gen,
                        kernel=3,
                        nb=residual_blocks,
                        upscaling_factor=upscaling_factor,
                        img_height=img_height,
                        img_width=img_width,
                        img_depth=img_depth,
                        subpixel_NN=subpixel_NN,
                        nn=nn,
                        feature_size=feature_size,
                        is_train=True,
                        reuse=False)
    net_d, disc_out_real = discriminator(input_disc=t_target_image,
                                         kernel=3,
                                         is_train=True,
                                         reuse=False)
    _, disc_out_fake = discriminator(input_disc=net_gen.outputs,
                                     kernel=3,
                                     is_train=True,
                                     reuse=True)

    # test
    gen_test = generator(t_input_gen,
                         kernel=3,
                         nb=residual_blocks,
                         upscaling_factor=upscaling_factor,
                         img_height=img_height,
                         img_width=img_width,
                         img_depth=img_depth,
                         subpixel_NN=subpixel_NN,
                         nn=nn,
                         feature_size=feature_size,
                         is_train=True,
                         reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###

    if np.random.uniform() > 0.1:
        # give correct classifications
        y_gan_real = tf.ones_like(disc_out_real)
        y_gan_fake = tf.zeros_like(disc_out_real)
    else:
        # give wrong classifications (noisy labels)
        y_gan_real = tf.zeros_like(disc_out_real)
        y_gan_fake = tf.ones_like(disc_out_real)

    d_loss_real = tf.reduce_mean(tf.square(disc_out_real -
                                           smooth_gan_labels(y_gan_real)),
                                 name='d_loss_real')
    d_loss_fake = tf.reduce_mean(tf.square(disc_out_fake -
                                           smooth_gan_labels(y_gan_fake)),
                                 name='d_loss_fake')
    d_loss = d_loss_real + d_loss_fake

    mse_loss = tf.reduce_sum(tf.square(net_gen.outputs - t_target_image),
                             axis=[0, 1, 2, 3, 4],
                             name='g_loss_mse')

    dx_real = t_target_image[:, 1:, :, :, :] - t_target_image[:, :-1, :, :, :]
    dy_real = t_target_image[:, :, 1:, :, :] - t_target_image[:, :, :-1, :, :]
    dz_real = t_target_image[:, :, :, 1:, :] - t_target_image[:, :, :, :-1, :]
    dx_fake = net_gen.outputs[:,
                              1:, :, :, :] - net_gen.outputs[:, :-1, :, :, :]
    dy_fake = net_gen.outputs[:, :,
                              1:, :, :] - net_gen.outputs[:, :, :-1, :, :]
    dz_fake = net_gen.outputs[:, :, :,
                              1:, :] - net_gen.outputs[:, :, :, :-1, :]

    gd_loss = tf.reduce_sum(tf.square(tf.abs(dx_real) - tf.abs(dx_fake))) + \
              tf.reduce_sum(tf.square(tf.abs(dy_real) - tf.abs(dy_fake))) + \
              tf.reduce_sum(tf.square(tf.abs(dz_real) - tf.abs(dz_fake)))

    g_gan_loss = 10e-2 * tf.reduce_mean(
        tf.square(disc_out_fake -
                  smooth_gan_labels(tf.ones_like(disc_out_real))),
        name='g_loss_gan')

    g_loss = mse_loss + g_gan_loss + gd_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(1e-4, trainable=False)
    global_step = tf.Variable(0, trainable=False)
    decay_rate = 0.5
    decay_steps = 4920  # every 2 epochs (more or less)
    learning_rate = tf.train.inverse_time_decay(lr_v,
                                                global_step=global_step,
                                                decay_rate=decay_rate,
                                                decay_steps=decay_steps)

    # Optimizers
    g_optim = tf.train.AdamOptimizer(learning_rate).minimize(g_loss,
                                                             var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(learning_rate).minimize(d_loss,
                                                             var_list=d_vars)

    session = tf.Session()
    tl.layers.initialize_global_variables(session)

    step = 0
    saver = tf.train.Saver()

    if restore is not None:
        saver.restore(session, tf.train.latest_checkpoint(restore))
        val_restore = 0 * epochs
    else:
        val_restore = 0

    array_psnr = []
    array_ssim = []

    for j in range(val_restore, epochs + val_restore):
        for i in range(0, iterations_train):
            # ====================== LOAD DATA =========================== #
            print("This is print by sai")
            print(i)
            print(iterations_train)
            xt_total = traindataset.patches_true(i)
            xm_total = traindataset.mask(i)
            for k in range(0, div_patches):
                print('{}'.format(k))
                xt = xt_total[k *
                              int((batch_size * num_patches) / div_patches):
                              (int((batch_size * num_patches) / div_patches) *
                               k) +
                              int((batch_size * num_patches) / div_patches)]
                xm = xm_total[k *
                              int((batch_size * num_patches) / div_patches):
                              (int((batch_size * num_patches) / div_patches) *
                               k) +
                              int((batch_size * num_patches) / div_patches)]

                # NORMALIZING
                for t in range(0, xt.shape[0]):
                    normfactor = (np.amax(xt[t])) / 2
                    if normfactor != 0:
                        xt[t] = ((xt[t] - normfactor) / normfactor)

                x_generator = gaussian_filter(xt, sigma=1)
                x_generator = zoom(x_generator, [
                    1, (1 / upscaling_factor), (1 / upscaling_factor),
                    (1 / upscaling_factor), 1
                ],
                                   prefilter=False,
                                   order=0)
                xgenin = x_generator

                # ========================= train SRGAN ========================= #
                # update D
                errd, _ = session.run([d_loss, d_optim], {
                    t_target_image: xt,
                    t_input_gen: xgenin
                })
                # update G
                errg, errmse, errgan, errgd, _ = session.run(
                    [g_loss, mse_loss, g_gan_loss, gd_loss, g_optim], {
                        t_input_gen: xgenin,
                        t_target_image: xt,
                        t_input_mask: xm
                    })
                print(
                    "Epoch [%2d/%2d] [%4d/%4d] [%4d/%4d]: d_loss: %.8f g_loss: %.8f (mse: %.6f gdl: %.6f adv: %.6f)"
                    % (j, epochs + val_restore, i, iterations_train, k,
                       div_patches - 1, errd, errg, errmse, errgd, errgan))

                # ========================= evaluate & save model ========================= #

                if k == 1 and i % 20 == 0:
                    if j - val_restore == 0:
                        x_true_img = xt[0]
                        if normfactor != 0:
                            x_true_img = (
                                (x_true_img + 1) * normfactor)  # denormalize
                        img_true = nib.Nifti1Image(x_true_img, np.eye(4))
                        img_true.to_filename(
                            os.path.join(path_prediction,
                                         str(j) + str(i) + 'true.nii.gz'))

                        x_gen_img = xgenin[0]
                        if normfactor != 0:
                            x_gen_img = (
                                (x_gen_img + 1) * normfactor)  # denormalize
                        img_gen = nib.Nifti1Image(x_gen_img, np.eye(4))
                        img_gen.to_filename(
                            os.path.join(path_prediction,
                                         str(j) + str(i) + 'gen.nii.gz'))

                    x_pred = session.run(gen_test.outputs,
                                         {t_input_gen: xgenin})
                    x_pred_img = x_pred[0]
                    if normfactor != 0:
                        x_pred_img = (
                            (x_pred_img + 1) * normfactor)  # denormalize
                    img_pred = nib.Nifti1Image(x_pred_img, np.eye(4))
                    img_pred.to_filename(
                        os.path.join(path_prediction,
                                     str(j) + str(i) + '.nii.gz'))

                    max_gen = np.amax(x_pred_img)
                    max_real = np.amax(x_true_img)
                    if max_gen > max_real:
                        val_max = max_gen
                    else:
                        val_max = max_real
                    min_gen = np.amin(x_pred_img)
                    min_real = np.amin(x_true_img)
                    if min_gen < min_real:
                        val_min = min_gen
                    else:
                        val_min = min_real
                    val_psnr = psnr(np.multiply(x_true_img, xm[0]),
                                    np.multiply(x_pred_img, xm[0]),
                                    data_range=val_max - val_min)
                    #val_psnr = psnr(np.multiply(x_true_img, xm[0]), np.multiply(x_pred_img, xm[0]), data_range = )
                    val_ssim = ssim(np.multiply(x_true_img, xm[0]),
                                    np.multiply(x_pred_img, xm[0]),
                                    data_range=val_max - val_min,
                                    multichannel=True)
                    #val_ssim = ssim(np.multiply(x_true_img, xm[0]), np.multiply(x_pred_img, xm[0]), multichannel=True)

        saver.save(sess=session, save_path=checkpoint_dir, global_step=step)
        print("Saved step: [%2d]" % step)
        step = step + 1
Ejemplo n.º 2
0
def train(upscaling_factor,
          img_width,
          img_height,
          img_depth,
          batch_size=1,
          div_patches=4,
          epochs=10):
    traindataset = Train_dataset(batch_size)
    iterations_train = math.ceil(
        (len(traindataset.subject_list) * 0.8) / batch_size)
    num_patches = traindataset.num_patches
    ###========================== DEFINE MODEL ============================###
    t_input_gen = tf.placeholder(
        'float32',
        [int((batch_size * num_patches) / div_patches), None, None, None, 1],
        name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [
        int((batch_size * num_patches) / div_patches), img_width, img_height,
        img_depth, 1
    ],
                                    name='t_target_image')
    t_input_mask = tf.placeholder('float32', [
        int((batch_size * num_patches) / div_patches), img_width, img_height,
        img_depth, 1
    ],
                                  name='t_image_input_mask')

    net_gen = generator(input_gen=t_input_gen,
                        kernel=3,
                        nb=args.residual_blocks,
                        upscaling_factor=args.upsampling_factor,
                        img_height=img_height,
                        img_width=img_width,
                        img_depth=img_depth,
                        is_train=True,
                        reuse=False)
    net_d, disc_out_real = discriminator(input_disc=t_target_image,
                                         kernel=3,
                                         is_train=True,
                                         reuse=False)
    _, disc_out_fake = discriminator(input_disc=net_gen.outputs,
                                     kernel=3,
                                     is_train=True,
                                     reuse=True)

    # test
    gen_test = generator(t_input_gen,
                         kernel=3,
                         nb=args.residual_blocks,
                         upscaling_factor=args.upsampling_factor,
                         img_height=img_height,
                         img_width=img_width,
                         img_depth=img_depth,
                         is_train=False,
                         reuse=True)

    # ###========================== DEFINE TRAIN OPS ==========================###

    # use disc_out_real in both cases because shape will be equal in disc_out_real and disc_out_fake
    # if not, problems for not specifying input shape for generator

    if np.random.uniform() > 0.1:
        # give correct classifications
        y_gan_real = tf.ones_like(disc_out_real)
        y_gan_fake = tf.zeros_like(disc_out_real)
    else:
        # give wrong classifications (noisy labels)
        y_gan_real = tf.zeros_like(disc_out_real)
        y_gan_fake = tf.ones_like(disc_out_real)

    d_loss_real = tf.reduce_mean(tf.square(disc_out_real -
                                           smooth_gan_labels(y_gan_real)),
                                 name='d_loss_real')
    d_loss_fake = tf.reduce_mean(tf.square(disc_out_fake -
                                           smooth_gan_labels(y_gan_fake)),
                                 name='d_loss_fake')
    d_loss = d_loss_real + d_loss_fake

    if t_input_mask == 0:
        mse_loss = 0
    else:
        mse_loss = tf.reduce_sum(tf.square(net_gen.outputs - t_target_image),
                                 axis=[0, 1, 2, 3, 4],
                                 name='g_loss_mse')

    dx_real = t_target_image[:, 1:, :, :, :] - t_target_image[:, :-1, :, :, :]
    dy_real = t_target_image[:, :, 1:, :, :] - t_target_image[:, :, :-1, :, :]
    dz_real = t_target_image[:, :, :, 1:, :] - t_target_image[:, :, :, :-1, :]
    dx_fake = net_gen.outputs[:,
                              1:, :, :, :] - net_gen.outputs[:, :-1, :, :, :]
    dy_fake = net_gen.outputs[:, :,
                              1:, :, :] - net_gen.outputs[:, :, :-1, :, :]
    dz_fake = net_gen.outputs[:, :, :,
                              1:, :] - net_gen.outputs[:, :, :, :-1, :]

    gd_loss = tf.reduce_sum(tf.square(tf.abs(dx_real) - tf.abs(dx_fake))) + \
              tf.reduce_sum(tf.square(tf.abs(dy_real) - tf.abs(dy_fake))) + \
              tf.reduce_sum(tf.square(tf.abs(dz_real) - tf.abs(dz_fake)))

    # use disc_out_real in both cases because shape will be equal in disc_out_real and disc_out_fake
    # if not, problems for not specifying input shape for generator

    g_gan_loss = 10e-2 * tf.reduce_mean(
        tf.square(disc_out_fake -
                  smooth_gan_labels(tf.ones_like(disc_out_real))),
        name='g_loss_gan')

    g_loss = mse_loss + g_gan_loss + gd_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(1e-4, trainable=False)
    global_step = tf.Variable(0, trainable=False)
    decay_rate = 0.5
    decay_steps = 4920  # every 2 epochs (more or less)
    learning_rate = tf.train.inverse_time_decay(lr_v,
                                                global_step=global_step,
                                                decay_rate=decay_rate,
                                                decay_steps=decay_steps)

    # Optimizers
    g_optim = tf.train.AdamOptimizer(learning_rate).minimize(g_loss,
                                                             var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(learning_rate).minimize(d_loss,
                                                             var_list=d_vars)

    session = tf.Session()
    tl.layers.initialize_global_variables(session)

    step = 0

    for j in range(0, epochs):
        for i in range(0, iterations_train):
            ###====================== LOAD DATA ===========================###
            xt_total = traindataset.patches_true(i)
            xm_total = traindataset.mask(i)
            for k in range(0, div_patches):
                print('{}'.format(k))
                xt = xt_total[k *
                              int((batch_size * num_patches) / div_patches):
                              (int((batch_size * num_patches) / div_patches) *
                               k) +
                              int((batch_size * num_patches) / div_patches)]
                xm = xm_total[k *
                              int((batch_size * num_patches) / div_patches):
                              (int((batch_size * num_patches) / div_patches) *
                               k) +
                              int((batch_size * num_patches) / div_patches)]

                # NORMALIZING
                for t in range(0, xt.shape[0]):
                    normfactor = (np.amax(xt[t])) / 2
                    if normfactor != 0:
                        xt[t] = ((xt[t] - normfactor) / normfactor)

                # RESIZING, don't normalize, XT already normalized
                x_generator = zoom(xt, [
                    1, (1 / upscaling_factor), (1 / upscaling_factor),
                    (1 / upscaling_factor), 1
                ])
                # XGENIN = gaussian_filter(x_generator,sigma=1)
                xgenin = x_generator

                ###========================= train SRGAN =========================###
                # update D
                errd, _ = session.run([d_loss, d_optim], {
                    t_target_image: xt,
                    t_input_gen: xgenin
                })
                # update G
                errg, errmse, errgan, errgd, _ = session.run(
                    [g_loss, mse_loss, g_gan_loss, gd_loss, g_optim], {
                        t_input_gen: xgenin,
                        t_target_image: xt,
                        t_input_mask: xm
                    })
                print(
                    "Epoch [%2d/%2d] [%4d/%4d] [%4d/%4d]: d_loss: %.8f g_loss: %.8f (mse: %.6f gdl: %.6f adv: %.6f)"
                    % (j, epochs, i, iterations_train, k, div_patches - 1,
                       errd, errg, errmse, errgd, errgan))

                ###========================= evaluate & save model =========================###

                if k == 1 and i % 20 == 0:
                    if j == 0:
                        x_true_img = xt[0]
                        if normfactor != 0:
                            x_true_img = (
                                (x_true_img + 1) * normfactor)  # denormalize
                        img_pred = nib.Nifti1Image(x_true_img, np.eye(4))
                        img_pred.to_filename(
                            os.path.join(args.path_prediction,
                                         str(j) + str(i) + 'true.nii.gz'))

                        x_gen_img = xgenin[0]
                        if normfactor != 0:
                            x_gen_img = (
                                (x_gen_img + 1) * normfactor)  # denormalize
                        img_pred = nib.Nifti1Image(x_gen_img, np.eye(4))
                        img_pred.to_filename(
                            os.path.join(args.path_prediction,
                                         str(j) + str(i) + 'gen.nii.gz'))

                    x_pred = session.run(gen_test.outputs,
                                         {t_input_gen: xgenin})
                    x_pred_img = x_pred[0]
                    if normfactor != 0:
                        x_pred_img = (
                            (x_pred_img + 1) * normfactor)  # denormalize
                    img_pred = nib.Nifti1Image(x_pred_img, np.eye(4))
                    img_pred.to_filename(
                        os.path.join(args.path_prediction,
                                     str(j) + str(i) + '.nii.gz'))

                    saver = tf.train.Saver()
                    saver.save(sess=session,
                               save_path=args.checkpoint_dir,
                               global_step=step)
                    print("Saved step: [%2d]" % step)
                    step = step + 1
Ejemplo n.º 3
0
def train(upscaling_factor,
          residual_blocks,
          feature_size,
          path_prediction,
          checkpoint_dir,
          img_width,
          img_height,
          img_depth,
          subpixel_NN,
          nn,
          batch_size=1,
          div_patches=4,
          epochs=12):
    traindataset = Train_dataset(batch_size)
    iterations_train = math.ceil(
        (len(traindataset.subject_list) * 0.8) / batch_size)
    num_patches = traindataset.num_patches

    # ##========================== DEFINE MODEL ============================##
    t_input_gen = tf.placeholder(
        'float32',
        [int((batch_size * num_patches) / div_patches), None, None, None, 1],
        name='t_image_input_to_SRGAN_generator')
    t_target_image = tf.placeholder('float32', [
        int((batch_size * num_patches) / div_patches), img_width, img_height,
        img_depth, 1
    ],
                                    name='t_target_image')
    t_input_mask = tf.placeholder('float32', [
        int((batch_size * num_patches) / div_patches), img_width, img_height,
        img_depth, 1
    ],
                                  name='t_image_input_mask')
    ae_real_z = tf.placeholder(
        'float32',
        [int((batch_size * num_patches) / div_patches), 16, 16, 12, 30],
        name='ae_real_z')
    ae_fake_z = tf.placeholder(
        'float32',
        [int((batch_size * num_patches) / div_patches), 16, 16, 12, 30],
        name='ae_fake_z')

    net_gen = generator(input_gen=t_input_gen,
                        kernel=3,
                        nb=residual_blocks,
                        upscaling_factor=upscaling_factor,
                        img_height=img_height,
                        img_width=img_width,
                        img_depth=img_depth,
                        subpixel_NN=subpixel_NN,
                        nn=nn,
                        feature_size=feature_size,
                        is_train=True,
                        reuse=False)
    net_d, disc_out_real = discriminator(input_disc=t_target_image,
                                         kernel=3,
                                         is_train=True,
                                         reuse=False)
    _, disc_out_fake = discriminator(input_disc=net_gen.outputs,
                                     kernel=3,
                                     is_train=True,
                                     reuse=True)

    # test
    gen_test = generator(t_input_gen,
                         kernel=3,
                         nb=residual_blocks,
                         upscaling_factor=upscaling_factor,
                         img_height=img_height,
                         img_width=img_width,
                         img_depth=img_depth,
                         subpixel_NN=subpixel_NN,
                         nn=nn,
                         feature_size=feature_size,
                         is_train=True,
                         reuse=True)

    # autoencoder
    ae_out = autoencoder(x_auto=t_target_image)

    # ###========================== DEFINE TRAIN OPS ==========================###

    # use disc_out_real in both cases because shape will be equal in disc_out_real and disc_out_fake
    # if not, problems for not specifying input shape for generator

    if np.random.uniform() > 0.1:
        # give correct classifications
        y_gan_real = tf.ones_like(disc_out_real)
        y_gan_fake = tf.zeros_like(disc_out_real)
    else:
        # give wrong classifications (noisy labels)
        y_gan_real = tf.zeros_like(disc_out_real)
        y_gan_fake = tf.ones_like(disc_out_real)

    d_loss_real = tf.reduce_mean(tf.square(disc_out_real -
                                           smooth_gan_labels(y_gan_real)),
                                 name='d_loss_real')
    d_loss_fake = tf.reduce_mean(tf.square(disc_out_fake -
                                           smooth_gan_labels(y_gan_fake)),
                                 name='d_loss_fake')
    d_loss = d_loss_real + d_loss_fake

    # use disc_out_real in both cases because shape will be equal in disc_out_real and disc_out_fake
    # if not, problems for not specifying input shape for generator

    g_gan_loss = 10e-2 * tf.reduce_mean(
        tf.square(disc_out_fake -
                  smooth_gan_labels(tf.ones_like(disc_out_real))),
        name='g_loss_gan')
    mse_loss = tf.reduce_sum(tf.square(net_gen.outputs - t_target_image),
                             axis=[0, 1, 2, 3, 4],
                             name='g_loss_mse')

    dx_real = t_target_image[:, 1:, :, :, :] - t_target_image[:, :-1, :, :, :]
    dy_real = t_target_image[:, :, 1:, :, :] - t_target_image[:, :, :-1, :, :]
    dz_real = t_target_image[:, :, :, 1:, :] - t_target_image[:, :, :, :-1, :]
    dx_fake = net_gen.outputs[:,
                              1:, :, :, :] - net_gen.outputs[:, :-1, :, :, :]
    dy_fake = net_gen.outputs[:, :,
                              1:, :, :] - net_gen.outputs[:, :, :-1, :, :]
    dz_fake = net_gen.outputs[:, :, :,
                              1:, :] - net_gen.outputs[:, :, :, :-1, :]

    gd_loss = tf.reduce_sum(tf.square(tf.abs(dx_real) - tf.abs(dx_fake))) + \
              tf.reduce_sum(tf.square(tf.abs(dy_real) - tf.abs(dy_fake))) + \
              tf.reduce_sum(tf.square(tf.abs(dz_real) - tf.abs(dz_fake)))

    g_loss_mse = mse_loss + g_gan_loss + gd_loss

    g_loss_ae = 10e-2 * tf.reduce_sum(tf.square(ae_real_z - ae_fake_z),
                                      axis=[0, 1, 2, 3, 4],
                                      name='ae_loss') + g_gan_loss

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

    with tf.variable_scope('learning_rate'):
        #lr_v_g = tf.Variable(1e-5, trainable=False)
        lr_v_g = tf.Variable(1e-4, trainable=False)
        lr_v_d = tf.Variable(1e-4, trainable=False)
    global_step = tf.Variable(0, trainable=False)
    decay_rate = 0.5
    decay_steps = 4920  # every 2 epochs (more or less)
    learning_rate_g = tf.train.inverse_time_decay(lr_v_g,
                                                  global_step=global_step,
                                                  decay_rate=decay_rate,
                                                  decay_steps=decay_steps)
    learning_rate_d = tf.train.inverse_time_decay(lr_v_d,
                                                  global_step=global_step,
                                                  decay_rate=decay_rate,
                                                  decay_steps=decay_steps)

    # Optimizers
    g_optim_mse = tf.train.AdamOptimizer(learning_rate_d).minimize(
        g_loss_mse, var_list=g_vars)
    #Adagrad probado no va mejor que sin ae, si converge (5e-5, 1e-5)
    #Adam/RMSprop no converge (5e-5)
    #Adadelta probado no va mejor que sin ae, si converge (5e-5, 1e-4 va peor) va mejor sin lrdecay
    g_optim_ae = tf.train.AdagradOptimizer(5e-5).minimize(g_loss_ae,
                                                          var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(learning_rate_d).minimize(d_loss,
                                                               var_list=d_vars)

    session = tf.Session()
    tl.layers.initialize_global_variables(session)

    session_ae = tf.Session()
    session_ae.run(tf.global_variables_initializer())
    var_list = list()

    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        if 'SRGAN' not in v.name and '9' not in v.name and 'learning_rate' not in v.name and 'beta' not in v.name:
            var_list.append(v)
    saver_ae = tf.train.Saver(var_list)
    saver_ae.restore(session_ae,
                     tf.train.latest_checkpoint(AUTOENCODER_CHECPOINTS))
    initialize_uninitialized_vars(session_ae)

    step = 0

    for j in range(0, epochs):
        for i in range(0, iterations_train):
            ###====================== LOAD DATA ===========================###
            xt_total = traindataset.patches_true(i)
            xm_total = traindataset.mask(i)
            for k in range(0, div_patches):
                print('{}'.format(k))
                xt = xt_total[k *
                              int((batch_size * num_patches) / div_patches):
                              (int((batch_size * num_patches) / div_patches) *
                               k) +
                              int((batch_size * num_patches) / div_patches)]
                xm = xm_total[k *
                              int((batch_size * num_patches) / div_patches):
                              (int((batch_size * num_patches) / div_patches) *
                               k) +
                              int((batch_size * num_patches) / div_patches)]

                # NORMALIZING
                for t in range(0, xt.shape[0]):
                    normfactor = (np.amax(xt[t])) / 2
                    if normfactor != 0:
                        xt[t] = ((xt[t] - normfactor) / normfactor)

                # RESIZING, don't normalize, XT already normalized
                x_generator = zoom(xt, [
                    1, (1 / upscaling_factor), (1 / upscaling_factor),
                    (1 / upscaling_factor), 1
                ])
                x_generator = gaussian_filter(x_generator, sigma=1)
                xgenin = x_generator

                ###========================= train SRGAN =========================###
                # update D
                errd, _ = session.run([d_loss, d_optim], {
                    t_target_image: xt,
                    t_input_gen: xgenin
                })

                if j < 4:
                    errg, errmse, errgan, errgd, _ = session.run(
                        [
                            g_loss_mse, mse_loss, g_gan_loss, gd_loss,
                            g_optim_mse
                        ], {
                            t_input_gen: xgenin,
                            t_target_image: xt,
                            t_input_mask: xm
                        })
                    print(
                        "Epoch [%2d/%2d] [%4d/%4d] [%4d/%4d]: d_loss: %.8f g_loss: %.8f (mse: %.6f gdl: %.6f adv: %.6f)"
                        % (j, epochs, i, iterations_train, k, div_patches - 1,
                           errd, errg, errmse, errgd, errgan))
                else:
                    # loss autoencoder
                    x_pred_ae = session.run(gen_test.outputs,
                                            {t_input_gen: xgenin})
                    ae_fake_z_val = session_ae.run(ae_out['z'],
                                                   {t_target_image: x_pred_ae})
                    ae_real_z_val = session_ae.run(ae_out['z'],
                                                   {t_target_image: xt})
                    # update G
                    errg, errgan, _ = session.run(
                        [g_loss_ae, g_gan_loss, g_optim_ae], {
                            t_input_gen: xgenin,
                            t_target_image: xt,
                            t_input_mask: xm,
                            ae_real_z: ae_real_z_val,
                            ae_fake_z: ae_fake_z_val
                        })

                    print(
                        "Epoch [%2d/%2d] [%4d/%4d] [%4d/%4d]: d_loss: %.8f g_loss: %.8f (adv: %.6f)"
                        % (j, epochs, i, iterations_train, k, div_patches - 1,
                           errd, errg, errgan))

                ###========================= evaluate & save model =========================###

                if k == 1 and i % 20 == 0:
                    if j == 0:
                        x_true_img = xt[0]
                        if normfactor != 0:
                            x_true_img = (
                                (x_true_img + 1) * normfactor)  # denormalize
                        img_pred = nib.Nifti1Image(x_true_img, np.eye(4))
                        img_pred.to_filename(
                            os.path.join(path_prediction,
                                         str(j) + str(i) + 'true.nii.gz'))

                        x_gen_img = xgenin[0]
                        if normfactor != 0:
                            x_gen_img = (
                                (x_gen_img + 1) * normfactor)  # denormalize
                        img_pred = nib.Nifti1Image(x_gen_img, np.eye(4))
                        img_pred.to_filename(
                            os.path.join(path_prediction,
                                         str(j) + str(i) + 'gen.nii.gz'))

                    x_pred = session.run(gen_test.outputs,
                                         {t_input_gen: xgenin})
                    x_pred_img = x_pred[0]
                    if normfactor != 0:
                        x_pred_img = (
                            (x_pred_img + 1) * normfactor)  # denormalize
                    img_pred = nib.Nifti1Image(x_pred_img, np.eye(4))
                    img_pred.to_filename(
                        os.path.join(path_prediction,
                                     str(j) + str(i) + '.nii.gz'))

                    # x_auto = session_ae.run(ae_out['y'], {t_target_image: xt})
                    # x_auto_img = x_auto[0]
                    # if normfactor != 0:
                    #     x_auto_img = ((x_auto_img + 1) * normfactor)  # denormalize
                    # img_pred = nib.Nifti1Image(x_auto_img, np.eye(4))
                    # img_pred.to_filename(
                    #     os.path.join(path_prediction, str(j) + str(i) + 'yayauto.nii.gz'))

        saver = tf.train.Saver()
        saver.save(sess=session, save_path=checkpoint_dir, global_step=step)
        print("Saved step: [%2d]" % step)
        step = step + 1