예제 #1
0
class GANUnetModel148():
    def __init__(self):

        K.set_image_data_format('channels_last')  # set format
        K.set_image_dim_ordering('tf')
        self.DEBUG = 1

        self.crop_size_g = (148, 148, 148)
        self.crop_size_d = (60, 60, 60)

        self.channels = 1

        self.input_shape_g = self.crop_size_g + (self.channels, )
        self.input_shape_d = self.crop_size_d + (self.channels, )

        self.output_shape_g = (60, 60, 60) + (
            3, )  # phi has three outputs. one for each X, Y, and Z dimensions
        self.output_shape_d = (15, 15, 15) + (self.channels, )
        #self.output_shape_d_v2 = (5, 5, 5) + (self.channels,)

        self.batch_sz = 1  # for testing locally to avoid memory allocation

        # Number of filters in the first layer of G and D
        self.gf = 32
        self.df = 32

        optimizerD = Adam(
            0.001, decay=0.05
        )  # in the paper the learning rate is 0.001 and weight decay is 0.5
        #optimizerD = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
        self.decay = 0.5
        self.iterations_decay = 50
        self.learning_rate = 0.001
        optimizerG = Adam(
            0.001,
            decay=0.05)  # in the paper the decay after 50K iterations by 0.5
        #optimizerG = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
        # Build the three networks
        self.generator = self.build_generator()
        self.generator.summary()
        self.discriminator = self.build_discriminator()
        self.discriminator.summary()
        self.transformation = self.build_transformation()
        self.transformation.summary()

        # Compile the discriminator
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizerD,
                                   metrics=['accuracy'])

        # Build the generator
        img_S = Input(shape=self.input_shape_g)  # subject image S
        img_T = Input(shape=self.input_shape_g)  # template image T

        # By conditioning on T generate a warped transformation function of S
        phi = self.generator([img_S, img_T])

        # Transform S
        warped_S = self.transformation([img_S, phi])

        # Use Python partial to provide loss function with additional deformable field argument
        partial_gp_loss = partial(self.gradient_penalty_loss, phi=phi)
        partial_gp_loss.__name__ = 'gradient_penalty'  # Keras requires function names

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminators determines validity of translated images / condition pairs
        validity = self.discriminator([warped_S, img_T])

        self.combined = Model(inputs=[img_S, img_T], outputs=validity)
        self.combined.summary()
        self.combined.compile(loss=partial_gp_loss, optimizer=optimizerG)

        if self.DEBUG:
            log_path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/logs_ganunet_v148/'
            self.callback = TensorBoard(log_path)
            self.callback.set_model(self.combined)

        self.data_loader = DataLoader(batch_sz=self.batch_sz,
                                      crop_size=self.crop_size_g,
                                      dataset_name='fly')

    """
    Generator Network
    """

    def build_generator(self):
        """U-Net Generator"""
        def conv3d(input_tensor,
                   n_filters,
                   kernel_size=(3, 3, 3),
                   batch_normalization=True,
                   scale=True,
                   padding='valid',
                   use_bias=False,
                   name=''):
            """
            3D convolutional layer (+ batch normalization) followed by ReLu activation
            """
            layer = Conv3D(filters=n_filters,
                           kernel_size=kernel_size,
                           padding=padding,
                           use_bias=use_bias,
                           name=name + '_conv3d')(input_tensor)
            # if batch_normalization:
            #     layer = BatchNormalization(name=name+'_bn')(layer)
            #layer = Activation('relu', name=name+'_actrelu')(layer)
            # Add BN after activation
            if batch_normalization:
                layer = BatchNormalization(momentum=0.8,
                                           name=name + '_bn',
                                           scale=scale)(layer)
            layer = LeakyReLU(alpha=0.2, name=name + '_actleakyrelu')(layer)
            return layer

        def deconv3d(input_tensor,
                     n_filters,
                     kernel_size=(3, 3, 3),
                     batch_normalization=True,
                     scale=True,
                     padding='valid',
                     use_bias=False,
                     name=''):
            """
            3D deconvolutional layer (+ batch normalization) followed by ReLu activation
            """
            layer = UpSampling3D(size=2)(input_tensor)
            layer = Conv3D(filters=n_filters,
                           kernel_size=kernel_size,
                           padding=padding,
                           use_bias=use_bias,
                           name=name + '_conv3d')(layer)
            # BN before activation
            if batch_normalization:
                layer = BatchNormalization(momentum=0.8,
                                           name=name + '_bn',
                                           scale=scale)(layer)
            layer = LeakyReLU(alpha=0.2, name=name + '_actleakyrelu')(layer)
            return layer

        img_S = Input(shape=self.input_shape_g,
                      name='input_img_S')  # 148x148x148
        img_T = Input(shape=self.input_shape_g,
                      name='input_img_T')  # 148x148x148

        # Concatenate subject image and template image by channels to produce input
        #combined_imgs = Concatenate(axis=-1, name='combine_imgs_g')([img_S, img_T])
        combined_imgs = Add(name='combine_imgs_g')([img_S, img_T])

        # downsampling
        down1 = conv3d(input_tensor=combined_imgs,
                       n_filters=self.gf,
                       padding='valid',
                       name='down1_1')  #146
        down1 = conv3d(input_tensor=down1,
                       n_filters=self.gf,
                       padding='valid',
                       name='down1_2')  #144
        pool1 = MaxPooling3D(pool_size=(2, 2, 2), name='pool1')(down1)  #72

        down2 = conv3d(input_tensor=pool1,
                       n_filters=2 * self.gf,
                       padding='valid',
                       name='down2_1')  #70
        down2 = conv3d(input_tensor=down2,
                       n_filters=2 * self.gf,
                       padding='valid',
                       name='down2_2')  #68
        pool2 = MaxPooling3D(pool_size=(2, 2, 2), name='pool2')(down2)  #34

        down3 = conv3d(input_tensor=pool2,
                       n_filters=4 * self.gf,
                       padding='valid',
                       name='down3_1')  #32
        down3 = conv3d(input_tensor=down3,
                       n_filters=4 * self.gf,
                       padding='valid',
                       name='down3_2')  #30
        pool3 = MaxPooling3D(pool_size=(2, 2, 2), name='pool3')(down3)  #15

        center = conv3d(input_tensor=pool3,
                        n_filters=8 * self.gf,
                        padding='valid',
                        name='center1')  #13
        center = conv3d(input_tensor=center,
                        n_filters=8 * self.gf,
                        padding='valid',
                        name='center2')  #11

        # upsampling with gap filling
        up3 = deconv3d(input_tensor=center,
                       n_filters=4 * self.gf,
                       padding='same',
                       name='up3')  #22
        gap3 = conv3d(input_tensor=down3,
                      n_filters=4 * self.gf,
                      padding='valid',
                      name='gap3_1')  #28
        gap3 = conv3d(input_tensor=gap3,
                      n_filters=4 * self.gf,
                      padding='valid',
                      name='gap3_2')  #26
        up3 = concatenate([Cropping3D(2)(gap3), up3], name='up3concat')  #22
        up3 = conv3d(input_tensor=up3,
                     n_filters=4 * self.gf,
                     padding='valid',
                     name='up3conv_1')  #20
        up3 = conv3d(input_tensor=up3,
                     n_filters=4 * self.gf,
                     padding='valid',
                     name='up3conv_2')  #18

        up2 = deconv3d(input_tensor=up3,
                       n_filters=2 * self.gf,
                       padding='same',
                       name='up2')  #36
        gap2 = conv3d(input_tensor=down2,
                      n_filters=2 * self.gf,
                      padding='valid',
                      name='gap2_1')  #66
        for i in range(2, 7):
            gap2 = conv3d(input_tensor=gap2,
                          n_filters=2 * self.gf,
                          padding='valid',
                          name='gap2_' + str(i))  #56

        up2 = concatenate([Cropping3D(10)(gap2), up2], name='up2concat')  #36
        up2 = conv3d(input_tensor=up2,
                     n_filters=2 * self.gf,
                     padding='valid',
                     name='up2conv_1')  #34
        up2 = conv3d(input_tensor=up2,
                     n_filters=2 * self.gf,
                     padding='valid',
                     name='up2conv_2')  #32

        up1 = deconv3d(input_tensor=up2,
                       n_filters=self.gf,
                       padding='same',
                       name='up1')  #64
        gap1 = conv3d(input_tensor=down1,
                      n_filters=self.gf,
                      padding='valid',
                      name='gap1_1')  #142
        for i in range(2, 21):
            gap1 = conv3d(input_tensor=gap1,
                          n_filters=self.gf,
                          padding='valid',
                          name='gap1_' + str(i))  #104
        up1 = concatenate([Cropping3D(20)(gap1), up1], name='up1concat')  #64
        up1 = conv3d(input_tensor=up1,
                     n_filters=self.gf,
                     padding='valid',
                     name='up1conv_1')  #62
        up1 = conv3d(input_tensor=up1,
                     n_filters=self.gf,
                     padding='valid',
                     name='up1conv_2')  #60

        phi = Conv3D(filters=3,
                     kernel_size=(1, 1, 1),
                     padding='same',
                     use_bias=False,
                     name='phi')(up1)  #60

        model = Model([img_S, img_T], outputs=phi, name='generator_model')

        return model

    """
    Discriminator Network
    """

    def build_discriminator(self):
        def d_layer(layer_input,
                    filters,
                    f_size=5,
                    bn=True,
                    scale=True,
                    name=''):  #change the bn to False
            """Discriminator layer"""
            d = Conv3D(filters,
                       kernel_size=f_size,
                       strides=1,
                       padding='same',
                       name=name + '_conv3d')(layer_input)
            if bn:
                d = BatchNormalization(momentum=0.8,
                                       name=name + '_bn',
                                       scale=scale)(d)
            d = LeakyReLU(alpha=0.2, name=name + '_leakyrelu')(d)
            return d

        img_S = Input(shape=self.input_shape_d,
                      name='input_img_A')  # 60 warped_img or reference
        img_T = Input(shape=self.input_shape_g,
                      name='input_img_T')  # 148 template

        img_T_cropped = Cropping3D(cropping=44)(img_T)  # 60

        # Concatenate image and conditioning image by channels to produce input
        #combined_imgs = Concatenate(axis=-1, name='combine_imgs_d')([img_S, img_T_cropped])
        combined_imgs = Add(name='combine_imgs_d')([img_S, img_T_cropped])
        d1 = d_layer(combined_imgs, self.df, bn=False, name='d1')  # 60
        d2 = d_layer(d1, self.df * 2, name='d2')  # 60
        pool = MaxPooling3D(pool_size=(2, 2, 2), name='d2_pool')(d2)  # 30

        d3 = d_layer(pool, self.df * 4, name='d3')  # 30
        d4 = d_layer(d3, self.df * 8, name='d4')  # 30
        pool = MaxPooling3D(pool_size=(2, 2, 2), name='d4_pool')(d4)  # 15

        d5 = d_layer(pool, self.df * 8, name='d5')  # 15

        # ToDo: Use FC layer at the end like specified in the paper
        validity = Conv3D(1,
                          kernel_size=4,
                          strides=1,
                          padding='same',
                          activation='sigmoid',
                          name='validity')(d5)  #9
        #d6 = Conv3D(1, kernel_size=4, strides=1, padding='same', name='validity')(d5)  # 6x6x6

        #validity = Flatten(data_format='channels_last')(d6)
        #x = Reshape((6*6*6*512,))(d5) # hack to avoid flatten bug
        #validity = Dense(1, activation='sigmoid')(x)

        # Use FC layer
        #d6 = Flatten(input_shape=(self.batch_sz,) + (6,6,6,512))(d5)
        #validity = Dense(1, activation='sigmoid')(d5)

        return Model([img_S, img_T], validity, name='discriminator_model')

    """
    Discriminator Network v2
    """

    def build_discriminator_v2(self):
        def d_layer(layer_input, filters, f_size=5, bn=True, name=''):
            """Discriminator layer"""
            d = Conv3D(filters,
                       kernel_size=f_size,
                       strides=2,
                       padding='same',
                       name=name + '_conv3d')(layer_input)
            d = LeakyReLU(alpha=0.2, name=name + '_leakyrelu')(d)
            if bn:
                d = BatchNormalization(momentum=0.8, name=name + '_bn')(d)
            return d

        img_A = Input(shape=self.input_shape_d,
                      name='input_img_A')  # 68x68x68 warped_img or reference
        img_T = Input(shape=self.input_shape_g,
                      name='input_img_T')  # 148x148x148 template

        img_T_cropped = Cropping3D(cropping=40)(img_T)  # 68x68x68

        # Concatenate image and conditioning image by channels to produce input
        combined_imgs = Concatenate(axis=-1)([img_A, img_T_cropped])

        d1 = d_layer(combined_imgs, self.df, bn=False, name='d1')
        d2 = d_layer(d1, self.df * 2, name='d2')
        d3 = d_layer(d2, self.df * 4, name='d3')
        d4 = d_layer(d3, self.df * 8, name='d4')

        #d5 = d_layer(d4, self.df*16, name='d5')
        # d5 = Flatten()(d4)
        # validity = Dense(1, activation='sigmoid')(d5)

        validity = Conv3D(1,
                          kernel_size=4,
                          strides=1,
                          padding='same',
                          activation='sigmoid',
                          name='disc_sig')(d4)  # 5x5x5

        return Model([img_A, img_T], validity, name='discriminator_model')

    """
    Deformable Transformation Layer    
    """

    def build_transformation(self):
        img_S = Input(shape=self.input_shape_g,
                      name='input_img_S_transform')  # 148
        phi = Input(shape=self.output_shape_g,
                    name='input_phi_transform')  # 60

        img_S_cropped = Cropping3D(cropping=44)(img_S)  # 60
        warped_S = Lambda(dense_image_warp_3D,
                          output_shape=(60, 60, 60, 1))([img_S_cropped, phi])

        return Model([img_S, phi], warped_S, name='transformation_layer')

    """
    Define losses
    """

    def gradient_penalty_loss(self, y_true, y_pred, phi):
        """
        Computes gradient penalty on phi to ensure smoothness
        """
        lr = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
        # compute the numerical gradient of phi
        gradients = numerical_gradient_3D(phi)
        # #if self.DEBUG: gradients = K.print_tensor(gradients, message='gradients are:')
        #
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr,
                                  axis=np.arange(1, len(gradients_sqr.shape)))
        # #   ... and sqrt
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # # compute lambda * (1 - ||grad||)^2 still for each single sample
        # #gradient_penalty = K.square(1 - gradient_l2_norm)
        # # return the mean as loss over all the batch samples
        return K.mean(gradient_l2_norm) + lr

    # return gradients_sqr_sum + lr
    """
    Training
    """

    def train(self, epochs, batch_size=1, sample_interval=50):

        # Adversarial loss ground truths
        disc_patch = self.output_shape_d
        input_sz = 148
        output_sz = 60
        gap = int((input_sz - output_sz) / 2)

        # hard labels
        validhard = np.ones((self.batch_sz, ) + disc_patch)
        fakehard = np.zeros((self.batch_sz, ) + disc_patch)
        # hard labels with only one output
        #validhard = np.ones((self.batch_sz, 1))
        #fakehard = np.zeros((self.batch_sz, 1))

        # soft labels only smooth the labels of postive samples
        # https://arxiv.org/abs/1701.00160
        # https://github.com/soumith/ganhacks/issues/41
        #smooth = 0.1 # validhard -smooth
        #validsoft =  0.9 + 0.1 * np.random.random_sample((self.batch_sz,) + disc_patch)     # random between [0.9, 1)
        #fakesoft =  0.1 * np.random.random_sample((self.batch_sz,) + disc_patch)           # random between [0, 0.1)
        validsoft = np.random.uniform(low=0.7,
                                      high=1.2,
                                      size=(self.batch_sz, ) + disc_patch)
        fakesoft = np.random.uniform(low=0.0,
                                     high=0.3,
                                     size=(self.batch_sz, ) + disc_patch)

        start_time = datetime.datetime.now()
        for epoch in range(epochs):
            for batch_i, (batch_img, batch_img_template,
                          batch_img_golden) in enumerate(
                              self.data_loader.load_batch()):
                # ---------------------
                #  Train Discriminator
                # ---------------------
                #assert not np.any(np.isnan(batch_img))
                #assert not np.any(np.isnan(batch_img_template))
                phi = self.generator.predict([batch_img,
                                              batch_img_template])  #24x24x24
                #assert not np.any(np.isnan(phi))

                # deformable transformation
                transform = self.transformation.predict([batch_img,
                                                         phi])  #24x24x24
                #assert not np.any(np.isnan(transform))

                # Create a ref image by perturbing th subject image with the template image
                perturbation_factor_alpha = 0.1 if epoch > epochs / 2 else 0.2
                batch_ref = perturbation_factor_alpha * batch_img + (
                    1 -
                    perturbation_factor_alpha) * batch_img_template  #64x64x64

                batch_img_sub = np.zeros((self.batch_sz, output_sz, output_sz,
                                          output_sz, self.channels),
                                         dtype=batch_img.dtype)
                batch_ref_sub = np.zeros((self.batch_sz, output_sz, output_sz,
                                          output_sz, self.channels),
                                         dtype=batch_ref.dtype)
                batch_temp_sub = np.zeros((self.batch_sz, output_sz, output_sz,
                                           output_sz, self.channels),
                                          dtype=batch_img_template.dtype)
                #batch_golden_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels), dtype=batch_img_golden.dtype)

                # take only (24,24,24) from the (64,64,64) size
                batch_img_sub[:, :, :, :, :] = batch_img[:, 0 + gap:0 + gap +
                                                         output_sz, 0 + gap:0 +
                                                         gap + output_sz,
                                                         0 + gap:0 + gap +
                                                         output_sz, :]
                batch_ref_sub[:, :, :, :, :] = batch_ref[:, 0 + gap:0 + gap +
                                                         output_sz, 0 + gap:0 +
                                                         gap + output_sz,
                                                         0 + gap:0 + gap +
                                                         output_sz, :]

                #batch_golden_sub[:, :, :, :, :] = batch_img_golden[:, 0 + gap:0 + gap + output_sz,
                #                                            0 + gap:0 + gap + output_sz,
                #                                            0 + gap:0 + gap + output_sz, :]

                batch_temp_sub[:, :, :, :, :] = batch_img_template[:,
                                                                   0 + gap:0 +
                                                                   gap +
                                                                   output_sz,
                                                                   0 + gap:0 +
                                                                   gap +
                                                                   output_sz,
                                                                   0 + gap:0 +
                                                                   gap +
                                                                   output_sz, :]

                #assert not np.any(np.isnan(batch_img_sub))
                #assert not np.any(np.isnan(batch_ref_sub))
                #assert not np.any(np.isnan(batch_temp_sub))

                # Train the discriminator (R -> T is valid, S -> T is fake)
                # Noisy and soft labels
                noisy_prob = 1 - np.sqrt(1 - np.random.random(
                ))  # peak near low values and falling off towards high values
                if noisy_prob < 0.85:  # occasionally flip labels to introduce noisy labels
                    d_loss_real = self.discriminator.train_on_batch(
                        [batch_ref_sub, batch_img_template], validhard)
                    d_loss_fake = self.discriminator.train_on_batch(
                        [transform, batch_img_template], fakehard)
                else:
                    d_loss_real = self.discriminator.train_on_batch(
                        [batch_ref_sub, batch_img_template], fakehard)
                    d_loss_fake = self.discriminator.train_on_batch(
                        [transform, batch_img_template], validhard)

                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # ---------------------
                #  Train Generator
                # ---------------------
                # Train the generator (to fool the discriminator)
                g_loss = self.combined.train_on_batch(
                    [batch_img, batch_img_template], validhard)

                elapsed_time = datetime.datetime.now() - start_time

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss average: %f, acc average: %3d%%, D loss fake:%f, acc: %3d%%, D loss real: %f, acc: %3d%%] [G loss: %f]  time: %s"
                    % (epoch, epochs, batch_i, self.data_loader.n_batches,
                       d_loss[0], 100 * d_loss[1], d_loss_fake[0],
                       100 * d_loss_fake[1], d_loss_real[0],
                       100 * d_loss_real[1], g_loss, elapsed_time))

                if self.DEBUG:
                    #self.write_log(self.callback, ['g_loss'], [g_loss[0]], batch_i)
                    self.write_log(self.callback, ['g_loss'], [g_loss],
                                   batch_i)
                    self.write_log(self.callback, ['d_loss'], [d_loss[0]],
                                   batch_i)

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0 and epoch != 0 and epoch % 5 == 0:
                    self.sample_images(epoch, batch_i)

    def write_log(self, callback, names, logs, batch_no):
        #https://github.com/eriklindernoren/Keras-GAN/issues/52
        for name, value in zip(names, logs):
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value
            summary_value.tag = name
            callback.writer.add_summary(summary, batch_no)
            callback.writer.flush()

    def sample_images(self, epoch, batch_i):
        path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/'
        os.makedirs(path + 'generated_v148/', exist_ok=True)

        idx, imgs_S = self.data_loader.load_data(is_validation=True)
        imgs_T = self.data_loader.img_template

        # Mask the image with the template mask # already done in the preparing phase
        # imgs_T_mask = self.data_loader.mask_template
        # imgs_S = imgs_S * imgs_T_mask

        predict_img = np.zeros(imgs_S.shape, dtype=imgs_S.dtype)
        predict_phi = np.zeros(imgs_S.shape + (3, ), dtype=imgs_S.dtype)

        input_sz = self.crop_size_g
        step = (60, 60, 60)

        gap = (int(
            (input_sz[0] - step[0]) / 2), int(
                (input_sz[1] - step[1]) / 2), int((input_sz[2] - step[2]) / 2))
        start_time = datetime.datetime.now()
        for row in range(0, imgs_S.shape[0] - input_sz[0], step[0]):
            for col in range(0, imgs_S.shape[1] - input_sz[1], step[1]):
                for vol in range(0, imgs_S.shape[2] - input_sz[2], step[2]):
                    patch_sub_img = np.zeros(
                        (1, input_sz[0], input_sz[1], input_sz[2], 1),
                        dtype=imgs_S.dtype)
                    patch_templ_img = np.zeros(
                        (1, input_sz[0], input_sz[1], input_sz[2], 1),
                        dtype=imgs_T.dtype)

                    patch_sub_img[0, :, :, :,
                                  0] = imgs_S[row:row + input_sz[0],
                                              col:col + input_sz[1],
                                              vol:vol + input_sz[2]]
                    patch_templ_img[0, :, :, :,
                                    0] = imgs_T[row:row + input_sz[0],
                                                col:col + input_sz[1],
                                                vol:vol + input_sz[2]]

                    patch_predict_phi = self.generator.predict(
                        [patch_sub_img, patch_templ_img])
                    patch_predict_warped = self.transformation.predict(
                        [patch_sub_img, patch_predict_phi])

                    predict_img[row + gap[0]:row + gap[0] + step[0],
                                col + gap[1]:col + gap[1] + step[1],
                                vol + gap[2]:vol + gap[2] +
                                step[2]] = patch_predict_warped[0, :, :, :, 0]
                    predict_phi[row + gap[0]:row + gap[0] + step[0],
                                col + gap[1]:col + gap[1] + step[1],
                                vol + gap[2]:vol + gap[2] +
                                step[2], :] = patch_predict_phi[0, :, :, :, :]

        elapsed_time = datetime.datetime.now() - start_time
        print(" --- Prediction time: %s" % (elapsed_time))

        nrrd.write(path + "generated_v148/%d_%d_%d" % (epoch, batch_i, idx),
                   predict_img)
        self.data_loader._write_nifti(
            path + "generated_v148/phi%d_%d_%d" % (epoch, batch_i, idx),
            predict_phi)

        file_name = 'gan_network'
        # save the whole network
        gan.combined.save(path + 'generated_v148/' + file_name + str(epoch) +
                          '.whole.h5',
                          overwrite=True)
        print('Save the whole network to disk as a .whole.h5 file')
        model_jason = gan.combined.to_json()
        with open(
                path + 'generated_v148/' + file_name + str(epoch) +
                '_arch.json', 'w') as json_file:
            json_file.write(model_jason)
        gan.combined.save_weights(path + 'generated_v148/' + file_name +
                                  str(epoch) + '_weights.h5',
                                  overwrite=True)
        print(
            'Save the network architecture in .json file and weights in .h5 file'
        )
