class GANUnetModel148(): def __init__(self): K.set_image_data_format('channels_last') # set format K.set_image_dim_ordering('tf') self.DEBUG = 1 self.crop_size_g = (148, 148, 148) self.crop_size_d = (60, 60, 60) self.channels = 1 self.input_shape_g = self.crop_size_g + (self.channels, ) self.input_shape_d = self.crop_size_d + (self.channels, ) self.output_shape_g = (60, 60, 60) + ( 3, ) # phi has three outputs. one for each X, Y, and Z dimensions self.output_shape_d = (15, 15, 15) + (self.channels, ) #self.output_shape_d_v2 = (5, 5, 5) + (self.channels,) self.batch_sz = 1 # for testing locally to avoid memory allocation # Number of filters in the first layer of G and D self.gf = 32 self.df = 32 optimizerD = Adam( 0.001, decay=0.05 ) # in the paper the learning rate is 0.001 and weight decay is 0.5 #optimizerD = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) self.decay = 0.5 self.iterations_decay = 50 self.learning_rate = 0.001 optimizerG = Adam( 0.001, decay=0.05) # in the paper the decay after 50K iterations by 0.5 #optimizerG = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) # Build the three networks self.generator = self.build_generator() self.generator.summary() self.discriminator = self.build_discriminator() self.discriminator.summary() self.transformation = self.build_transformation() self.transformation.summary() # Compile the discriminator self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizerD, metrics=['accuracy']) # Build the generator img_S = Input(shape=self.input_shape_g) # subject image S img_T = Input(shape=self.input_shape_g) # template image T # By conditioning on T generate a warped transformation function of S phi = self.generator([img_S, img_T]) # Transform S warped_S = self.transformation([img_S, phi]) # Use Python partial to provide loss function with additional deformable field argument partial_gp_loss = partial(self.gradient_penalty_loss, phi=phi) partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names # For the combined model we will only train the generator self.discriminator.trainable = False # Discriminators determines validity of translated images / condition pairs validity = self.discriminator([warped_S, img_T]) self.combined = Model(inputs=[img_S, img_T], outputs=validity) self.combined.summary() self.combined.compile(loss=partial_gp_loss, optimizer=optimizerG) if self.DEBUG: log_path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/logs_ganunet_v148/' self.callback = TensorBoard(log_path) self.callback.set_model(self.combined) self.data_loader = DataLoader(batch_sz=self.batch_sz, crop_size=self.crop_size_g, dataset_name='fly') """ Generator Network """ def build_generator(self): """U-Net Generator""" def conv3d(input_tensor, n_filters, kernel_size=(3, 3, 3), batch_normalization=True, scale=True, padding='valid', use_bias=False, name=''): """ 3D convolutional layer (+ batch normalization) followed by ReLu activation """ layer = Conv3D(filters=n_filters, kernel_size=kernel_size, padding=padding, use_bias=use_bias, name=name + '_conv3d')(input_tensor) # if batch_normalization: # layer = BatchNormalization(name=name+'_bn')(layer) #layer = Activation('relu', name=name+'_actrelu')(layer) # Add BN after activation if batch_normalization: layer = BatchNormalization(momentum=0.8, name=name + '_bn', scale=scale)(layer) layer = LeakyReLU(alpha=0.2, name=name + '_actleakyrelu')(layer) return layer def deconv3d(input_tensor, n_filters, kernel_size=(3, 3, 3), batch_normalization=True, scale=True, padding='valid', use_bias=False, name=''): """ 3D deconvolutional layer (+ batch normalization) followed by ReLu activation """ layer = UpSampling3D(size=2)(input_tensor) layer = Conv3D(filters=n_filters, kernel_size=kernel_size, padding=padding, use_bias=use_bias, name=name + '_conv3d')(layer) # BN before activation if batch_normalization: layer = BatchNormalization(momentum=0.8, name=name + '_bn', scale=scale)(layer) layer = LeakyReLU(alpha=0.2, name=name + '_actleakyrelu')(layer) return layer img_S = Input(shape=self.input_shape_g, name='input_img_S') # 148x148x148 img_T = Input(shape=self.input_shape_g, name='input_img_T') # 148x148x148 # Concatenate subject image and template image by channels to produce input #combined_imgs = Concatenate(axis=-1, name='combine_imgs_g')([img_S, img_T]) combined_imgs = Add(name='combine_imgs_g')([img_S, img_T]) # downsampling down1 = conv3d(input_tensor=combined_imgs, n_filters=self.gf, padding='valid', name='down1_1') #146 down1 = conv3d(input_tensor=down1, n_filters=self.gf, padding='valid', name='down1_2') #144 pool1 = MaxPooling3D(pool_size=(2, 2, 2), name='pool1')(down1) #72 down2 = conv3d(input_tensor=pool1, n_filters=2 * self.gf, padding='valid', name='down2_1') #70 down2 = conv3d(input_tensor=down2, n_filters=2 * self.gf, padding='valid', name='down2_2') #68 pool2 = MaxPooling3D(pool_size=(2, 2, 2), name='pool2')(down2) #34 down3 = conv3d(input_tensor=pool2, n_filters=4 * self.gf, padding='valid', name='down3_1') #32 down3 = conv3d(input_tensor=down3, n_filters=4 * self.gf, padding='valid', name='down3_2') #30 pool3 = MaxPooling3D(pool_size=(2, 2, 2), name='pool3')(down3) #15 center = conv3d(input_tensor=pool3, n_filters=8 * self.gf, padding='valid', name='center1') #13 center = conv3d(input_tensor=center, n_filters=8 * self.gf, padding='valid', name='center2') #11 # upsampling with gap filling up3 = deconv3d(input_tensor=center, n_filters=4 * self.gf, padding='same', name='up3') #22 gap3 = conv3d(input_tensor=down3, n_filters=4 * self.gf, padding='valid', name='gap3_1') #28 gap3 = conv3d(input_tensor=gap3, n_filters=4 * self.gf, padding='valid', name='gap3_2') #26 up3 = concatenate([Cropping3D(2)(gap3), up3], name='up3concat') #22 up3 = conv3d(input_tensor=up3, n_filters=4 * self.gf, padding='valid', name='up3conv_1') #20 up3 = conv3d(input_tensor=up3, n_filters=4 * self.gf, padding='valid', name='up3conv_2') #18 up2 = deconv3d(input_tensor=up3, n_filters=2 * self.gf, padding='same', name='up2') #36 gap2 = conv3d(input_tensor=down2, n_filters=2 * self.gf, padding='valid', name='gap2_1') #66 for i in range(2, 7): gap2 = conv3d(input_tensor=gap2, n_filters=2 * self.gf, padding='valid', name='gap2_' + str(i)) #56 up2 = concatenate([Cropping3D(10)(gap2), up2], name='up2concat') #36 up2 = conv3d(input_tensor=up2, n_filters=2 * self.gf, padding='valid', name='up2conv_1') #34 up2 = conv3d(input_tensor=up2, n_filters=2 * self.gf, padding='valid', name='up2conv_2') #32 up1 = deconv3d(input_tensor=up2, n_filters=self.gf, padding='same', name='up1') #64 gap1 = conv3d(input_tensor=down1, n_filters=self.gf, padding='valid', name='gap1_1') #142 for i in range(2, 21): gap1 = conv3d(input_tensor=gap1, n_filters=self.gf, padding='valid', name='gap1_' + str(i)) #104 up1 = concatenate([Cropping3D(20)(gap1), up1], name='up1concat') #64 up1 = conv3d(input_tensor=up1, n_filters=self.gf, padding='valid', name='up1conv_1') #62 up1 = conv3d(input_tensor=up1, n_filters=self.gf, padding='valid', name='up1conv_2') #60 phi = Conv3D(filters=3, kernel_size=(1, 1, 1), padding='same', use_bias=False, name='phi')(up1) #60 model = Model([img_S, img_T], outputs=phi, name='generator_model') return model """ Discriminator Network """ def build_discriminator(self): def d_layer(layer_input, filters, f_size=5, bn=True, scale=True, name=''): #change the bn to False """Discriminator layer""" d = Conv3D(filters, kernel_size=f_size, strides=1, padding='same', name=name + '_conv3d')(layer_input) if bn: d = BatchNormalization(momentum=0.8, name=name + '_bn', scale=scale)(d) d = LeakyReLU(alpha=0.2, name=name + '_leakyrelu')(d) return d img_S = Input(shape=self.input_shape_d, name='input_img_A') # 60 warped_img or reference img_T = Input(shape=self.input_shape_g, name='input_img_T') # 148 template img_T_cropped = Cropping3D(cropping=44)(img_T) # 60 # Concatenate image and conditioning image by channels to produce input #combined_imgs = Concatenate(axis=-1, name='combine_imgs_d')([img_S, img_T_cropped]) combined_imgs = Add(name='combine_imgs_d')([img_S, img_T_cropped]) d1 = d_layer(combined_imgs, self.df, bn=False, name='d1') # 60 d2 = d_layer(d1, self.df * 2, name='d2') # 60 pool = MaxPooling3D(pool_size=(2, 2, 2), name='d2_pool')(d2) # 30 d3 = d_layer(pool, self.df * 4, name='d3') # 30 d4 = d_layer(d3, self.df * 8, name='d4') # 30 pool = MaxPooling3D(pool_size=(2, 2, 2), name='d4_pool')(d4) # 15 d5 = d_layer(pool, self.df * 8, name='d5') # 15 # ToDo: Use FC layer at the end like specified in the paper validity = Conv3D(1, kernel_size=4, strides=1, padding='same', activation='sigmoid', name='validity')(d5) #9 #d6 = Conv3D(1, kernel_size=4, strides=1, padding='same', name='validity')(d5) # 6x6x6 #validity = Flatten(data_format='channels_last')(d6) #x = Reshape((6*6*6*512,))(d5) # hack to avoid flatten bug #validity = Dense(1, activation='sigmoid')(x) # Use FC layer #d6 = Flatten(input_shape=(self.batch_sz,) + (6,6,6,512))(d5) #validity = Dense(1, activation='sigmoid')(d5) return Model([img_S, img_T], validity, name='discriminator_model') """ Discriminator Network v2 """ def build_discriminator_v2(self): def d_layer(layer_input, filters, f_size=5, bn=True, name=''): """Discriminator layer""" d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same', name=name + '_conv3d')(layer_input) d = LeakyReLU(alpha=0.2, name=name + '_leakyrelu')(d) if bn: d = BatchNormalization(momentum=0.8, name=name + '_bn')(d) return d img_A = Input(shape=self.input_shape_d, name='input_img_A') # 68x68x68 warped_img or reference img_T = Input(shape=self.input_shape_g, name='input_img_T') # 148x148x148 template img_T_cropped = Cropping3D(cropping=40)(img_T) # 68x68x68 # Concatenate image and conditioning image by channels to produce input combined_imgs = Concatenate(axis=-1)([img_A, img_T_cropped]) d1 = d_layer(combined_imgs, self.df, bn=False, name='d1') d2 = d_layer(d1, self.df * 2, name='d2') d3 = d_layer(d2, self.df * 4, name='d3') d4 = d_layer(d3, self.df * 8, name='d4') #d5 = d_layer(d4, self.df*16, name='d5') # d5 = Flatten()(d4) # validity = Dense(1, activation='sigmoid')(d5) validity = Conv3D(1, kernel_size=4, strides=1, padding='same', activation='sigmoid', name='disc_sig')(d4) # 5x5x5 return Model([img_A, img_T], validity, name='discriminator_model') """ Deformable Transformation Layer """ def build_transformation(self): img_S = Input(shape=self.input_shape_g, name='input_img_S_transform') # 148 phi = Input(shape=self.output_shape_g, name='input_phi_transform') # 60 img_S_cropped = Cropping3D(cropping=44)(img_S) # 60 warped_S = Lambda(dense_image_warp_3D, output_shape=(60, 60, 60, 1))([img_S_cropped, phi]) return Model([img_S, phi], warped_S, name='transformation_layer') """ Define losses """ def gradient_penalty_loss(self, y_true, y_pred, phi): """ Computes gradient penalty on phi to ensure smoothness """ lr = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1) # compute the numerical gradient of phi gradients = numerical_gradient_3D(phi) # #if self.DEBUG: gradients = K.print_tensor(gradients, message='gradients are:') # # compute the euclidean norm by squaring ... gradients_sqr = K.square(gradients) # ... summing over the rows ... gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape))) # # ... and sqrt gradient_l2_norm = K.sqrt(gradients_sqr_sum) # # compute lambda * (1 - ||grad||)^2 still for each single sample # #gradient_penalty = K.square(1 - gradient_l2_norm) # # return the mean as loss over all the batch samples return K.mean(gradient_l2_norm) + lr # return gradients_sqr_sum + lr """ Training """ def train(self, epochs, batch_size=1, sample_interval=50): # Adversarial loss ground truths disc_patch = self.output_shape_d input_sz = 148 output_sz = 60 gap = int((input_sz - output_sz) / 2) # hard labels validhard = np.ones((self.batch_sz, ) + disc_patch) fakehard = np.zeros((self.batch_sz, ) + disc_patch) # hard labels with only one output #validhard = np.ones((self.batch_sz, 1)) #fakehard = np.zeros((self.batch_sz, 1)) # soft labels only smooth the labels of postive samples # https://arxiv.org/abs/1701.00160 # https://github.com/soumith/ganhacks/issues/41 #smooth = 0.1 # validhard -smooth #validsoft = 0.9 + 0.1 * np.random.random_sample((self.batch_sz,) + disc_patch) # random between [0.9, 1) #fakesoft = 0.1 * np.random.random_sample((self.batch_sz,) + disc_patch) # random between [0, 0.1) validsoft = np.random.uniform(low=0.7, high=1.2, size=(self.batch_sz, ) + disc_patch) fakesoft = np.random.uniform(low=0.0, high=0.3, size=(self.batch_sz, ) + disc_patch) start_time = datetime.datetime.now() for epoch in range(epochs): for batch_i, (batch_img, batch_img_template, batch_img_golden) in enumerate( self.data_loader.load_batch()): # --------------------- # Train Discriminator # --------------------- #assert not np.any(np.isnan(batch_img)) #assert not np.any(np.isnan(batch_img_template)) phi = self.generator.predict([batch_img, batch_img_template]) #24x24x24 #assert not np.any(np.isnan(phi)) # deformable transformation transform = self.transformation.predict([batch_img, phi]) #24x24x24 #assert not np.any(np.isnan(transform)) # Create a ref image by perturbing th subject image with the template image perturbation_factor_alpha = 0.1 if epoch > epochs / 2 else 0.2 batch_ref = perturbation_factor_alpha * batch_img + ( 1 - perturbation_factor_alpha) * batch_img_template #64x64x64 batch_img_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels), dtype=batch_img.dtype) batch_ref_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels), dtype=batch_ref.dtype) batch_temp_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels), dtype=batch_img_template.dtype) #batch_golden_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels), dtype=batch_img_golden.dtype) # take only (24,24,24) from the (64,64,64) size batch_img_sub[:, :, :, :, :] = batch_img[:, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, :] batch_ref_sub[:, :, :, :, :] = batch_ref[:, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, :] #batch_golden_sub[:, :, :, :, :] = batch_img_golden[:, 0 + gap:0 + gap + output_sz, # 0 + gap:0 + gap + output_sz, # 0 + gap:0 + gap + output_sz, :] batch_temp_sub[:, :, :, :, :] = batch_img_template[:, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, :] #assert not np.any(np.isnan(batch_img_sub)) #assert not np.any(np.isnan(batch_ref_sub)) #assert not np.any(np.isnan(batch_temp_sub)) # Train the discriminator (R -> T is valid, S -> T is fake) # Noisy and soft labels noisy_prob = 1 - np.sqrt(1 - np.random.random( )) # peak near low values and falling off towards high values if noisy_prob < 0.85: # occasionally flip labels to introduce noisy labels d_loss_real = self.discriminator.train_on_batch( [batch_ref_sub, batch_img_template], validhard) d_loss_fake = self.discriminator.train_on_batch( [transform, batch_img_template], fakehard) else: d_loss_real = self.discriminator.train_on_batch( [batch_ref_sub, batch_img_template], fakehard) d_loss_fake = self.discriminator.train_on_batch( [transform, batch_img_template], validhard) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- # Train the generator (to fool the discriminator) g_loss = self.combined.train_on_batch( [batch_img, batch_img_template], validhard) elapsed_time = datetime.datetime.now() - start_time print( "[Epoch %d/%d] [Batch %d/%d] [D loss average: %f, acc average: %3d%%, D loss fake:%f, acc: %3d%%, D loss real: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs, batch_i, self.data_loader.n_batches, d_loss[0], 100 * d_loss[1], d_loss_fake[0], 100 * d_loss_fake[1], d_loss_real[0], 100 * d_loss_real[1], g_loss, elapsed_time)) if self.DEBUG: #self.write_log(self.callback, ['g_loss'], [g_loss[0]], batch_i) self.write_log(self.callback, ['g_loss'], [g_loss], batch_i) self.write_log(self.callback, ['d_loss'], [d_loss[0]], batch_i) # If at save interval => save generated image samples if batch_i % sample_interval == 0 and epoch != 0 and epoch % 5 == 0: self.sample_images(epoch, batch_i) def write_log(self, callback, names, logs, batch_no): #https://github.com/eriklindernoren/Keras-GAN/issues/52 for name, value in zip(names, logs): summary = tf.Summary() summary_value = summary.value.add() summary_value.simple_value = value summary_value.tag = name callback.writer.add_summary(summary, batch_no) callback.writer.flush() def sample_images(self, epoch, batch_i): path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/' os.makedirs(path + 'generated_v148/', exist_ok=True) idx, imgs_S = self.data_loader.load_data(is_validation=True) imgs_T = self.data_loader.img_template # Mask the image with the template mask # already done in the preparing phase # imgs_T_mask = self.data_loader.mask_template # imgs_S = imgs_S * imgs_T_mask predict_img = np.zeros(imgs_S.shape, dtype=imgs_S.dtype) predict_phi = np.zeros(imgs_S.shape + (3, ), dtype=imgs_S.dtype) input_sz = self.crop_size_g step = (60, 60, 60) gap = (int( (input_sz[0] - step[0]) / 2), int( (input_sz[1] - step[1]) / 2), int((input_sz[2] - step[2]) / 2)) start_time = datetime.datetime.now() for row in range(0, imgs_S.shape[0] - input_sz[0], step[0]): for col in range(0, imgs_S.shape[1] - input_sz[1], step[1]): for vol in range(0, imgs_S.shape[2] - input_sz[2], step[2]): patch_sub_img = np.zeros( (1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_S.dtype) patch_templ_img = np.zeros( (1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_T.dtype) patch_sub_img[0, :, :, :, 0] = imgs_S[row:row + input_sz[0], col:col + input_sz[1], vol:vol + input_sz[2]] patch_templ_img[0, :, :, :, 0] = imgs_T[row:row + input_sz[0], col:col + input_sz[1], vol:vol + input_sz[2]] patch_predict_phi = self.generator.predict( [patch_sub_img, patch_templ_img]) patch_predict_warped = self.transformation.predict( [patch_sub_img, patch_predict_phi]) predict_img[row + gap[0]:row + gap[0] + step[0], col + gap[1]:col + gap[1] + step[1], vol + gap[2]:vol + gap[2] + step[2]] = patch_predict_warped[0, :, :, :, 0] predict_phi[row + gap[0]:row + gap[0] + step[0], col + gap[1]:col + gap[1] + step[1], vol + gap[2]:vol + gap[2] + step[2], :] = patch_predict_phi[0, :, :, :, :] elapsed_time = datetime.datetime.now() - start_time print(" --- Prediction time: %s" % (elapsed_time)) nrrd.write(path + "generated_v148/%d_%d_%d" % (epoch, batch_i, idx), predict_img) self.data_loader._write_nifti( path + "generated_v148/phi%d_%d_%d" % (epoch, batch_i, idx), predict_phi) file_name = 'gan_network' # save the whole network gan.combined.save(path + 'generated_v148/' + file_name + str(epoch) + '.whole.h5', overwrite=True) print('Save the whole network to disk as a .whole.h5 file') model_jason = gan.combined.to_json() with open( path + 'generated_v148/' + file_name + str(epoch) + '_arch.json', 'w') as json_file: json_file.write(model_jason) gan.combined.save_weights(path + 'generated_v148/' + file_name + str(epoch) + '_weights.h5', overwrite=True) print( 'Save the network architecture in .json file and weights in .h5 file' )
class GAN_pix2pix(): def __init__(self): K.set_image_data_format('channels_last') # set format self.DEBUG = 1 # Input shape self.img_rows = 256 self.img_cols = 256 self.img_vols = 256 self.channels = 1 self.batch_sz = 1 self.crop_size = (self.img_rows, self.img_cols, self.img_vols) self.img_shape = self.crop_size + (self.channels,) self.output_size = 128 self.output_shape_g = (self.output_size, self.output_size, self.output_size) + (3,) # phi has three outputs. one for each X, Y, and Z dimensions self.input_shape_d = (self.output_size, self.output_size, self.output_size)+ (1,) # Calculate output shape of D patch = int(self.output_size / 2 ** 4) self.output_shape_d = (patch, patch, patch, self.channels) # Number of filters in the first layer of G and D self.gf = 32 self.df = 32 optimizer = Adam(0.001, 0.05) #optimizer = SGD(lr=0.001, decay=1e-6, momentum=0.9, # nesterov=True) #optimizer = Adam(0.0002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.summary() self.discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) # ------------------------- # Construct Computational # Graph of Generator # ------------------------- # Build the generator self.generator = self.build_generator() self.generator.summary() self.transformation = self.build_transformation() self.transformation.summary() # Input images and their conditioning images img_S = Input(shape=self.img_shape) img_T = Input(shape=self.img_shape) # Generate the deformable funtion phi = self.generator([img_S, img_T]) # Transform S warped_S = self.transformation([img_S, phi]) # For the combined model we will only train the generator self.discriminator.trainable = False # Discriminators determines validity of translated images / condition pairs validity = self.discriminator([warped_S, img_T]) self.combined = Model(inputs=[img_S, img_T], outputs=[validity, warped_S]) self.combined.summary() partial_gp_loss = partial(self.smoothness_loss, phi=phi) partial_gp_loss.__name__ = 'smoothness' # Keras requires function names self.combined.compile(loss=[partial_gp_loss, 'mae'], loss_weights=[1, 100], optimizer=optimizer) if self.DEBUG: log_path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/logs_ganpix2pix_remod_golden_smooth/' self.callback = TensorBoard(log_path) self.callback.set_model(self.combined) self.data_loader = DataLoader(batch_sz=self.batch_sz, crop_size=self.crop_size, dataset_name='fly', min_max=False, restricted_mask=False, use_hist_equilized_data=False, use_sharpen=False, use_golden=True) def build_generator(self): """U-Net Generator""" def conv2d(layer_input, filters, f_size=4, bn=True): """Layers used during downsampling""" d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) if bn: d = BatchNormalization(momentum=0.8)(d) d = LeakyReLU(alpha=0.2)(d) return d def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0): # dropout is 50 ->change from the implementaion """Layers used during upsampling""" u = UpSampling3D(size=2)(layer_input) u = Conv3D(filters, kernel_size=f_size, padding='same')(u) # remove the strides if dropout_rate: u = Dropout(dropout_rate)(u) u = BatchNormalization(momentum=0.8)(u) u = Activation('relu')(u) u = Concatenate()([u, skip_input]) return u img_S = Input(shape=self.img_shape, name='input_img_S') #256 img_T = Input(shape=self.img_shape, name='input_img_T') #256 d0 = Concatenate(axis=-1, name='combine_imgs_g')([img_S, img_T]) #d0= Add(name='combine_imgs_g')([img_S, img_T]) #256 # Downsampling d1 = conv2d(d0, self.gf, bn=False) #128 d2 = conv2d(d1, self.gf*2) #64 d3 = conv2d(d2, self.gf*4) #32 d4 = conv2d(d3, self.gf*8) #16 d5 = conv2d(d4, self.gf*8) #8 d6 = conv2d(d5, self.gf*8) #4 d7 = conv2d(d6, self.gf*8) #2 # Upsampling u1 = deconv2d(d7, d6, self.gf*8) #4 u2 = deconv2d(u1, d5, self.gf*8) #8 u3 = deconv2d(u2, d4, self.gf*8) #16 u4 = deconv2d(u3, d3, self.gf*4) #32 u5 = deconv2d(u4, d2, self.gf*2) #64 u6 = deconv2d(u5, d1, self.gf) #128 #u7 = UpSampling3D(size=2)(u6) #256 phi = Conv3D(filters=3, kernel_size=1, strides=1, padding='same')(u6) #128 return Model([img_S, img_T], outputs=phi, name='generator_model') def build_discriminator(self): def d_layer(layer_input, filters, f_size=4, bn=True): """Discriminator layer""" d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) if bn: d = BatchNormalization(momentum=0.8)(d) d = LeakyReLU(alpha=0.2)(d) return d img_S = Input(shape=self.input_shape_d) #128 S img_T = Input(shape=self.img_shape) #256 T img_T_cropped = Cropping3D(cropping=64)(img_T) # 128 combined_imgs = Concatenate(axis=-1)([img_S, img_T_cropped]) #combined_imgs = Add()([img_S, img_T]) d1 = d_layer(combined_imgs, self.df, bn=False) d2 = d_layer(d1, self.df*2) d3 = d_layer(d2, self.df*4) d4 = d_layer(d3, self.df*8) validity = Conv3D(1, kernel_size=4, strides=1, padding='same', name='disc_sig')(d4) #original is linear activation no sigmoid return Model([img_S, img_T], validity, name='discriminator_model') def build_transformation(self): img_S = Input(shape=self.img_shape, name='input_img_S_transform') # 256 phi = Input(shape=self.output_shape_g, name='input_phi_transform') # 128 img_S_cropped = Cropping3D(cropping=64)(img_S) # 128 warped_S = Lambda(dense_image_warp_3D, output_shape=(128, 128, 128, 1))([img_S_cropped, phi]) return Model([img_S, phi], warped_S, name='transformation_layer') """ Define losses """ """ Computes gradient penalty on phi to ensure smoothness """ def smoothness_loss(self, y_true, y_pred, phi): # mean square error loss mse_loss = K.mean(K.square(y_pred - y_true), axis=-1) # compute the numerical gradient of phi gradients = numerical_gradient_3D(phi) # #if self.DEBUG: gradients = K.print_tensor(gradients, message='gradients are:') # # compute the euclidean norm by squaring ... gradients_sqr = K.square(gradients) # ... summing over the rows ... gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape))) # # ... and sqrt #gradient_l2_norm = K.sqrt(gradients_sqr_sum) # # compute lambda * (1 - ||grad||)^2 still for each single sample #gradient_penalty = K.square(1 - gradient_l2_norm) # # return the mean as loss over all the batch samples #return K.mean(gradient_l2_norm) + mse_loss return K.mean(gradients_sqr_sum) + mse_loss #return K.mean(gradient_penalty) + mse_loss """ Training """ def train(self, epochs, batch_size=1, sample_interval=50): DEBUG =1 path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/' os.makedirs(path+'generated_pix2pix_remod_smooth/' , exist_ok=True) output_sz = 128 input_sz = 256 gap = int((input_sz - output_sz)/2) # Adversarial loss ground truths valid = np.ones((self.batch_sz,) + self.output_shape_d) fake = np.zeros((self.batch_sz,) + self.output_shape_d) validsoft = np.random.uniform(low=0.7, high=1.2, size=(self.batch_sz,) + self.output_shape_d) fakesoft = np.random.uniform(low=0.0, high=0.3, size=(self.batch_sz,) + self.output_shape_d) start_time = datetime.datetime.now() for epoch in range(epochs): for batch_i, (batch_img, batch_img_template, batch_img_golden) in enumerate(self.data_loader.load_batch()): # --------------------- # Train Discriminator # --------------------- # Condition on template and generate a transform phi = self.generator.predict([batch_img, batch_img_template]) transform = self.transformation.predict([batch_img, phi]) # Create a ref image by perturbing th subject image with the template image perturbation_factor_alpha = 0.1 if epoch > epochs/2 else 0.2 batch_ref = perturbation_factor_alpha * batch_img + (1- perturbation_factor_alpha) * batch_img_template batch_img_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels), dtype=batch_img.dtype) batch_ref_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels), dtype=batch_ref.dtype) batch_temp_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels), dtype=batch_img_template.dtype) batch_golden_sub = np.zeros((self.batch_sz, output_sz, output_sz, output_sz, self.channels), dtype=batch_img_golden.dtype) # take only (24,24,24) from the (64,64,64) size batch_img_sub[:, :, :, :, :] = batch_img[:, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, :] batch_ref_sub[:, :, :, :, :] = batch_ref[:, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, :] batch_golden_sub[:, :, :, :, :] = batch_img_golden[:, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, :] batch_temp_sub[:, :, :, :, :] = batch_img_template[:, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, 0 + gap:0 + gap + output_sz, :] # use noisy targets to get the GAN out of any local minima (mode collapse) noisy_prob = 1 - np.sqrt( 1 - np.random.random()) # peak near low values and falling off towards high values if noisy_prob < 0.85: # occasionally flip labels to introduce noisy labels d_loss_real = self.discriminator.train_on_batch([batch_golden_sub, batch_img_template], valid) d_loss_fake = self.discriminator.train_on_batch([transform, batch_img_template], fake) else: d_loss_real = self.discriminator.train_on_batch([batch_golden_sub, batch_img_template], fake) d_loss_fake = self.discriminator.train_on_batch([transform, batch_img_template], valid) # d_loss_real = self.discriminator.train_on_batch([batch_img_golden, batch_img_template], valid) # d_loss_fake = self.discriminator.train_on_batch([transform, batch_img_template], fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # ----------------- # Train Generator # ----------------- g_loss = self.combined.train_on_batch([batch_img, batch_img_template], [valid, batch_golden_sub]) elapsed_time = datetime.datetime.now() - start_time print( "[Epoch %d/%d] [Batch %d/%d] [D loss average: %f, acc average: %3d%%, D loss fake:%f, acc: %3d%%, D loss real: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs, batch_i, self.data_loader.n_batches, d_loss[0], 100 * d_loss[1], d_loss_fake[0], 100 * d_loss_fake[1], d_loss_real[0], 100 * d_loss_real[1], g_loss[0], elapsed_time)) if self.DEBUG: self.write_log(self.callback, ['g_loss'], [g_loss[0]], batch_i) self.write_log(self.callback, ['d_loss'], [d_loss[0]], batch_i) if batch_i % sample_interval == 0 and epoch != 0 and epoch % 5 == 0: self.sample_images(epoch, batch_i) def write_log(self, callback, names, logs, batch_no): #https://github.com/eriklindernoren/Keras-GAN/issues/52 for name, value in zip(names, logs): summary = tf.Summary() summary_value = summary.value.add() summary_value.simple_value = value summary_value.tag = name callback.writer.add_summary(summary, batch_no) callback.writer.flush() def sample_images(self, epoch, batch_i): path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/' os.makedirs(path+'generated_pix2pix_remod_golden_smooth/' , exist_ok=True) idx, imgs_S = self.data_loader.load_data(is_validation=True) imgs_T = self.data_loader.img_template predict_img = np.zeros(imgs_S.shape, dtype=imgs_S.dtype) predict_phi = np.zeros(imgs_S.shape + (3,), dtype=imgs_S.dtype) input_sz = self.crop_size output_sz = (self.output_size, self.output_size, self.output_size) step = (24, 24, 24) gap = (int((input_sz[0] - output_sz[0]) / 2), int((input_sz[1] - output_sz[1]) / 2), int((input_sz[2] - output_sz[2]) / 2)) start_time = datetime.datetime.now() for row in range(0, imgs_S.shape[0] - input_sz[0], step[0]): for col in range(0, imgs_S.shape[1] - input_sz[1], step[1]): for vol in range(0, imgs_S.shape[2] - input_sz[2], step[2]): patch_sub_img = np.zeros((1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_S.dtype) patch_templ_img = np.zeros((1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_T.dtype) patch_sub_img[0, :, :, :, 0] = imgs_S[row:row + input_sz[0], col:col + input_sz[1], vol:vol + input_sz[2]] patch_templ_img[0, :, :, :, 0] = imgs_T[row:row + input_sz[0], col:col + input_sz[1], vol:vol + input_sz[2]] patch_predict_phi = self.generator.predict([patch_sub_img, patch_templ_img]) patch_predict_warped = self.transformation.predict([patch_sub_img, patch_predict_phi]) predict_img[row + gap[0]:row + gap[0] + output_sz[0], col + gap[1]:col + gap[1] + output_sz[1], vol + gap[2]:vol + gap[2] + output_sz[2]] = patch_predict_warped[0, :, :, :, 0] predict_phi[row + gap[0]:row + gap[0] + output_sz[0], col + gap[1]:col + gap[1] + output_sz[1], vol + gap[2]:vol + gap[2] + output_sz[2],:] = patch_predict_phi[0, :, :, :, :] elapsed_time = datetime.datetime.now() - start_time print(" --- Prediction time: %s" % (elapsed_time)) nrrd.write(path+"generated_pix2pix_remod_golden_smooth/%d_%d_%d" % (epoch, batch_i, idx), predict_img) self.data_loader._write_nifti(path + "generated_pix2pix_remod_golden_smooth/phi%d_%d_%d" % (epoch, batch_i, idx), predict_phi) if epoch%10 == 0: file_name = 'gan_network' +str(epoch) # save the whole network gan.combined.save(path+'generated_pix2pix_remod_golden_smooth/'+file_name + '.whole.h5', overwrite=True) print('Save the whole network to disk as a .whole.h5 file') model_jason = gan.combined.to_json() with open(path+'generated_pix2pix_remod_golden_smooth/'+file_name + '_arch.json', 'w') as json_file: json_file.write(model_jason) gan.combined.save_weights(path+'generated_pix2pix_remod_golden_smooth/'+file_name + '_weights.h5', overwrite=True) print('Save the network architecture in .json file and weights in .h5 file')
class GANUnetNoGapFillingModel(): def __init__(self): K.set_image_data_format('channels_last') # set format K.set_image_dim_ordering('tf') self.DEBUG = 1 # Input shape self.img_rows = 192 self.img_cols = 192 self.img_vols = 192 self.channels = 1 self.batch_sz = 1 # for testing locally to avoid memory allocation self.crop_size = (self.img_rows, self.img_cols, self.img_vols) self.img_shape = self.crop_size + (self.channels,) self.output_size = 192 self.output_shape_g = self.crop_size + (3,) # phi has three outputs. one for each X, Y, and Z dimensions self.input_shape_d = self.crop_size + (1,) # Calculate output shape of D patch = int(self.output_size / 2 ** 4) self.output_shape_d = (patch, patch, patch, self.channels) self.batch_sz = 1 # for testing locally to avoid memory allocation # Number of filters in the first layer of G and D self.gf = 32 self.df = 32 # Train the discriminator faster than the generator #optimizerD = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) # in the paper the learning rate is 0.001 and weight decay is 0.5 optimizerD = Adam(0.001, decay=0.05) # in the paper the decay after 50K iterations by 0.5 self.decay = 0.5 self.iterations_decay = 50 self.learning_rate = 0.001 #optimizerG = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) # in the paper the decay after 50K iterations by 0.5 optimizerG = Adam(0.001, decay=0.05) # in the paper the decay after 50K iterations by 0.5 # Build the discriminator self.discriminator = self.build_discriminator() self.discriminator.summary() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizerD, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() self.generator.summary() # Build the deformable transformation layer self.transformation = self.build_transformation() self.transformation.summary() # Input images and their conditioning images img_S = Input(shape=self.img_shape) img_T = Input(shape=self.img_shape) # By conditioning on T generate a warped transformation function of S phi = self.generator([img_S, img_T]) # Transform S warped_S = self.transformation([img_S, phi]) # Use Python partial to provide loss function with additional deformable field argument partial_gp_loss = partial(self.gradient_penalty_loss, phi=phi) partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names # For the combined model we will only train the generator self.discriminator.trainable = False # Discriminators determines validity of translated images / condition pairs validity = self.discriminator([warped_S, img_T]) self.combined = Model(inputs=[img_S, img_T], outputs=validity) self.combined.summary() self.combined.compile(loss = partial_gp_loss, optimizer=optimizerG) if self.DEBUG: log_path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/logs_ganunet_nogap/' os.makedirs(log_path, exist_ok=True) self.callback = TensorBoard(log_path) self.callback.set_model(self.combined) self.data_loader = DataLoader(batch_sz=self.batch_sz, crop_size=self.crop_size, dataset_name='fly', min_max=False, restricted_mask=False, use_hist_equilized_data=False, use_golden=False) """ Generator Network """ def build_generator(self): """U-Net Generator""" def conv3d(input_tensor, n_filters, kernel_size=(3, 3, 3), batch_normalization=True, scale=True, padding='valid', use_bias=False, name=''): """ 3D convolutional layer (+ batch normalization) followed by ReLu activation """ layer = Conv3D(filters=n_filters, kernel_size=kernel_size, padding=padding, use_bias=use_bias, name=name + '_conv3d')(input_tensor) if batch_normalization: layer = BatchNormalization(momentum=0.8, name=name+'_bn', scale=scale)(layer) #layer = LeakyReLU(alpha=0.2, name=name + '_actleakyrelu')(layer) layer = Activation("relu")(layer) return layer def deconv3d(input_tensor, n_filters, kernel_size=(3, 3, 3), batch_normalization=True, scale=True, padding='valid', use_bias=False, name=''): """ 3D deconvolutional layer (+ batch normalization) followed by ReLu activation """ layer = UpSampling3D(size=2)(input_tensor) layer = Conv3D(filters=n_filters, kernel_size=kernel_size, padding=padding, use_bias=use_bias, name=name + '_conv3d')(layer) if batch_normalization: layer = BatchNormalization(momentum=0.8, name=name+'_bn', scale=scale)(layer) #layer = LeakyReLU(alpha=0.2, name=name + '_actleakyrelu')(layer) layer = Activation("relu")(layer) return layer img_S = Input(shape=self.img_shape, name='input_img_S') img_T = Input(shape=self.img_shape, name='input_img_T') combined_imgs = Add(name='combine_imgs_g')([img_S,img_T]) # downsampling down1 = conv3d(input_tensor=combined_imgs, n_filters=self.gf, padding='same', name='down1_1') # 192 down1 = conv3d(input_tensor=down1, n_filters=self.gf, padding='same', name='down1_2') # 192 pool1 = MaxPooling3D(pool_size=(2, 2, 2), name='pool1')(down1) # 96 down2 = conv3d(input_tensor=pool1, n_filters=2 * self.gf, padding='same', name='down2_1') # 96 down2 = conv3d(input_tensor=down2, n_filters=2 * self.gf, padding='same', name='down2_2') # 96 pool2 = MaxPooling3D(pool_size=(2, 2, 2), name='pool2')(down2) # 48 down3 = conv3d(input_tensor=pool2, n_filters=4 * self.gf, padding='same', name='down3_1') # 48 down3 = conv3d(input_tensor=down3, n_filters=4 * self.gf, padding='same', name='down3_2') # 48 pool3 = MaxPooling3D(pool_size=(2, 2, 2), name='pool3')(down3) #24 down4 = conv3d(input_tensor=pool3, n_filters=8 * self.gf, padding='same', name='down4_1') # 24 down4 = conv3d(input_tensor=down4, n_filters=8 * self.gf, padding='same', name='down4_2') # 24 pool4 = MaxPooling3D(pool_size=(2, 2, 2), name='pool4')(down4) # 12 down5 = conv3d(input_tensor=pool4, n_filters=8 * self.gf, padding='same', name='down5_1') # 12 down5 = conv3d(input_tensor=down5, n_filters=8 * self.gf, padding='same', name='down5_2') # 12 pool5 = MaxPooling3D(pool_size=(2, 2, 2), name='pool5')(down5) # 6 center = conv3d(input_tensor=pool5, n_filters=16 * self.gf, padding='same', name='center1') # 6 center = conv3d(input_tensor=center, n_filters=16 * self.gf, padding='same', name='center2') # 6 # upsampling up5 = deconv3d(input_tensor=center, n_filters = 8*self.gf, padding='same', name='up5') # 12 up5 = concatenate([up5,down5]) # 12 up5 = conv3d(input_tensor=up5, n_filters=8 * self.gf, padding='same', name='up5_1') # 12 up5 = conv3d(input_tensor=up5, n_filters=8 * self.gf, padding='same', name='up5_2') # 12 up4 = deconv3d(input_tensor=up5, n_filters=8 * self.gf, padding='same', name='up4') #24 up4 = concatenate([up4, down4]) # 24 up4 = conv3d(input_tensor=up4, n_filters=8 * self.gf, padding='same', name='up4_1') # 24 up4 = conv3d(input_tensor=up4, n_filters=8 * self.gf, padding='same', name='up4_2') # 24 up3 = deconv3d(input_tensor=up4, n_filters=4 * self.gf, padding='same', name='up3') #48 up3 = concatenate([up3, down3]) # 48 up3 = conv3d(input_tensor=up3, n_filters=4 * self.gf, padding='same', name='up3_1') # 48 up3 = conv3d(input_tensor=up3, n_filters=4 * self.gf, padding='same', name='up3_2') # 48 up2 = deconv3d(input_tensor=up3, n_filters=2 * self.gf, padding='same', name='up2') # 96 up2 = concatenate([up2, down2]) # 96 up2 = conv3d(input_tensor=up2, n_filters=2 * self.gf, padding='same', name='up2_1') # 96 up2 = conv3d(input_tensor=up2, n_filters=2 * self.gf, padding='same', name='up2_2') # 96 up1 = deconv3d(input_tensor=up2, n_filters=self.gf, padding='same', name='up1') # 192 up1 = concatenate([up1, down1]) # 192 up1 = conv3d(input_tensor=up1, n_filters=self.gf, padding='same', name='up1_1') # 192 up1 = conv3d(input_tensor=up1, n_filters=self.gf, padding='same', name='up1_2') # 192 phi = Conv3D(filters=3, kernel_size=(1, 1, 1), strides=1, use_bias=False, padding='same', name='phi')(up1) #192 model = Model([img_S, img_T], outputs=phi, name='generator_model') return model """ Discriminator Network """ def build_discriminator(self): def d_layer(layer_input, filters, f_size=4, bn=True): """Discriminator layer""" d = Conv3D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) if bn: d = BatchNormalization(momentum=0.8)(d) #d = LeakyReLU(alpha=0.2)(d) d = Activation("relu")(d) return d img_S = Input(shape=self.img_shape) #192 S img_T = Input(shape=self.img_shape) #192 T combined_imgs = Add()([img_S, img_T]) d1 = d_layer(combined_imgs, self.df, bn=False) d2 = d_layer(d1, self.df*2) d3 = d_layer(d2, self.df*4) d4 = d_layer(d3, self.df*8) validity = Conv3D(1, kernel_size=4, strides=1, padding='same', activation='sigmoid', name='disc_sig')(d4) return Model([img_S, img_T], validity, name='discriminator_model') """ Transformation Network """ def build_transformation(self): img_S = Input(shape=self.img_shape, name='input_img_S_transform') # 192 phi = Input(shape=self.output_shape_g, name='input_phi_transform') # 192 warped_S = Lambda(dense_image_warp_3D, output_shape=self.input_shape_d)([img_S, phi]) return Model([img_S, phi], warped_S, name='transformation_layer') """ Define losses """ def gradient_penalty_loss(self, y_true, y_pred, phi): """ Computes gradient penalty on phi to ensure smoothness """ lr = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1) # compute the numerical gradient of phi gradients = numerical_gradient_3D(phi) # #if self.DEBUG: gradients = K.print_tensor(gradients, message='gradients are:') # # compute the euclidean norm by squaring ... gradients_sqr = K.square(gradients) # ... summing over the rows ... gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape))) # # ... and sqrt gradient_l2_norm = K.sqrt(gradients_sqr_sum) # # compute lambda * (1 - ||grad||)^2 still for each single sample # #gradient_penalty = K.square(1 - gradient_l2_norm) # # return the mean as loss over all the batch samples return K.mean(gradient_l2_norm) + lr #return gradients_sqr_sum + lr """ Training """ def train(self, epochs, batch_size=1, sample_interval=50): # Adversarial loss ground truths # hard labels valid = np.ones((self.batch_sz,) + self.output_shape_d) fake = np.zeros((self.batch_sz,) + self.output_shape_d) start_time = datetime.datetime.now() for epoch in range(epochs): for batch_i, (batch_img, batch_img_template, batch_img_golden) in enumerate(self.data_loader.load_batch()): # --------------------- # Train Discriminator # --------------------- # Condition on B and generate a translate phi = self.generator.predict([batch_img, batch_img_template]) transform = self.transformation.predict([batch_img, phi]) # 256x256x256 # Create a ref image by perturbing th subject image with the template image perturbation_factor_alpha = 0.1 if epoch > epochs / 2 else 0.2 batch_ref = perturbation_factor_alpha * batch_img + (1 - perturbation_factor_alpha) * batch_img_template d_loss_real = self.discriminator.train_on_batch([batch_ref, batch_img_template], valid) d_loss_fake = self.discriminator.train_on_batch([transform, batch_img_template], fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- g_loss = self.combined.train_on_batch([batch_img, batch_img_template], valid) elapsed_time = datetime.datetime.now() - start_time print( "[Epoch %d/%d] [Batch %d/%d] [D loss average: %f, acc average: %3d%%, D loss fake:%f, acc: %3d%%, D loss real: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs, batch_i, self.data_loader.n_batches, d_loss[0], 100 * d_loss[1], d_loss_fake[0], 100 * d_loss_fake[1], d_loss_real[0], 100 * d_loss_real[1], g_loss, elapsed_time)) if self.DEBUG: self.write_log(self.callback, ['g_loss'], [g_loss], batch_i) self.write_log(self.callback, ['d_loss'], [d_loss[0]], batch_i) # If at save interval => save generated image samples if batch_i % sample_interval == 0 and epoch != 0 and epoch % 5 == 0: self.sample_images(epoch, batch_i) def write_log(self, callback, names, logs, batch_no): #https://github.com/eriklindernoren/Keras-GAN/issues/52 for name, value in zip(names, logs): summary = tf.Summary() summary_value = summary.value.add() summary_value.simple_value = value summary_value.tag = name callback.writer.add_summary(summary, batch_no) callback.writer.flush() def sample_images(self, epoch, batch_i): path = '/nrs/scicompsoft/elmalakis/GAN_Registration_Data/flydata/forSalma/lo_res/' os.makedirs(path+'generated_unet_nogap/' , exist_ok=True) idx, imgs_S = self.data_loader.load_data(is_validation=True) imgs_T = self.data_loader.img_template predict_img = np.zeros(imgs_S.shape, dtype=imgs_S.dtype) predict_phi = np.zeros(imgs_S.shape + (3,), dtype=imgs_S.dtype) input_sz = self.crop_size output_sz = (self.output_size, self.output_size, self.output_size) step = (64, 64, 64) start_time = datetime.datetime.now() for row in range(0, imgs_S.shape[0] - input_sz[0], step[0]): for col in range(0, imgs_S.shape[1] - input_sz[1], step[1]): for vol in range(0, imgs_S.shape[2] - input_sz[2], step[2]): patch_sub_img = np.zeros((1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_S.dtype) patch_templ_img = np.zeros((1, input_sz[0], input_sz[1], input_sz[2], 1), dtype=imgs_T.dtype) patch_sub_img[0, :, :, :, 0] = imgs_S[row:row + input_sz[0], col:col + input_sz[1], vol:vol + input_sz[2]] patch_templ_img[0, :, :, :, 0] = imgs_T[row:row + input_sz[0], col:col + input_sz[1], vol:vol + input_sz[2]] patch_predict_phi = self.generator.predict([patch_sub_img, patch_templ_img]) patch_predict_warped = self.transformation.predict([patch_sub_img, patch_predict_phi]) predict_img[row:row + output_sz[0], col:col + output_sz[1], vol:vol + output_sz[2]] = patch_predict_warped[0, :, :, :, 0] predict_phi[row :row + output_sz[0], col :col + output_sz[1], vol :vol + output_sz[2],:] = patch_predict_phi[0, :, :, :, :] elapsed_time = datetime.datetime.now() - start_time print(" --- Prediction time: %s" % (elapsed_time)) nrrd.write(path+"generated_unet_nogap/%d_%d_%d" % (epoch, batch_i, idx), predict_img) self.data_loader._write_nifti(path+"generated_unet_nogap/phi%d_%d_%d" % (epoch, batch_i, idx), predict_phi) if epoch%10 == 0: file_name = 'gan_network'+ str(epoch) # save the whole network self.combined.save(path+ 'generated_unet_nogap/'+ file_name + '.whole.h5', overwrite=True) print('Save the whole network to disk as a .whole.h5 file') model_jason = self.combined.to_json() with open(path+ 'generated_unet_nogap/'+file_name + '_arch.json', 'w') as json_file: json_file.write(model_jason) self.combined.save_weights(path+ 'generated_unet_nogap/'+file_name + '_weights.h5', overwrite=True) print('Save the network architecture in .json file and weights in .h5 file')