class Train(object): """ Main GAN trainer. Responsible for training the GAN and pre-training the generator autoencoder. """ def __init__(self, config): """ Construct a new GAN trainer :param Config config: The parsed network configuration. """ self.config = config LOG.info("CUDA version: {0}".format(version.cuda)) LOG.info("Creating data loader from path {0}".format(config.FILENAME)) self.data_loader = Data( config.FILENAME, config.BATCH_SIZE, polarisations=config.POLARISATIONS, # Polarisations to use frequencies=config.FREQUENCIES, # Frequencies to use max_inputs=config. MAX_SAMPLES, # Max inputs per polarisation and frequency normalise=config.NORMALISE) # Normalise inputs shape = self.data_loader.get_input_shape() width = shape[1] LOG.info("Creating models with input shape {0}".format(shape)) self._autoencoder = Autoencoder(width) self._discriminator = Discriminator(width) # TODO: Get correct input and output widths for generator self._generator = Generator(width, width) if config.USE_CUDA: LOG.info("Using CUDA") self.autoencoder = self._autoencoder.cuda() self.discriminator = self._discriminator.cuda() self.generator = self._generator.cuda() else: LOG.info("Using CPU") self.autoencoder = self._autoencoder self.discriminator = self._discriminator self.generator = self._generator def check_requeue(self, epochs_complete): """ Check and re-queue the training script if it has completed the desired number of training epochs per session :param int epochs_complete: Number of epochs completed :return: True if the script has been requeued, False if not :rtype bool """ if self.config.REQUEUE_EPOCHS > 0: if epochs_complete >= self.config.REQUEUE_EPOCHS: # We've completed enough epochs for this instance. We need to kill it and requeue LOG.info( "REQUEUE_EPOCHS of {0} met, calling REQUEUE_SCRIPT".format( self.config.REQUEUE_EPOCHS)) subprocess.call(self.config.REQUEUE_SCRIPT, shell=True, cwd=os.path.dirname( self.config.REQUEUE_SCRIPT)) return True # Requeue performed return False # No requeue needed def load_state(self, checkpoint, module, optimiser=None): """ Load the provided checkpoint into the provided module and optimiser. This function checks whether the load threw an exception and logs it to the user. :param Checkpoint checkpoint: The checkpoint to load :param module: The pytorch module to load the checkpoint into. :param optimiser: The pytorch optimiser to load the checkpoint into. :return: None if the load failed, int number of epochs in the checkpoint if load succeeded """ try: module.load_state_dict(checkpoint.module_state) if optimiser is not None: optimiser.load_state_dict(checkpoint.optimiser_state) return checkpoint.epoch except RuntimeError as e: LOG.exception( "Error loading module state. This is most likely an input size mismatch. Please delete the old module saved state, or change the input size" ) return None def close(self): """ Close the data loader used by the trainer. """ self.data_loader.close() def generate_labels(self, num_samples, pattern): """ Generate labels for the discriminator. :param int num_samples: Number of input samples to generate labels for. :param list pattern: Pattern to generator. Should be either [1, 0], or [0, 1] :return: New labels for the discriminator """ var = torch.FloatTensor([pattern] * num_samples) return var.cuda() if self.config.USE_CUDA else var def _train_autoencoder(self): """ Main training loop for the autencoder. This function will return False if: - Loading the autoencoder succeeded, but the NN model did not load the state dicts correctly. - The script needs to be re-queued because the NN has been trained for REQUEUE_EPOCHS :return: True if training was completed, False if training needs to continue. :rtype bool """ criterion = nn.SmoothL1Loss() optimiser = optim.Adam(self.generator.parameters(), lr=0.00003, betas=(0.5, 0.999)) checkpoint = Checkpoint("autoencoder") epoch = 0 if checkpoint.load(): epoch = self.load_state(checkpoint, self.autoencoder, optimiser) if epoch is not None and epoch >= self.config.MAX_AUTOENCODER_EPOCHS: LOG.info("Autoencoder already trained") return True else: LOG.info( "Autoencoder training beginning from epoch {0}".format( epoch)) else: LOG.info('Autoencoder checkpoint not found. Training from start') # Train autoencoder self._autoencoder.set_mode(Autoencoder.Mode.AUTOENCODER) vis_path = os.path.join( os.path.splitext(self.config.FILENAME)[0], "autoencoder", str(datetime.now())) with Visualiser(vis_path) as vis: epochs_complete = 0 while epoch < self.config.MAX_AUTOENCODER_EPOCHS: if self.check_requeue(epochs_complete): return False # Requeue needed and training not complete for step, (data, _, _) in enumerate(self.data_loader): if self.config.USE_CUDA: data = data.cuda() if self.config.ADD_DROPOUT: # Drop out parts of the input, but compute loss on the full input. out = self.autoencoder(nn.functional.dropout( data, 0.5)) else: out = self.autoencoder(data) loss = criterion(out.cpu(), data.cpu()) self.autoencoder.zero_grad() loss.backward() optimiser.step() vis.step_autoencoder(loss.item()) # Report data and save checkpoint fmt = "Epoch [{0}/{1}], Step[{2}/{3}], loss: {4:.4f}" LOG.info( fmt.format(epoch + 1, self.config.MAX_AUTOENCODER_EPOCHS, step, len(self.data_loader), loss)) epoch += 1 epochs_complete += 1 checkpoint.set(self.autoencoder.state_dict(), optimiser.state_dict(), epoch).save() LOG.info("Plotting autoencoder progress") vis.plot_training(epoch) data, _, _ = iter(self.data_loader).__next__() vis.test_autoencoder(epoch, self.autoencoder, data.cuda()) LOG.info("Autoencoder training complete") return True # Training complete def _train_gan(self): """ TODO: Add in autoencoder to perform dimensionality reduction on data TODO: Not working yet - trying to work out good autoencoder model first :return: """ criterion = nn.BCELoss() discriminator_optimiser = optim.Adam(self.discriminator.parameters(), lr=0.003, betas=(0.5, 0.999)) discriminator_scheduler = optim.lr_scheduler.LambdaLR( discriminator_optimiser, lambda epoch: 0.97**epoch) discriminator_checkpoint = Checkpoint("discriminator") discriminator_epoch = 0 if discriminator_checkpoint.load(): discriminator_epoch = self.load_state(discriminator_checkpoint, self.discriminator, discriminator_optimiser) else: LOG.info('Discriminator checkpoint not found') generator_optimiser = optim.Adam(self.generator.parameters(), lr=0.003, betas=(0.5, 0.999)) generator_scheduler = optim.lr_scheduler.LambdaLR( generator_optimiser, lambda epoch: 0.97**epoch) generator_checkpoint = Checkpoint("generator") generator_epoch = 0 if generator_checkpoint.load(): generator_epoch = self.load_state(generator_checkpoint, self.generator, generator_optimiser) else: LOG.info('Generator checkpoint not found') if discriminator_epoch is None or generator_epoch is None: epoch = 0 LOG.info( "Discriminator or generator failed to load, training from start" ) else: epoch = min(generator_epoch, discriminator_epoch) LOG.info("Generator loaded at epoch {0}".format(generator_epoch)) LOG.info("Discriminator loaded at epoch {0}".format( discriminator_epoch)) LOG.info("Training from lowest epoch {0}".format(epoch)) vis_path = os.path.join( os.path.splitext(self.config.FILENAME)[0], "gan", str(datetime.now())) with Visualiser(vis_path) as vis: real_labels = None # all 1s fake_labels = None # all 0s epochs_complete = 0 while epoch < self.config.MAX_EPOCHS: if self.check_requeue(epochs_complete): return # Requeue needed and training not complete for step, (data, noise1, noise2) in enumerate(self.data_loader): batch_size = data.size(0) if real_labels is None or real_labels.size( 0) != batch_size: real_labels = self.generate_labels(batch_size, [1.0]) if fake_labels is None or fake_labels.size( 0) != batch_size: fake_labels = self.generate_labels(batch_size, [0.0]) if self.config.USE_CUDA: data = data.cuda() noise1 = noise1.cuda() noise2 = noise2.cuda() # ============= Train the discriminator ============= # Pass real noise through first - ideally the discriminator will return 1 #[1, 0] d_output_real = self.discriminator(data) # Pass generated noise through - ideally the discriminator will return 0 #[0, 1] d_output_fake1 = self.discriminator(self.generator(noise1)) # Determine the loss of the discriminator by adding up the real and fake loss and backpropagate d_loss_real = criterion( d_output_real, real_labels ) # How good the discriminator is on real input d_loss_fake = criterion( d_output_fake1, fake_labels ) # How good the discriminator is on fake input d_loss = d_loss_real + d_loss_fake self.discriminator.zero_grad() d_loss.backward() discriminator_optimiser.step() # =============== Train the generator =============== # Pass in fake noise to the generator and get it to generate "real" noise # Judge how good this noise is with the discriminator d_output_fake2 = self.discriminator(self.generator(noise2)) # Determine the loss of the generator using the discriminator and backpropagate g_loss = criterion(d_output_fake2, real_labels) self.discriminator.zero_grad() self.generator.zero_grad() g_loss.backward() generator_optimiser.step() vis.step(d_loss_real.item(), d_loss_fake.item(), g_loss.item()) # Report data and save checkpoint fmt = "Epoch [{0}/{1}], Step[{2}/{3}], d_loss_real: {4:.4f}, d_loss_fake: {5:.4f}, g_loss: {6:.4f}" LOG.info( fmt.format(epoch + 1, self.config.MAX_EPOCHS, step + 1, len(self.data_loader), d_loss_real, d_loss_fake, g_loss)) epoch += 1 epochs_complete += 1 discriminator_checkpoint.set( self.discriminator.state_dict(), discriminator_optimiser.state_dict(), epoch).save() generator_checkpoint.set(self.generator.state_dict(), generator_optimiser.state_dict(), epoch).save() vis.plot_training(epoch) data, noise1, _ = iter(self.data_loader).__next__() if self.config.USE_CUDA: data = data.cuda() noise1 = noise1.cuda() vis.test(epoch, self.data_loader.get_input_size_first(), self.discriminator, self.generator, noise1, data) generator_scheduler.step(epoch) discriminator_scheduler.step(epoch) LOG.info("Learning rates: d {0} g {1}".format( discriminator_optimiser.param_groups[0]["lr"], generator_optimiser.param_groups[0]["lr"])) LOG.info("GAN Training complete") def __call__(self): """ Main training loop for the GAN. The training process is interruptable; the model and optimiser states are saved to disk each epoch, and the latest states are restored when the trainer is resumed. If the script is not able to load the generator's saved state, it will attempt to load the pre-trained generator autoencoder from the generator_decoder_complete checkpoint (if it exists). If this also fails, the generator is pre-trained as an autoencoder. This training is also interruptable, and will produce the generator_decoder_complete checkpoint on completion. On successfully restoring generator and discriminator state, the trainer will proceed from the earliest restored epoch. For example, if the generator is restored from epoch 7 and the discriminator is restored from epoch 5, training will proceed from epoch 5. Visualisation plots are produces each epoch and stored in /path_to_input_file_directory/{gan/generator_auto_encoder}/{timestamp}/{epoch} Each time the trainer is run, it creates a new timestamp directory using the current time. """ # Load the autoencoder, and train it if needed. if not self._train_autoencoder(): # Autoencoder training incomplete return
class SimpleModel(CustomModule, EmbeddingGenerator): ae_day: Autoencoder ae_night: Autoencoder def __init__(self): encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder() self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder()) self.ae_night = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder()) self.loss_fn = nn.L1Loss() self.optimizer_day = None self.optimizer_night = None self.scheduler_day = None self.scheduler_night = None def __call__(self, input): raise NotImplementedError def init_optimizers(self): """ Is called right before training and after model has been moved to GPU. Supposed to initialize optimizers and schedulers. """ self.optimizer_day = Adam(self.ae_day.parameters(), lr=1e-4) self.optimizer_night = Adam(self.ae_night.parameters(), lr=1e-4) self.scheduler_day = ReduceLROnPlateau(self.optimizer_day, patience=15, verbose=True) self.scheduler_night = ReduceLROnPlateau(self.optimizer_night, patience=15, verbose=True) def train_epoch(self, train_loader, epoch, use_cuda, log_path, **kwargs): loss_day_sum, loss_night_sum = 0, 0 for day_img, night_img in train_loader: if use_cuda: day_img, night_img = day_img.cuda(), night_img.cuda() # zero day gradients self.optimizer_day.zero_grad() # train day autoencoder out_day = self.ae_day(day_img) loss_day = self.loss_fn(out_day, day_img) # optimize loss_day.backward() self.optimizer_day.step() # zero night gradients self.optimizer_night.zero_grad() # train night autoencoder out_night = self.ae_night(night_img) loss_night = self.loss_fn(out_night, night_img) # optimize loss_night.backward() self.optimizer_night.step() loss_day_sum += loss_day loss_night_sum += loss_night loss_day_mean = loss_day_sum / len(train_loader) loss_night_mean = loss_night_sum / len(train_loader) self.scheduler_day.step(loss_day_mean, epoch) self.scheduler_night.step(loss_night_mean, epoch) # log losses log_str = f'[Epoch {epoch}] Train day loss: {loss_day_mean} Train night loss: {loss_night_mean}' print(log_str) with open(os.path.join(log_path, 'log.txt'), 'a+') as f: f.write(log_str + '\n') def validate(self, val_loader, epoch, use_cuda, log_path, **kwargs): loss_day_sum, loss_night_sum = 0, 0 day_img, night_img, out_day, out_night = (None, ) * 4 with torch.no_grad(): for day_img, night_img in val_loader: if use_cuda: day_img, night_img = day_img.cuda(), night_img.cuda() out_day = self.ae_day(day_img) loss_day = self.loss_fn(out_day, day_img) out_night = self.ae_night(night_img) loss_night = self.loss_fn(out_night, night_img) loss_day_sum += loss_day loss_night_sum += loss_night loss_day_mean = loss_day_sum / len(val_loader) loss_night_mean = loss_night_sum / len(val_loader) # domain translation day_to_night = self.ae_night.decode( self.ae_day.encode(day_img[0].unsqueeze(0))) night_to_day = self.ae_day.decode( self.ae_night.encode(night_img[0].unsqueeze(0))) # log losses log_str = f'[Epoch {epoch}] Val day loss: {loss_day_mean} Val night loss: {loss_night_mean}' print(log_str) with open(os.path.join(log_path, 'log.txt'), 'a+') as f: f.write(log_str + '\n') # save sample images samples = { 'day_img': day_img[0], 'night_img': night_img[0], 'out_day': out_day[0], 'out_night': out_night[0], 'day_to_night': day_to_night[0], 'night_to_day': night_to_day[0] } for name, img in samples.items(): ToPILImage()(img.cpu()).save( os.path.join(log_path, f'{epoch}_{name}.jpeg'), 'JPEG') def register_hooks(self, layers): """ This function is not supposed to be called from outside the class. """ handles = [] embedding_dict = {} def get_hook(name, embedding_dict): def hook(model, input, output): embedding_dict[name] = output.detach() return hook for layer in layers: hook = get_hook(layer, embedding_dict) handles.append( getattr(self.ae_day.encoder_upper, layer).register_forward_hook(hook)) return handles, embedding_dict def deregister_hooks(self, handles): """ This function is not supposed to be called from outside the class. """ for handle in handles: handle.remove() def get_day_embeddings(self, img, layers): """ Returns deep embeddings for the passed layers inside the upper encoder. """ handles, embedding_dict = self.register_hooks(layers) # forward pass self.ae_day.encode(img) self.deregister_hooks(handles) return embedding_dict def get_night_embeddings(self, img, layers): """ Returns deep embeddings for the passed layers inside the upper encoder. """ handles, embedding_dict = self.register_hooks(layers) # forward pass self.ae_night.encode(img) self.deregister_hooks(handles) return embedding_dict def train(self): self.ae_day.train() self.ae_night.train() def eval(self): self.ae_day.eval() self.ae_night.eval() def cuda(self): self.ae_day.cuda() self.ae_night.cuda() def state_dict(self): return { 'encoder_lower_day': self.ae_day.encoder_lower.state_dict(), 'encoder_lower_night': self.ae_night.encoder_lower.state_dict(), 'encoder_upper': self.ae_day.encoder_upper.state_dict(), 'decoder_day': self.ae_day.decoder.state_dict(), 'decoder_night': self.ae_night.decoder.state_dict() } def optim_state_dict(self): return { 'optimizer_day': self.optimizer_day.state_dict(), 'optimizer_night': self.optimizer_night.state_dict() } def load_state_dict(self, state): self.ae_day.encoder_lower.load_state_dict(state['encoder_lower_day']) self.ae_night.encoder_lower.load_state_dict( state['encoder_lower_night']) self.ae_day.encoder_upper.load_state_dict(state['encoder_upper']) self.ae_day.decoder.load_state_dict(state['decoder_day']) self.ae_night.decoder.load_state_dict(state['decoder_night']) def load_optim_state_dict(self, state): self.optimizer_day.load_state_dict(state['optimizer_day']) self.optimizer_night.load_state_dict(state['optimizer_night'])
class CycleModel(CustomModule): ae_day: Autoencoder ae_night: Autoencoder reconstruction_loss_factor: float cycle_loss_factor: float def __init__(self, reconstruction_loss_factor: float, cycle_loss_factor: float): # share weights of the upper encoder & lower decoder encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder() self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder()) self.ae_night = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder()) self.loss_fn = nn.L1Loss() self.reconstruction_loss_factor = reconstruction_loss_factor self.cycle_loss_factor = cycle_loss_factor self.optimizer = None self.scheduler = None def __call__(self, input): raise NotImplementedError def init_optimizers(self): """ Is called right before training and after model has been moved to GPU. Supposed to initialize optimizers and schedulers. """ parameters = set() parameters |= set(self.ae_day.parameters()) parameters |= set(self.ae_night.parameters()) self.optimizer = Adam(parameters) # initialize scheduler self.scheduler = ReduceLROnPlateau(self.optimizer, patience=15, verbose=True) def train_epoch(self, train_loader, epoch, use_cuda, log_path, **kwargs): loss_day2night2day_sum, loss_night2day2night_sum, loss_day2day_sum, loss_night2night_sum = 0, 0, 0, 0 for day_img, night_img in train_loader: if use_cuda: day_img, night_img = day_img.cuda(), night_img.cuda() # Day -> Night -> Day self.optimizer.zero_grad() loss_day2night2day, loss_day2day = self.cycle_plus_reconstruction_loss(day_img, self.ae_day, self.ae_night) loss = loss_day2night2day * self.cycle_loss_factor + loss_day2day * self.reconstruction_loss_factor loss.backward() self.optimizer.step() # Night -> Day -> Night self.optimizer.zero_grad() loss_night2day2night, loss_night2night \ = self.cycle_plus_reconstruction_loss(night_img, self.ae_night, self.ae_day) loss = loss_night2day2night * self.cycle_loss_factor + loss_night2night * self.reconstruction_loss_factor loss.backward() self.optimizer.step() loss_day2night2day_sum += loss_day2night2day loss_day2day_sum += loss_day2day loss_night2day2night_sum += loss_night2day2night loss_night2night_sum += loss_night2night loss_day2night2day_mean = loss_day2night2day_sum / len(train_loader) loss_day2day_mean = loss_day2day_sum / len(train_loader) loss_night2day2night_mean = loss_night2day2night_sum / len(train_loader) loss_night2night_mean = loss_night2night_sum / len(train_loader) loss_mean = (loss_day2night2day_mean + loss_day2day_mean + loss_night2day2night_mean + loss_night2night_mean)/4 self.scheduler.step(loss_mean, epoch) # log losses log_str = f'[Epoch {epoch}] ' \ f'Train loss day -> night -> day: {loss_day2night2day_mean} ' \ f'Train loss night -> day -> night: {loss_night2day2night_mean} ' \ f'Train loss day -> day: {loss_day2day_mean} ' \ f'Train loss night -> night: {loss_night2night_mean}' print(log_str) with open(os.path.join(log_path, 'log.txt'), 'a+') as f: f.write(log_str + '\n') def validate(self, val_loader, epoch, use_cuda, log_path, **kwargs): loss_day2night2day_sum, loss_night2day2night_sum, loss_day2day_sum, loss_night2night_sum = 0, 0, 0, 0 day_img, night_img = None, None with torch.no_grad(): for day_img, night_img in val_loader: if use_cuda: day_img, night_img = day_img.cuda(), night_img.cuda() # Day -> Night -> Day and Day -> Day loss_day2night2day, loss_day2day = \ self.cycle_plus_reconstruction_loss(day_img, self.ae_day, self.ae_night) # Night -> Day -> Night and Night -> Night loss_night2day2night, loss_night2night = \ self.cycle_plus_reconstruction_loss(night_img, self.ae_night, self.ae_day) loss_day2night2day_sum += loss_day2night2day loss_day2day_sum += loss_day2day loss_night2day2night_sum += loss_night2day2night loss_night2night_sum += loss_night2night loss_day2night2day_mean = loss_day2night2day_sum / len(val_loader) loss_night2day2night_mean = loss_night2day2night_sum / len(val_loader) loss_day2day_mean = loss_day2day_sum / len(val_loader) loss_night2night_mean = loss_night2night_sum / len(val_loader) # log losses log_str = f'[Epoch {epoch}] ' \ f'Val loss day -> night -> day: {loss_day2night2day_mean} ' \ f'Val loss night -> day -> night: {loss_night2day2night_mean} ' \ f'Val loss day -> day: {loss_day2day_mean} ' \ f'Val loss night -> night: {loss_night2night_mean}' print(log_str) with open(os.path.join(log_path, 'log.txt'), 'a+') as f: f.write(log_str + '\n') # create sample images latent_day = self.ae_day.encode(day_img[0].unsqueeze(0)) latent_night = self.ae_night.encode(night_img[0].unsqueeze(0)) # reconstruction day2day = self.ae_day.decode(latent_day) night2night = self.ae_night.decode(latent_night) # domain translation day2night = self.ae_night.decode(latent_day) night2day = self.ae_day.decode(latent_night) # cycle day2night2day = self.ae_day.decode(self.ae_night.encode(day2night)) night2day2night = self.ae_night.decode(self.ae_day.encode(night2day)) # save sample images samples = { 'day_img': day_img[0], 'night_img': night_img[0], 'day2day': day2day[0], 'night2night': night2night[0], 'day2night': day2night[0], 'night2day': night2day[0], 'day2night2day': day2night2day[0], 'night2day2night': night2day2night[0], } for name, img in samples.items(): ToPILImage()(img.cpu()).save(os.path.join(log_path, f'{epoch}_{name}.jpeg'), 'JPEG') def cycle_plus_reconstruction_loss(self, image, autoencoder1, autoencoder2): # send the image through the cycle intermediate_latent_1 = autoencoder1.encode(image) intermediate_opposite = autoencoder2.decode(intermediate_latent_1) intermediate_latent_2 = autoencoder2.encode(intermediate_opposite) cycle_img = autoencoder1.decode(intermediate_latent_2) # do simple reconstruction reconstructed_img = autoencoder1.decode(intermediate_latent_1) cycle_loss = self.loss_fn(cycle_img, image) reconstruction_loss = self.loss_fn(reconstructed_img, image) return cycle_loss, reconstruction_loss def train(self): self.ae_day.train() self.ae_night.train() def eval(self): self.ae_day.eval() self.ae_night.eval() def cuda(self): self.ae_day.cuda() self.ae_night.cuda() def state_dict(self): return { 'encoder_lower_day': self.ae_day.encoder_lower.state_dict(), 'encoder_lower_night': self.ae_night.encoder_lower.state_dict(), 'encoder_upper': self.ae_day.encoder_upper.state_dict(), 'decoder_day': self.ae_day.decoder.state_dict(), 'decoder_night': self.ae_night.decoder.state_dict() } def optim_state_dict(self): return { 'optimizer': self.optimizer.state_dict(), } def load_state_dict(self, state): self.ae_day.encoder_lower.load_state_dict(state['encoder_lower_day']) self.ae_night.encoder_lower.load_state_dict(state['encoder_lower_night']) self.ae_day.encoder_upper.load_state_dict(state['encoder_upper']) self.ae_day.decoder.load_state_dict(state['decoder_day']) self.ae_night.decoder.load_state_dict(state['decoder_night'])
class CycleVAE(CustomModule, EmbeddingGenerator): """ CycleVAE model. This is the model which was used for evaluation. """ def get_day_embeddings(self, img, layers): """ Returns deep embeddings for the passed layers inside the upper encoder. """ # forward pass latent = self.ae_day.encode(img)[0] return {'latent': latent} def get_night_embeddings(self, img, layers): """ Returns deep embeddings for the passed layers inside the upper encoder. """ # forward pass latent = self.ae_night.encode(img)[0] return {'latent': latent} def __init__(self, params: dict): self.params = params # share weights of the upper encoder & lower decoder encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder() self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder()) self.ae_night = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder()) self.reconst_loss = nn.L1Loss() self.optimizer = None self.scheduler = None def __call__(self, input): raise NotImplementedError def init_optimizers(self): """ Is called right before training and after model has been moved to GPU. Supposed to initialize optimizers and schedulers. """ params = list(self.ae_day.parameters()) + list( self.ae_night.parameters()) self.optimizer = Adam([p for p in params if p.requires_grad], lr=self.params['lr']) self.scheduler = ReduceLROnPlateau(self.optimizer, patience=self.params['patience'], verbose=True) def train_epoch(self, train_loader, epoch, use_cuda, log_path, **kwargs): loss_sum = 0 for img_day, img_night in train_loader: if use_cuda: img_day, img_night = img_day.cuda(), img_night.cuda() self.optimizer.zero_grad() latent_day, noise_day = self.ae_day.encode(img_day) latent_night, noise_night = self.ae_night.encode(img_night) # same domain reconstruction reconst_day = self.ae_day.decode(latent_day + noise_day) reconst_night = self.ae_night.decode(latent_night + noise_night) # cross domain night_to_day = self.ae_day.decode(latent_night + noise_night) day_to_night = self.ae_night.decode(latent_day + noise_day) # encode again for cycle loss latent_night_to_day, noise_night_to_day = self.ae_day.encode( night_to_day) latent_day_to_night, noise_day_to_night = self.ae_night.encode( day_to_night) # aaaand decode again reconst_cycle_day = self.ae_day.decode(latent_day_to_night + noise_day_to_night) reconst_cycle_night = self.ae_night.decode(latent_night_to_day + noise_night_to_day) # loss formulations loss_reconst_day = self.reconst_loss(reconst_day, img_day) loss_reconst_night = self.reconst_loss(reconst_night, img_night) loss_kl_reconst_day = kl_loss(latent_day) loss_kl_reconst_night = kl_loss(latent_night) loss_cycle_day = self.reconst_loss(reconst_cycle_day, img_day) loss_cycle_night = self.reconst_loss(reconst_cycle_night, img_night) loss_kl_cycle_day = kl_loss(latent_night_to_day) loss_kl_cycle_night = kl_loss(latent_day_to_night) loss = \ self.params['loss_reconst'] * (loss_reconst_day + loss_reconst_night) + \ self.params['loss_kl_reconst'] * (loss_kl_reconst_day + loss_kl_reconst_night) + \ self.params['loss_cycle'] * (loss_cycle_day + loss_cycle_night) + \ self.params['loss_kl_cycle'] * (loss_kl_cycle_day + loss_kl_cycle_night) loss.backward() self.optimizer.step() loss_sum += loss.detach().item() loss_mean = loss_sum / len(train_loader) self.scheduler.step(loss_mean, epoch) # log loss log_str = f'[Epoch {epoch}] Train loss: {loss_mean}' print(log_str) with open(os.path.join(log_path, 'log.txt'), 'a+') as f: f.write(log_str + '\n') def validate(self, val_loader, epoch, use_cuda, log_path, **kwargs): loss_sum = 0 img_day, img_night, reconst_day, reconst_night, reconst_cycle_day, reconst_cycle_night = ( None, ) * 6 with torch.no_grad(): for img_day, img_night in val_loader: if use_cuda: img_day, img_night = img_day.cuda(), img_night.cuda() latent_day, noise_day = self.ae_day.encode(img_day) latent_night, noise_night = self.ae_night.encode(img_night) # same domain reconstruction reconst_day = self.ae_day.decode(latent_day + noise_day) reconst_night = self.ae_night.decode(latent_night + noise_night) # cross domain night_to_day = self.ae_day.decode(latent_night + noise_night) day_to_night = self.ae_night.decode(latent_day + noise_day) # encode again for cycle loss latent_night_to_day, noise_night_to_day = self.ae_day.encode( night_to_day) latent_day_to_night, noise_day_to_night = self.ae_night.encode( day_to_night) # aaaand decode again reconst_cycle_day = self.ae_day.decode(latent_day_to_night + noise_day_to_night) reconst_cycle_night = self.ae_night.decode( latent_night_to_day + noise_night_to_day) # loss formulations loss_reconst_day = self.reconst_loss(reconst_day, img_day) loss_reconst_night = self.reconst_loss(reconst_night, img_night) loss_kl_reconst_day = kl_loss(latent_day) loss_kl_reconst_night = kl_loss(latent_night) loss_cycle_day = self.reconst_loss(reconst_cycle_day, img_day) loss_cycle_night = self.reconst_loss(reconst_cycle_night, img_night) loss_kl_cycle_day = kl_loss(latent_night_to_day) loss_kl_cycle_night = kl_loss(latent_day_to_night) loss = \ self.params['loss_reconst'] * (loss_reconst_day + loss_reconst_night) + \ self.params['loss_kl_reconst'] * (loss_kl_reconst_day + loss_kl_reconst_night) + \ self.params['loss_cycle'] * (loss_cycle_day + loss_cycle_night) + \ self.params['loss_kl_cycle'] * (loss_kl_cycle_day + loss_kl_cycle_night) loss_sum += loss.detach().item() loss_mean = loss_sum / len(val_loader) # domain translation day_to_night = self.ae_night.decode( self.ae_day.encode(img_day[0].unsqueeze(0))[0]) night_to_day = self.ae_day.decode( self.ae_night.encode(img_night[0].unsqueeze(0))[0]) # log loss log_str = f'[Epoch {epoch}] Val loss: {loss_mean}' print(log_str) with open(os.path.join(log_path, 'log.txt'), 'a+') as f: f.write(log_str + '\n') # save sample images samples = { 'day_img': img_day[0], 'night_img': img_night[0], 'reconst_day': reconst_day[0], 'reconst_night': reconst_night[0], 'reconst_cycle_day': reconst_cycle_day[0], 'reconst_cycle_night': reconst_cycle_night[0], 'day_to_night': day_to_night[0], 'night_to_day': night_to_day[0] } for name, img in samples.items(): ToPILImage()(img.cpu()).save( os.path.join(log_path, f'{epoch}_{name}.jpeg'), 'JPEG') def train(self): self.ae_day.train() self.ae_night.train() def eval(self): self.ae_day.eval() self.ae_night.eval() def cuda(self): self.ae_day.cuda() self.ae_night.cuda() def state_dict(self): return { 'encoder_lower_day': self.ae_day.encoder_lower.state_dict(), 'encoder_lower_night': self.ae_night.encoder_lower.state_dict(), 'encoder_upper': self.ae_day.encoder_upper.state_dict(), 'decoder_lower': self.ae_day.decoder_lower.state_dict(), 'decoder_upper_day': self.ae_day.decoder_upper.state_dict(), 'decoder_upper_night': self.ae_night.decoder_upper.state_dict() } def optim_state_dict(self): return self.optimizer.state_dict() def load_state_dict(self, state): self.ae_day.encoder_lower.load_state_dict(state['encoder_lower_day']) self.ae_night.encoder_lower.load_state_dict( state['encoder_lower_night']) self.ae_day.encoder_upper.load_state_dict(state['encoder_upper']) self.ae_day.decoder_lower.load_state_dict(state['decoder_lower']) self.ae_day.decoder_upper.load_state_dict(state['decoder_upper_day']) self.ae_night.decoder_upper.load_state_dict( state['decoder_upper_night']) def load_optim_state_dict(self, state_dict): self.optimizer.load_state_dict(state_dict)