def train(self, initial_epoch, nb_epochs, steps_per_epoch=int(1e3)): """Model training method""" # Adversarial ground truths valid = np.ones((self.batch_generator.batch_size, 1)) fake = np.zeros((self.batch_generator.batch_size, 1)) for epoch in range(initial_epoch + 1, initial_epoch + nb_epochs): discrim_losses = [] for step in range(steps_per_epoch): # ------------------- # Train discriminator # ------------------- img, mask = next(self._get_batch()) generated = self.generator.predict(img) generated = np.reshape( generated, (self.batch_generator.batch_size, *self.mask_shape)) # train discriminator where real-like classified images are 1's # and fake-like ones are 0's discrim_loss_real = self.discriminator.train_on_batch( np.concatenate([img, mask], axis=-1), valid) discrim_loss_fake = self.discriminator.train_on_batch( np.concatenate([img, generated], axis=-1), fake) discrim_losses.append(discrim_loss_real) discrim_losses.append(discrim_loss_fake) # --------------- # Train generator # --------------- self.combined_model.train_on_batch(img, valid) logging.info('One image IOU: {:2f}'.format( DCGAN.metric(generated, mask))) discrim_loss = np.sum(discrim_losses, axis=0) / len(discrim_losses) logging.info('Discrimination loss: {:2f}, accucacy: {:2f}'.format( discrim_loss[0], 100 * discrim_loss[1])) # log error value self.loss_validate.error_log(epoch) if epoch % 50 == 0: self.combined_model.save_weights( self.dcgan_weights_path.format(epoch)) self.generator.save_weights( self.generator_weights_path.format(epoch)) self.discriminator.save_weights( self.discriminator_weights_path.format(epoch))
discriminator_nn.trainable = False # ------------------------ # Define Optimizers opt_discriminator = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) opt_dcgan = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # ------------------------- # compile generator generator_nn.compile(loss='mae', optimizer=opt_discriminator) # ---------------------- # MAKE FULL DCGAN # ---------------------- dc_gan_nn = DCGAN(generator_model=generator_nn, discriminator_model=discriminator_nn, input_img_dim=input_img_dim, patch_dim=sub_patch_dim) dc_gan_nn.summary() # --------------------- # Compile DCGAN # we use a combination of mae and bin_crossentropy loss = ['mae', 'binary_crossentropy'] loss_weights = [1E2, 1] dc_gan_nn.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) # --------------------- # ENABLE DISCRIMINATOR AND COMPILE discriminator_nn.trainable = True discriminator_nn.compile(loss='binary_crossentropy',
opt_DCGAN = Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon) # UNetGenerator generator = UNetGenerator(input_dim, output_channels) generator.summary() generator.compile(loss='mae', optimizer=opt_generator) # L1 loss for image generation # Patch GAN Discriminator nb_patches, patch_gan_dim = patch_utils.num_patches(output_dim, patch_size) discriminator = PatchGanDiscriminator(output_dim, patch_size, nb_patches) discriminator.summary() discriminator.trainable = False # DCGAN dc_gan = DCGAN(generator, discriminator, input_dim, patch_size) dc_gan.summary() # Total Loss loss = ['mae', 'binary_crossentropy'] loss_weights = [1e2, 1] dc_gan.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_DCGAN) discriminator.trainable = True discriminator.compile(loss='binary_crossentropy', optimizer=opt_discriminator) # ------------------- # set up tensorboard # -------------------