예제 #2
0
class GAN_pix2pix():

    def __init__(self):

        K.set_image_data_format('channels_last')  # set format
        self.DEBUG = 1

        # Input shape
        self.img_rows = 256
        self.img_cols = 256
        self.img_vols = 256
        self.channels = 1
        self.batch_sz = 1

        self.crop_size = (self.img_rows, self.img_cols, self.img_vols)
        self.img_shape = self.crop_size + (self.channels,)

        self.output_size = 128
        self.output_shape_g = (self.output_size,  self.output_size,  self.output_size) + (3,)  # phi has three outputs. one for each X, Y, and Z dimensions
        self.input_shape_d = (self.output_size,  self.output_size,  self.output_size)+ (1,)

        # Calculate output shape of D
        patch = int(self.output_size / 2 ** 4)
        self.output_shape_d = (patch, patch, patch,  self.channels)

        # Number of filters in the first layer of G and D
        self.gf = 32
        self.df = 32

        optimizer = Adam(0.001, 0.05)
        #optimizer = SGD(lr=0.001, decay=1e-6, momentum=0.9,
        #                  nesterov=True)
        #optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.summary()
        self.discriminator.compile(loss='mse',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        # -------------------------
        # Construct Computational
        #   Graph of Generator
        # -------------------------

        # Build the generator
        self.generator = self.build_generator()
        self.generator.summary()

        self.transformation = self.build_transformation()
        self.transformation.summary()

        # Input images and their conditioning images
        img_S = Input(shape=self.img_shape)
        img_T = Input(shape=self.img_shape)

        # Generate the deformable funtion
        phi = self.generator([img_S, img_T])
        # Transform S
        warped_S = self.transformation([img_S, phi])
        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminators determines validity of translated images / condition pairs
        validity = self.discriminator([warped_S, img_T])

        self.combined = Model(inputs=[img_S, img_T], outputs=[validity, warped_S])
        self.combined.summary()

        partial_gp_loss = partial(self.smoothness_loss, phi=phi)
        partial_gp_loss.__name__ = 'smoothness' # Keras requires function names


        self.combined.compile(loss=[partial_gp_loss, 'mae'],
                              loss_weights=[1, 100],
                              optimizer=optimizer)


        if self.DEBUG:
            log_path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/logs_ganpix2pix_remod_golden_smooth/'
            self.callback = TensorBoard(log_path)
            self.callback.set_model(self.combined)



        self.data_loader = DataLoader(batch_sz=self.batch_sz,
                                      crop_size=self.crop_size,
                                      dataset_name='fly',
                                      min_max=False,
                                      restricted_mask=False,
                                      use_hist_equilized_data=False,
                                      use_sharpen=False,
                                      use_golden=True)

    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4, bn=True):
            """Layers used during downsampling"""
            d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            d = LeakyReLU(alpha=0.2)(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0): # dropout is 50 ->change from the implementaion
            """Layers used during upsampling"""
            u = UpSampling3D(size=2)(layer_input)

            u = Conv3D(filters, kernel_size=f_size, padding='same')(u) # remove the strides
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Activation('relu')(u)
            u = Concatenate()([u, skip_input])
            return u


        img_S = Input(shape=self.img_shape, name='input_img_S') #256
        img_T = Input(shape=self.img_shape, name='input_img_T') #256

        d0 = Concatenate(axis=-1, name='combine_imgs_g')([img_S, img_T])
        #d0= Add(name='combine_imgs_g')([img_S, img_T])  #256

        # Downsampling
        d1 = conv2d(d0, self.gf, bn=False)   #128
        d2 = conv2d(d1, self.gf*2)           #64
        d3 = conv2d(d2, self.gf*4)           #32
        d4 = conv2d(d3, self.gf*8)           #16
        d5 = conv2d(d4, self.gf*8)           #8
        d6 = conv2d(d5, self.gf*8)           #4
        d7 = conv2d(d6, self.gf*8)           #2

        # Upsampling
        u1 = deconv2d(d7, d6, self.gf*8)     #4
        u2 = deconv2d(u1, d5, self.gf*8)     #8
        u3 = deconv2d(u2, d4, self.gf*8)     #16
        u4 = deconv2d(u3, d3, self.gf*4)     #32
        u5 = deconv2d(u4, d2, self.gf*2)     #64
        u6 = deconv2d(u5, d1, self.gf)       #128

        #u7 = UpSampling3D(size=2)(u6)        #256
        phi = Conv3D(filters=3, kernel_size=1, strides=1, padding='same')(u6) #128

        return  Model([img_S, img_T], outputs=phi, name='generator_model')


    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, bn=True):
            """Discriminator layer"""
            d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            d = LeakyReLU(alpha=0.2)(d)
            return d

        img_S = Input(shape=self.input_shape_d) #128 S
        img_T = Input(shape=self.img_shape) #256 T

        img_T_cropped = Cropping3D(cropping=64)(img_T)  # 128

        combined_imgs = Concatenate(axis=-1)([img_S, img_T_cropped])
        #combined_imgs = Add()([img_S, img_T])

        d1 = d_layer(combined_imgs, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv3D(1, kernel_size=4, strides=1, padding='same', name='disc_sig')(d4) #original is linear activation no sigmoid

        return Model([img_S, img_T], validity, name='discriminator_model')


    def build_transformation(self):

        img_S = Input(shape=self.img_shape, name='input_img_S_transform')      # 256
        phi = Input(shape=self.output_shape_g, name='input_phi_transform')     # 128

        img_S_cropped = Cropping3D(cropping=64)(img_S)  # 128

        warped_S = Lambda(dense_image_warp_3D, output_shape=(128, 128, 128, 1))([img_S_cropped, phi])

        return Model([img_S, phi], warped_S,  name='transformation_layer')

    """
     Define losses
     """
    """
    Computes gradient penalty on phi to ensure smoothness
    """
    def smoothness_loss(self, y_true, y_pred, phi):

        # mean square error loss
        mse_loss = K.mean(K.square(y_pred - y_true), axis=-1)

        # compute the numerical gradient of phi
        gradients = numerical_gradient_3D(phi)
        # #if self.DEBUG: gradients = K.print_tensor(gradients, message='gradients are:')
        #
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
        # #   ... and sqrt
        #gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # # compute lambda * (1 - ||grad||)^2 still for each single sample
        #gradient_penalty = K.square(1 - gradient_l2_norm)
        # # return the mean as loss over all the batch samples
        #return K.mean(gradient_l2_norm) + mse_loss
        return K.mean(gradients_sqr_sum) + mse_loss
        #return K.mean(gradient_penalty) + mse_loss
    """
    Training
    """
    def train(self, epochs, batch_size=1, sample_interval=50):
        DEBUG =1
        path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/'
        os.makedirs(path+'generated_pix2pix_remod_smooth/' , exist_ok=True)
        output_sz = 128
        input_sz = 256
        gap = int((input_sz - output_sz)/2)
        # Adversarial loss ground truths
        valid = np.ones((self.batch_sz,) + self.output_shape_d)
        fake = np.zeros((self.batch_sz,) + self.output_shape_d)
        validsoft = np.random.uniform(low=0.7, high=1.2, size=(self.batch_sz,) + self.output_shape_d)
        fakesoft = np.random.uniform(low=0.0, high=0.3, size=(self.batch_sz,) + self.output_shape_d)

        start_time = datetime.datetime.now()
        for epoch in range(epochs):
            for batch_i, (batch_img, batch_img_template, batch_img_golden) in enumerate(self.data_loader.load_batch()):
                # ---------------------
                #  Train Discriminator
                # ---------------------
                # Condition on template and generate a transform
                phi = self.generator.predict([batch_img, batch_img_template])
                transform = self.transformation.predict([batch_img, phi])

                # Create a ref image by perturbing th subject image with the template image
                perturbation_factor_alpha = 0.1 if epoch > epochs/2 else 0.2
                batch_ref = perturbation_factor_alpha * batch_img + (1- perturbation_factor_alpha) * batch_img_template

                batch_img_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels),
                                         dtype=batch_img.dtype)
                batch_ref_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels),
                                         dtype=batch_ref.dtype)
                batch_temp_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels),
                                          dtype=batch_img_template.dtype)
                batch_golden_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels),
                                            dtype=batch_img_golden.dtype)

                # take only (24,24,24) from the (64,64,64) size
                batch_img_sub[:, :, :, :, :] = batch_img[:, 0 + gap:0 + gap + output_sz,
                                               0 + gap:0 + gap + output_sz,
                                               0 + gap:0 + gap + output_sz, :]
                batch_ref_sub[:, :, :, :, :] = batch_ref[:, 0 + gap:0 + gap + output_sz,
                                               0 + gap:0 + gap + output_sz,
                                               0 + gap:0 + gap + output_sz, :]
                batch_golden_sub[:, :, :, :, :] = batch_img_golden[:, 0 + gap:0 + gap + output_sz,
                                                  0 + gap:0 + gap + output_sz,
                                                  0 + gap:0 + gap + output_sz, :]
                batch_temp_sub[:, :, :, :, :] = batch_img_template[:, 0 + gap:0 + gap + output_sz,
                                                0 + gap:0 + gap + output_sz,
                                                0 + gap:0 + gap + output_sz, :]

                # use noisy targets to get the GAN out of any local minima (mode collapse)
                noisy_prob = 1 - np.sqrt(
                    1 - np.random.random())  # peak near low values and falling off towards high values
                if noisy_prob < 0.85:  # occasionally flip labels to introduce noisy labels
                    d_loss_real = self.discriminator.train_on_batch([batch_golden_sub, batch_img_template], valid)
                    d_loss_fake = self.discriminator.train_on_batch([transform, batch_img_template], fake)
                else:
                    d_loss_real = self.discriminator.train_on_batch([batch_golden_sub, batch_img_template], fake)
                    d_loss_fake = self.discriminator.train_on_batch([transform, batch_img_template], valid)

                # d_loss_real = self.discriminator.train_on_batch([batch_img_golden, batch_img_template], valid)
                # d_loss_fake = self.discriminator.train_on_batch([transform, batch_img_template], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # -----------------
                #  Train Generator
                # -----------------
                g_loss = self.combined.train_on_batch([batch_img, batch_img_template], [valid, batch_golden_sub])

                elapsed_time = datetime.datetime.now() - start_time

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss average: %f, acc average: %3d%%, D loss fake:%f, acc: %3d%%, D loss real: %f, acc: %3d%%] [G loss: %f]  time: %s"
                    % (epoch, epochs,
                       batch_i, self.data_loader.n_batches,
                       d_loss[0], 100 * d_loss[1],
                       d_loss_fake[0], 100 * d_loss_fake[1],
                       d_loss_real[0], 100 * d_loss_real[1],
                       g_loss[0],
                       elapsed_time))


                if self.DEBUG:
                    self.write_log(self.callback, ['g_loss'], [g_loss[0]], batch_i)
                    self.write_log(self.callback, ['d_loss'], [d_loss[0]], batch_i)

                if batch_i % sample_interval == 0 and epoch != 0 and epoch % 5 == 0:
                    self.sample_images(epoch, batch_i)


    def write_log(self, callback, names, logs, batch_no):
        #https://github.com/eriklindernoren/Keras-GAN/issues/52
        for name, value in zip(names, logs):
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value
            summary_value.tag = name
            callback.writer.add_summary(summary, batch_no)
            callback.writer.flush()


    def sample_images(self, epoch, batch_i):
        path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/'
        os.makedirs(path+'generated_pix2pix_remod_golden_smooth/' , exist_ok=True)

        idx, imgs_S = self.data_loader.load_data(is_validation=True)
        imgs_T = self.data_loader.img_template

        predict_img = np.zeros(imgs_S.shape, dtype=imgs_S.dtype)
        predict_phi = np.zeros(imgs_S.shape + (3,), dtype=imgs_S.dtype)

        input_sz = self.crop_size
        output_sz = (self.output_size, self.output_size, self.output_size)
        step = (24, 24, 24)

        gap = (int((input_sz[0] - output_sz[0]) / 2), int((input_sz[1] - output_sz[1]) / 2), int((input_sz[2] - output_sz[2]) / 2))
        start_time = datetime.datetime.now()

        for row in range(0, imgs_S.shape[0] - input_sz[0], step[0]):
            for col in range(0, imgs_S.shape[1] - input_sz[1], step[1]):
                for vol in range(0, imgs_S.shape[2] - input_sz[2], step[2]):

                    patch_sub_img = np.zeros((1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_S.dtype)
                    patch_templ_img = np.zeros((1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_T.dtype)

                    patch_sub_img[0, :, :, :, 0] = imgs_S[row:row + input_sz[0],
                                                         col:col + input_sz[1],
                                                         vol:vol + input_sz[2]]
                    patch_templ_img[0, :, :, :, 0] = imgs_T[row:row + input_sz[0],
                                                       col:col + input_sz[1],
                                                       vol:vol + input_sz[2]]

                    patch_predict_phi = self.generator.predict([patch_sub_img, patch_templ_img])
                    patch_predict_warped = self.transformation.predict([patch_sub_img, patch_predict_phi])

                    predict_img[row + gap[0]:row + gap[0] + output_sz[0],
                                col + gap[1]:col + gap[1] + output_sz[1],
                                vol + gap[2]:vol + gap[2] + output_sz[2]] = patch_predict_warped[0, :, :, :, 0]

                    predict_phi[row + gap[0]:row + gap[0] + output_sz[0],
                               col + gap[1]:col + gap[1] + output_sz[1],
                               vol + gap[2]:vol + gap[2] + output_sz[2],:] = patch_predict_phi[0, :, :, :, :]

        elapsed_time = datetime.datetime.now() - start_time
        print(" --- Prediction time: %s" % (elapsed_time))

        nrrd.write(path+"generated_pix2pix_remod_golden_smooth/%d_%d_%d" % (epoch, batch_i, idx), predict_img)
        self.data_loader._write_nifti(path + "generated_pix2pix_remod_golden_smooth/phi%d_%d_%d" % (epoch, batch_i, idx), predict_phi)

        if epoch%10 == 0:
            file_name = 'gan_network' +str(epoch)
            # save the whole network
            gan.combined.save(path+'generated_pix2pix_remod_golden_smooth/'+file_name + '.whole.h5', overwrite=True)
            print('Save the whole network to disk as a .whole.h5 file')
            model_jason = gan.combined.to_json()
            with open(path+'generated_pix2pix_remod_golden_smooth/'+file_name + '_arch.json', 'w') as json_file:
                json_file.write(model_jason)
            gan.combined.save_weights(path+'generated_pix2pix_remod_golden_smooth/'+file_name + '_weights.h5', overwrite=True)
            print('Save the network architecture in .json file and weights in .h5 file')
예제 #3
0
class GANUnetNoGapFillingModel():

    def __init__(self):

        K.set_image_data_format('channels_last')  # set format
        K.set_image_dim_ordering('tf')
        self.DEBUG = 1

        # Input shape
        self.img_rows = 192
        self.img_cols = 192
        self.img_vols = 192
        self.channels = 1
        self.batch_sz = 1  # for testing locally to avoid memory allocation

        self.crop_size = (self.img_rows, self.img_cols, self.img_vols)

        self.img_shape = self.crop_size + (self.channels,)

        self.output_size = 192
        self.output_shape_g = self.crop_size + (3,)  # phi has three outputs. one for each X, Y, and Z dimensions
        self.input_shape_d = self.crop_size + (1,)

        # Calculate output shape of D
        patch = int(self.output_size / 2 ** 4)
        self.output_shape_d = (patch, patch, patch,  self.channels)

        self.batch_sz = 1 # for testing locally to avoid memory allocation

        # Number of filters in the first layer of G and D
        self.gf = 32
        self.df = 32

        # Train the discriminator faster than the generator
        #optimizerD = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) # in the paper the learning rate is 0.001 and weight decay is 0.5
        optimizerD = Adam(0.001, decay=0.05)  # in the paper the decay after 50K iterations by 0.5
        self.decay = 0.5
        self.iterations_decay = 50
        self.learning_rate = 0.001
        #optimizerG = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) # in the paper the decay after 50K iterations by 0.5
        optimizerG = Adam(0.001, decay=0.05)  # in the paper the decay after 50K iterations by 0.5

        # Build the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.summary()
        self.discriminator.compile(loss='binary_crossentropy',
                                    optimizer=optimizerD,
                                    metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()
        self.generator.summary()
        # Build the deformable transformation layer
        self.transformation = self.build_transformation()
        self.transformation.summary()

        # Input images and their conditioning images
        img_S = Input(shape=self.img_shape)
        img_T = Input(shape=self.img_shape)

        # By conditioning on T generate a warped transformation function of S
        phi = self.generator([img_S, img_T])

        # Transform S
        warped_S = self.transformation([img_S, phi])

        # Use Python partial to provide loss function with additional deformable field argument
        partial_gp_loss = partial(self.gradient_penalty_loss, phi=phi)
        partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminators determines validity of translated images / condition pairs
        validity = self.discriminator([warped_S, img_T])

        self.combined = Model(inputs=[img_S, img_T], outputs=validity)
        self.combined.summary()
        self.combined.compile(loss = partial_gp_loss, optimizer=optimizerG)

        if self.DEBUG:
            log_path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/logs_ganunet_nogap/'
            os.makedirs(log_path, exist_ok=True)
            self.callback = TensorBoard(log_path)
            self.callback.set_model(self.combined)

        self.data_loader = DataLoader(batch_sz=self.batch_sz,
                                      crop_size=self.crop_size,
                                      dataset_name='fly',
                                      min_max=False,
                                      restricted_mask=False,
                                      use_hist_equilized_data=False,
                                      use_golden=False)

    """
    Generator Network
    """
    def build_generator(self):
        """U-Net Generator"""
        def conv3d(input_tensor,
                        n_filters,
                        kernel_size=(3, 3, 3),
                        batch_normalization=True,
                        scale=True,
                        padding='valid',
                        use_bias=False,
                        name=''):
            """
            3D convolutional layer (+ batch normalization) followed by ReLu activation
            """
            layer = Conv3D(filters=n_filters,
                           kernel_size=kernel_size,
                           padding=padding,
                           use_bias=use_bias,
                           name=name + '_conv3d')(input_tensor)
            if batch_normalization:
                layer = BatchNormalization(momentum=0.8, name=name+'_bn', scale=scale)(layer)
            #layer = LeakyReLU(alpha=0.2, name=name + '_actleakyrelu')(layer)
            layer = Activation("relu")(layer)
            return layer


        def deconv3d(input_tensor,
                        n_filters,
                        kernel_size=(3, 3, 3),
                        batch_normalization=True,
                        scale=True,
                        padding='valid',
                        use_bias=False,
                        name=''):
            """
            3D deconvolutional layer (+ batch normalization) followed by ReLu activation
            """
            layer = UpSampling3D(size=2)(input_tensor)
            layer = Conv3D(filters=n_filters,
                           kernel_size=kernel_size,
                           padding=padding,
                           use_bias=use_bias,
                           name=name + '_conv3d')(layer)

            if batch_normalization:
                layer = BatchNormalization(momentum=0.8, name=name+'_bn', scale=scale)(layer)
            #layer = LeakyReLU(alpha=0.2, name=name + '_actleakyrelu')(layer)
            layer = Activation("relu")(layer)
            return layer

        img_S = Input(shape=self.img_shape, name='input_img_S')
        img_T = Input(shape=self.img_shape, name='input_img_T')

        combined_imgs = Add(name='combine_imgs_g')([img_S,img_T])

        # downsampling
        down1 = conv3d(input_tensor=combined_imgs, n_filters=self.gf, padding='same', name='down1_1')  # 192
        down1 = conv3d(input_tensor=down1, n_filters=self.gf, padding='same', name='down1_2')          # 192
        pool1 = MaxPooling3D(pool_size=(2, 2, 2), name='pool1')(down1)                                 # 96

        down2 = conv3d(input_tensor=pool1, n_filters=2 * self.gf, padding='same', name='down2_1')      # 96
        down2 = conv3d(input_tensor=down2, n_filters=2 * self.gf, padding='same', name='down2_2')      # 96
        pool2 = MaxPooling3D(pool_size=(2, 2, 2), name='pool2')(down2)                                 # 48

        down3 = conv3d(input_tensor=pool2, n_filters=4 * self.gf, padding='same', name='down3_1')      # 48
        down3 = conv3d(input_tensor=down3, n_filters=4 * self.gf, padding='same', name='down3_2')      # 48
        pool3 = MaxPooling3D(pool_size=(2, 2, 2), name='pool3')(down3)                                 #24

        down4 = conv3d(input_tensor=pool3, n_filters=8 * self.gf, padding='same', name='down4_1')      # 24
        down4 = conv3d(input_tensor=down4, n_filters=8 * self.gf, padding='same', name='down4_2')      # 24
        pool4 = MaxPooling3D(pool_size=(2, 2, 2), name='pool4')(down4)                                 # 12

        down5 = conv3d(input_tensor=pool4, n_filters=8 * self.gf, padding='same', name='down5_1')      # 12
        down5 = conv3d(input_tensor=down5, n_filters=8 * self.gf, padding='same', name='down5_2')      # 12
        pool5 = MaxPooling3D(pool_size=(2, 2, 2), name='pool5')(down5)                                  # 6

        center = conv3d(input_tensor=pool5, n_filters=16 * self.gf, padding='same', name='center1')     # 6
        center = conv3d(input_tensor=center, n_filters=16 * self.gf, padding='same', name='center2')    # 6

        # upsampling
        up5 = deconv3d(input_tensor=center, n_filters = 8*self.gf, padding='same', name='up5')          # 12
        up5 = concatenate([up5,down5])                                                                  # 12
        up5 = conv3d(input_tensor=up5, n_filters=8 * self.gf, padding='same', name='up5_1')             # 12
        up5 = conv3d(input_tensor=up5, n_filters=8 * self.gf, padding='same', name='up5_2')             # 12

        up4 = deconv3d(input_tensor=up5, n_filters=8 * self.gf, padding='same', name='up4')             #24
        up4 = concatenate([up4, down4])                                                                 # 24
        up4 = conv3d(input_tensor=up4, n_filters=8 * self.gf, padding='same', name='up4_1')             # 24
        up4 = conv3d(input_tensor=up4, n_filters=8 * self.gf, padding='same', name='up4_2')             # 24

        up3 = deconv3d(input_tensor=up4, n_filters=4 * self.gf, padding='same', name='up3')             #48
        up3 = concatenate([up3, down3])                                                                 # 48
        up3 = conv3d(input_tensor=up3, n_filters=4 * self.gf, padding='same', name='up3_1')            # 48
        up3 = conv3d(input_tensor=up3, n_filters=4 * self.gf, padding='same', name='up3_2')            # 48

        up2 = deconv3d(input_tensor=up3, n_filters=2 * self.gf, padding='same', name='up2')             # 96
        up2 = concatenate([up2, down2])                                                                # 96
        up2 = conv3d(input_tensor=up2, n_filters=2 * self.gf, padding='same', name='up2_1')            # 96
        up2 = conv3d(input_tensor=up2, n_filters=2 * self.gf, padding='same', name='up2_2')            # 96

        up1 = deconv3d(input_tensor=up2, n_filters=self.gf, padding='same', name='up1')                 # 192
        up1 = concatenate([up1, down1])                                                                 # 192
        up1 = conv3d(input_tensor=up1, n_filters=self.gf, padding='same', name='up1_1')                # 192
        up1 = conv3d(input_tensor=up1, n_filters=self.gf, padding='same', name='up1_2')                # 192

        phi = Conv3D(filters=3, kernel_size=(1, 1, 1), strides=1, use_bias=False, padding='same', name='phi')(up1)                 #192

        model = Model([img_S, img_T], outputs=phi, name='generator_model')

        return model

    """
    Discriminator Network
    """
    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, bn=True):
            """Discriminator layer"""
            d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            #d = LeakyReLU(alpha=0.2)(d)
            d = Activation("relu")(d)
            return d

        img_S = Input(shape=self.img_shape) #192 S
        img_T = Input(shape=self.img_shape) #192 T

        combined_imgs = Add()([img_S, img_T])

        d1 = d_layer(combined_imgs, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv3D(1, kernel_size=4, strides=1, padding='same', activation='sigmoid', name='disc_sig')(d4)

        return Model([img_S, img_T], validity, name='discriminator_model')

    """
    Transformation Network
    """
    def build_transformation(self):
        img_S = Input(shape=self.img_shape, name='input_img_S_transform')      # 192
        phi = Input(shape=self.output_shape_g, name='input_phi_transform')     # 192

        warped_S = Lambda(dense_image_warp_3D, output_shape=self.input_shape_d)([img_S, phi])

        return Model([img_S, phi], warped_S,  name='transformation_layer')


    """
    Define losses
    """
    def gradient_penalty_loss(self, y_true, y_pred, phi):
        """
        Computes gradient penalty on phi to ensure smoothness
        """
        lr = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
        # compute the numerical gradient of phi
        gradients = numerical_gradient_3D(phi)
        # #if self.DEBUG: gradients = K.print_tensor(gradients, message='gradients are:')
        #
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
        # #   ... and sqrt
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # # compute lambda * (1 - ||grad||)^2 still for each single sample
        # #gradient_penalty = K.square(1 - gradient_l2_norm)
        # # return the mean as loss over all the batch samples
        return K.mean(gradient_l2_norm) + lr
        #return gradients_sqr_sum + lr


    """
    Training
    """
    def train(self, epochs, batch_size=1, sample_interval=50):

        # Adversarial loss ground truths
        # hard labels
        valid = np.ones((self.batch_sz,) + self.output_shape_d)
        fake = np.zeros((self.batch_sz,) + self.output_shape_d)

        start_time = datetime.datetime.now()
        for epoch in range(epochs):
            for batch_i, (batch_img, batch_img_template, batch_img_golden) in enumerate(self.data_loader.load_batch()):
                # ---------------------
                #  Train Discriminator
                # ---------------------
                # Condition on B and generate a translate
                phi = self.generator.predict([batch_img, batch_img_template])
                transform = self.transformation.predict([batch_img, phi])  # 256x256x256
                # Create a ref image by perturbing th subject image with the template image
                perturbation_factor_alpha = 0.1 if epoch > epochs / 2 else 0.2
                batch_ref = perturbation_factor_alpha * batch_img + (1 - perturbation_factor_alpha) * batch_img_template

                d_loss_real = self.discriminator.train_on_batch([batch_ref, batch_img_template], valid)
                d_loss_fake = self.discriminator.train_on_batch([transform, batch_img_template], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # ---------------------
                #  Train Generator
                # ---------------------
                g_loss = self.combined.train_on_batch([batch_img, batch_img_template], valid)

                elapsed_time = datetime.datetime.now() - start_time

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss average: %f, acc average: %3d%%, D loss fake:%f, acc: %3d%%, D loss real: %f, acc: %3d%%] [G loss: %f]  time: %s"
                    % (epoch, epochs,
                       batch_i, self.data_loader.n_batches,
                       d_loss[0], 100 * d_loss[1],
                       d_loss_fake[0], 100 * d_loss_fake[1],
                       d_loss_real[0], 100 * d_loss_real[1],
                       g_loss,
                       elapsed_time))

                if self.DEBUG:
                    self.write_log(self.callback, ['g_loss'], [g_loss], batch_i)
                    self.write_log(self.callback, ['d_loss'], [d_loss[0]], batch_i)

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0 and epoch != 0 and epoch % 5 == 0:
                    self.sample_images(epoch, batch_i)


    def write_log(self, callback, names, logs, batch_no):
        #https://github.com/eriklindernoren/Keras-GAN/issues/52
        for name, value in zip(names, logs):
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value
            summary_value.tag = name
            callback.writer.add_summary(summary, batch_no)
            callback.writer.flush()


    def sample_images(self, epoch, batch_i):
        path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/'
        os.makedirs(path+'generated_unet_nogap/' , exist_ok=True)

        idx, imgs_S = self.data_loader.load_data(is_validation=True)
        imgs_T = self.data_loader.img_template

        predict_img = np.zeros(imgs_S.shape, dtype=imgs_S.dtype)
        predict_phi = np.zeros(imgs_S.shape + (3,), dtype=imgs_S.dtype)

        input_sz = self.crop_size
        output_sz = (self.output_size, self.output_size, self.output_size)
        step = (64, 64, 64)

        start_time = datetime.datetime.now()

        for row in range(0, imgs_S.shape[0] - input_sz[0], step[0]):
            for col in range(0, imgs_S.shape[1] - input_sz[1], step[1]):
                for vol in range(0, imgs_S.shape[2] - input_sz[2], step[2]):
                    patch_sub_img = np.zeros((1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_S.dtype)
                    patch_templ_img = np.zeros((1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_T.dtype)

                    patch_sub_img[0, :, :, :, 0] = imgs_S[row:row + input_sz[0],
                                                   col:col + input_sz[1],
                                                   vol:vol + input_sz[2]]
                    patch_templ_img[0, :, :, :, 0] = imgs_T[row:row + input_sz[0],
                                                     col:col + input_sz[1],
                                                     vol:vol + input_sz[2]]

                    patch_predict_phi = self.generator.predict([patch_sub_img, patch_templ_img])
                    patch_predict_warped = self.transformation.predict([patch_sub_img, patch_predict_phi])

                    predict_img[row:row + output_sz[0],
                                col:col + output_sz[1],
                                vol:vol + output_sz[2]] = patch_predict_warped[0, :, :, :, 0]
                    predict_phi[row :row  + output_sz[0],
                               col :col  + output_sz[1],
                               vol :vol  + output_sz[2],:] = patch_predict_phi[0, :, :, :, :]

        elapsed_time = datetime.datetime.now() - start_time
        print(" --- Prediction time: %s" % (elapsed_time))

        nrrd.write(path+"generated_unet_nogap/%d_%d_%d" % (epoch, batch_i, idx), predict_img)
        self.data_loader._write_nifti(path+"generated_unet_nogap/phi%d_%d_%d" % (epoch, batch_i, idx), predict_phi)

        if epoch%10 == 0:
            file_name = 'gan_network'+ str(epoch)
            # save the whole network
            self.combined.save(path+ 'generated_unet_nogap/'+ file_name + '.whole.h5', overwrite=True)
            print('Save the whole network to disk as a .whole.h5 file')
            model_jason = self.combined.to_json()
            with open(path+ 'generated_unet_nogap/'+file_name + '_arch.json', 'w') as json_file:
                json_file.write(model_jason)
                self.combined.save_weights(path+ 'generated_unet_nogap/'+file_name + '_weights.h5', overwrite=True)
            print('Save the network architecture in .json file and weights in .h5 file')