Exemple #1
0
class Train(object):
    """
    Main GAN trainer. Responsible for training the GAN and pre-training the generator autoencoder.
    """
    def __init__(self, config):
        """
        Construct a new GAN trainer
        :param Config config: The parsed network configuration.
        """
        self.config = config

        LOG.info("CUDA version: {0}".format(version.cuda))
        LOG.info("Creating data loader from path {0}".format(config.FILENAME))

        self.data_loader = Data(
            config.FILENAME,
            config.BATCH_SIZE,
            polarisations=config.POLARISATIONS,  # Polarisations to use
            frequencies=config.FREQUENCIES,  # Frequencies to use
            max_inputs=config.
            MAX_SAMPLES,  # Max inputs per polarisation and frequency
            normalise=config.NORMALISE)  # Normalise inputs

        shape = self.data_loader.get_input_shape()
        width = shape[1]
        LOG.info("Creating models with input shape {0}".format(shape))
        self._autoencoder = Autoencoder(width)
        self._discriminator = Discriminator(width)
        # TODO: Get correct input and output widths for generator
        self._generator = Generator(width, width)

        if config.USE_CUDA:
            LOG.info("Using CUDA")
            self.autoencoder = self._autoencoder.cuda()
            self.discriminator = self._discriminator.cuda()
            self.generator = self._generator.cuda()
        else:
            LOG.info("Using CPU")
            self.autoencoder = self._autoencoder
            self.discriminator = self._discriminator
            self.generator = self._generator

    def check_requeue(self, epochs_complete):
        """
        Check and re-queue the training script if it has completed the desired number of training epochs per session
        :param int epochs_complete: Number of epochs completed
        :return: True if the script has been requeued, False if not
        :rtype bool
        """
        if self.config.REQUEUE_EPOCHS > 0:
            if epochs_complete >= self.config.REQUEUE_EPOCHS:
                # We've completed enough epochs for this instance. We need to kill it and requeue
                LOG.info(
                    "REQUEUE_EPOCHS of {0} met, calling REQUEUE_SCRIPT".format(
                        self.config.REQUEUE_EPOCHS))
                subprocess.call(self.config.REQUEUE_SCRIPT,
                                shell=True,
                                cwd=os.path.dirname(
                                    self.config.REQUEUE_SCRIPT))
                return True  # Requeue performed
        return False  # No requeue needed

    def load_state(self, checkpoint, module, optimiser=None):
        """
        Load the provided checkpoint into the provided module and optimiser.
        This function checks whether the load threw an exception and logs it to the user.
        :param Checkpoint checkpoint: The checkpoint to load
        :param module: The pytorch module to load the checkpoint into.
        :param optimiser: The pytorch optimiser to load the checkpoint into.
        :return: None if the load failed, int number of epochs in the checkpoint if load succeeded
        """
        try:
            module.load_state_dict(checkpoint.module_state)
            if optimiser is not None:
                optimiser.load_state_dict(checkpoint.optimiser_state)
            return checkpoint.epoch
        except RuntimeError as e:
            LOG.exception(
                "Error loading module state. This is most likely an input size mismatch. Please delete the old module saved state, or change the input size"
            )
            return None

    def close(self):
        """
        Close the data loader used by the trainer.
        """
        self.data_loader.close()

    def generate_labels(self, num_samples, pattern):
        """
        Generate labels for the discriminator.
        :param int num_samples: Number of input samples to generate labels for.
        :param list pattern: Pattern to generator. Should be either [1, 0], or [0, 1]
        :return: New labels for the discriminator
        """
        var = torch.FloatTensor([pattern] * num_samples)
        return var.cuda() if self.config.USE_CUDA else var

    def _train_autoencoder(self):
        """
        Main training loop for the autencoder.
        This function will return False if:
        - Loading the autoencoder succeeded, but the NN model did not load the state dicts correctly.
        - The script needs to be re-queued because the NN has been trained for REQUEUE_EPOCHS
        :return: True if training was completed, False if training needs to continue.
        :rtype bool
        """

        criterion = nn.SmoothL1Loss()

        optimiser = optim.Adam(self.generator.parameters(),
                               lr=0.00003,
                               betas=(0.5, 0.999))
        checkpoint = Checkpoint("autoencoder")
        epoch = 0
        if checkpoint.load():
            epoch = self.load_state(checkpoint, self.autoencoder, optimiser)
            if epoch is not None and epoch >= self.config.MAX_AUTOENCODER_EPOCHS:
                LOG.info("Autoencoder already trained")
                return True
            else:
                LOG.info(
                    "Autoencoder training beginning from epoch {0}".format(
                        epoch))
        else:
            LOG.info('Autoencoder checkpoint not found. Training from start')

        # Train autoencoder
        self._autoencoder.set_mode(Autoencoder.Mode.AUTOENCODER)

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "autoencoder",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            epochs_complete = 0
            while epoch < self.config.MAX_AUTOENCODER_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return False  # Requeue needed and training not complete

                for step, (data, _, _) in enumerate(self.data_loader):
                    if self.config.USE_CUDA:
                        data = data.cuda()

                    if self.config.ADD_DROPOUT:
                        # Drop out parts of the input, but compute loss on the full input.
                        out = self.autoencoder(nn.functional.dropout(
                            data, 0.5))
                    else:
                        out = self.autoencoder(data)

                    loss = criterion(out.cpu(), data.cpu())
                    self.autoencoder.zero_grad()
                    loss.backward()
                    optimiser.step()

                    vis.step_autoencoder(loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], loss: {4:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1,
                                   self.config.MAX_AUTOENCODER_EPOCHS, step,
                                   len(self.data_loader), loss))

                epoch += 1
                epochs_complete += 1

                checkpoint.set(self.autoencoder.state_dict(),
                               optimiser.state_dict(), epoch).save()

                LOG.info("Plotting autoencoder progress")
                vis.plot_training(epoch)
                data, _, _ = iter(self.data_loader).__next__()
                vis.test_autoencoder(epoch, self.autoencoder, data.cuda())

        LOG.info("Autoencoder training complete")
        return True  # Training complete

    def _train_gan(self):
        """
        TODO: Add in autoencoder to perform dimensionality reduction on data
        TODO: Not working yet - trying to work out good autoencoder model first
        :return:
        """

        criterion = nn.BCELoss()

        discriminator_optimiser = optim.Adam(self.discriminator.parameters(),
                                             lr=0.003,
                                             betas=(0.5, 0.999))
        discriminator_scheduler = optim.lr_scheduler.LambdaLR(
            discriminator_optimiser, lambda epoch: 0.97**epoch)
        discriminator_checkpoint = Checkpoint("discriminator")
        discriminator_epoch = 0
        if discriminator_checkpoint.load():
            discriminator_epoch = self.load_state(discriminator_checkpoint,
                                                  self.discriminator,
                                                  discriminator_optimiser)
        else:
            LOG.info('Discriminator checkpoint not found')

        generator_optimiser = optim.Adam(self.generator.parameters(),
                                         lr=0.003,
                                         betas=(0.5, 0.999))
        generator_scheduler = optim.lr_scheduler.LambdaLR(
            generator_optimiser, lambda epoch: 0.97**epoch)
        generator_checkpoint = Checkpoint("generator")
        generator_epoch = 0
        if generator_checkpoint.load():
            generator_epoch = self.load_state(generator_checkpoint,
                                              self.generator,
                                              generator_optimiser)
        else:
            LOG.info('Generator checkpoint not found')

        if discriminator_epoch is None or generator_epoch is None:
            epoch = 0
            LOG.info(
                "Discriminator or generator failed to load, training from start"
            )
        else:
            epoch = min(generator_epoch, discriminator_epoch)
            LOG.info("Generator loaded at epoch {0}".format(generator_epoch))
            LOG.info("Discriminator loaded at epoch {0}".format(
                discriminator_epoch))
            LOG.info("Training from lowest epoch {0}".format(epoch))

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "gan",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            real_labels = None  # all 1s
            fake_labels = None  # all 0s
            epochs_complete = 0
            while epoch < self.config.MAX_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return  # Requeue needed and training not complete

                for step, (data, noise1,
                           noise2) in enumerate(self.data_loader):
                    batch_size = data.size(0)
                    if real_labels is None or real_labels.size(
                            0) != batch_size:
                        real_labels = self.generate_labels(batch_size, [1.0])
                    if fake_labels is None or fake_labels.size(
                            0) != batch_size:
                        fake_labels = self.generate_labels(batch_size, [0.0])

                    if self.config.USE_CUDA:
                        data = data.cuda()
                        noise1 = noise1.cuda()
                        noise2 = noise2.cuda()

                    # ============= Train the discriminator =============
                    # Pass real noise through first - ideally the discriminator will return 1 #[1, 0]
                    d_output_real = self.discriminator(data)
                    # Pass generated noise through - ideally the discriminator will return 0 #[0, 1]
                    d_output_fake1 = self.discriminator(self.generator(noise1))

                    # Determine the loss of the discriminator by adding up the real and fake loss and backpropagate
                    d_loss_real = criterion(
                        d_output_real, real_labels
                    )  # How good the discriminator is on real input
                    d_loss_fake = criterion(
                        d_output_fake1, fake_labels
                    )  # How good the discriminator is on fake input
                    d_loss = d_loss_real + d_loss_fake
                    self.discriminator.zero_grad()
                    d_loss.backward()
                    discriminator_optimiser.step()

                    # =============== Train the generator ===============
                    # Pass in fake noise to the generator and get it to generate "real" noise
                    # Judge how good this noise is with the discriminator
                    d_output_fake2 = self.discriminator(self.generator(noise2))

                    # Determine the loss of the generator using the discriminator and backpropagate
                    g_loss = criterion(d_output_fake2, real_labels)
                    self.discriminator.zero_grad()
                    self.generator.zero_grad()
                    g_loss.backward()
                    generator_optimiser.step()

                    vis.step(d_loss_real.item(), d_loss_fake.item(),
                             g_loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], d_loss_real: {4:.4f}, d_loss_fake: {5:.4f}, g_loss: {6:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1, self.config.MAX_EPOCHS, step + 1,
                                   len(self.data_loader), d_loss_real,
                                   d_loss_fake, g_loss))

                epoch += 1
                epochs_complete += 1

                discriminator_checkpoint.set(
                    self.discriminator.state_dict(),
                    discriminator_optimiser.state_dict(), epoch).save()
                generator_checkpoint.set(self.generator.state_dict(),
                                         generator_optimiser.state_dict(),
                                         epoch).save()
                vis.plot_training(epoch)

                data, noise1, _ = iter(self.data_loader).__next__()
                if self.config.USE_CUDA:
                    data = data.cuda()
                    noise1 = noise1.cuda()
                vis.test(epoch, self.data_loader.get_input_size_first(),
                         self.discriminator, self.generator, noise1, data)

                generator_scheduler.step(epoch)
                discriminator_scheduler.step(epoch)

                LOG.info("Learning rates: d {0} g {1}".format(
                    discriminator_optimiser.param_groups[0]["lr"],
                    generator_optimiser.param_groups[0]["lr"]))

        LOG.info("GAN Training complete")

    def __call__(self):
        """
        Main training loop for the GAN.
        The training process is interruptable; the model and optimiser states are saved to disk each epoch, and the
        latest states are restored when the trainer is resumed.

        If the script is not able to load the generator's saved state, it will attempt to load the pre-trained generator
        autoencoder from the generator_decoder_complete checkpoint (if it exists). If this also fails, the generator is
        pre-trained as an autoencoder. This training is also interruptable, and will produce the
        generator_decoder_complete checkpoint on completion.

        On successfully restoring generator and discriminator state, the trainer will proceed from the earliest restored
        epoch. For example, if the generator is restored from epoch 7 and the discriminator is restored from epoch 5,
        training will proceed from epoch 5.

        Visualisation plots are produces each epoch and stored in
        /path_to_input_file_directory/{gan/generator_auto_encoder}/{timestamp}/{epoch}

        Each time the trainer is run, it creates a new timestamp directory using the current time.
        """

        # Load the autoencoder, and train it if needed.
        if not self._train_autoencoder():
            # Autoencoder training incomplete
            return
class SimpleModel(CustomModule, EmbeddingGenerator):

    ae_day: Autoencoder
    ae_night: Autoencoder

    def __init__(self):
        encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder()
        self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower,
                                  UpperDecoder())
        self.ae_night = Autoencoder(LowerEncoder(), encoder_upper,
                                    decoder_lower, UpperDecoder())
        self.loss_fn = nn.L1Loss()

        self.optimizer_day = None
        self.optimizer_night = None
        self.scheduler_day = None
        self.scheduler_night = None

    def __call__(self, input):
        raise NotImplementedError

    def init_optimizers(self):
        """
        Is called right before training and after model has been moved to GPU.
        Supposed to initialize optimizers and schedulers.
        """
        self.optimizer_day = Adam(self.ae_day.parameters(), lr=1e-4)
        self.optimizer_night = Adam(self.ae_night.parameters(), lr=1e-4)
        self.scheduler_day = ReduceLROnPlateau(self.optimizer_day,
                                               patience=15,
                                               verbose=True)
        self.scheduler_night = ReduceLROnPlateau(self.optimizer_night,
                                                 patience=15,
                                                 verbose=True)

    def train_epoch(self, train_loader, epoch, use_cuda, log_path, **kwargs):
        loss_day_sum, loss_night_sum = 0, 0

        for day_img, night_img in train_loader:
            if use_cuda:
                day_img, night_img = day_img.cuda(), night_img.cuda()

            # zero day gradients
            self.optimizer_day.zero_grad()

            # train day autoencoder
            out_day = self.ae_day(day_img)
            loss_day = self.loss_fn(out_day, day_img)

            # optimize
            loss_day.backward()
            self.optimizer_day.step()

            # zero night gradients
            self.optimizer_night.zero_grad()

            # train night autoencoder
            out_night = self.ae_night(night_img)
            loss_night = self.loss_fn(out_night, night_img)

            # optimize
            loss_night.backward()
            self.optimizer_night.step()

            loss_day_sum += loss_day
            loss_night_sum += loss_night

        loss_day_mean = loss_day_sum / len(train_loader)
        loss_night_mean = loss_night_sum / len(train_loader)

        self.scheduler_day.step(loss_day_mean, epoch)
        self.scheduler_night.step(loss_night_mean, epoch)

        # log losses
        log_str = f'[Epoch {epoch}] Train day loss: {loss_day_mean} Train night loss: {loss_night_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

    def validate(self, val_loader, epoch, use_cuda, log_path, **kwargs):
        loss_day_sum, loss_night_sum = 0, 0
        day_img, night_img, out_day, out_night = (None, ) * 4

        with torch.no_grad():
            for day_img, night_img in val_loader:
                if use_cuda:
                    day_img, night_img = day_img.cuda(), night_img.cuda()

                out_day = self.ae_day(day_img)
                loss_day = self.loss_fn(out_day, day_img)

                out_night = self.ae_night(night_img)
                loss_night = self.loss_fn(out_night, night_img)

                loss_day_sum += loss_day
                loss_night_sum += loss_night

        loss_day_mean = loss_day_sum / len(val_loader)
        loss_night_mean = loss_night_sum / len(val_loader)

        # domain translation
        day_to_night = self.ae_night.decode(
            self.ae_day.encode(day_img[0].unsqueeze(0)))
        night_to_day = self.ae_day.decode(
            self.ae_night.encode(night_img[0].unsqueeze(0)))

        # log losses
        log_str = f'[Epoch {epoch}] Val day loss: {loss_day_mean} Val night loss: {loss_night_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

        # save sample images
        samples = {
            'day_img': day_img[0],
            'night_img': night_img[0],
            'out_day': out_day[0],
            'out_night': out_night[0],
            'day_to_night': day_to_night[0],
            'night_to_day': night_to_day[0]
        }

        for name, img in samples.items():
            ToPILImage()(img.cpu()).save(
                os.path.join(log_path, f'{epoch}_{name}.jpeg'), 'JPEG')

    def register_hooks(self, layers):
        """
        This function is not supposed to be called from outside the class.
        """
        handles = []
        embedding_dict = {}

        def get_hook(name, embedding_dict):
            def hook(model, input, output):
                embedding_dict[name] = output.detach()

            return hook

        for layer in layers:
            hook = get_hook(layer, embedding_dict)
            handles.append(
                getattr(self.ae_day.encoder_upper,
                        layer).register_forward_hook(hook))

        return handles, embedding_dict

    def deregister_hooks(self, handles):
        """
        This function is not supposed to be called from outside the class.
        """
        for handle in handles:
            handle.remove()

    def get_day_embeddings(self, img, layers):
        """
        Returns deep embeddings for the passed layers inside the upper encoder.
        """
        handles, embedding_dict = self.register_hooks(layers)

        # forward pass
        self.ae_day.encode(img)

        self.deregister_hooks(handles)

        return embedding_dict

    def get_night_embeddings(self, img, layers):
        """
        Returns deep embeddings for the passed layers inside the upper encoder.
        """
        handles, embedding_dict = self.register_hooks(layers)

        # forward pass
        self.ae_night.encode(img)

        self.deregister_hooks(handles)

        return embedding_dict

    def train(self):
        self.ae_day.train()
        self.ae_night.train()

    def eval(self):
        self.ae_day.eval()
        self.ae_night.eval()

    def cuda(self):
        self.ae_day.cuda()
        self.ae_night.cuda()

    def state_dict(self):
        return {
            'encoder_lower_day': self.ae_day.encoder_lower.state_dict(),
            'encoder_lower_night': self.ae_night.encoder_lower.state_dict(),
            'encoder_upper': self.ae_day.encoder_upper.state_dict(),
            'decoder_day': self.ae_day.decoder.state_dict(),
            'decoder_night': self.ae_night.decoder.state_dict()
        }

    def optim_state_dict(self):
        return {
            'optimizer_day': self.optimizer_day.state_dict(),
            'optimizer_night': self.optimizer_night.state_dict()
        }

    def load_state_dict(self, state):
        self.ae_day.encoder_lower.load_state_dict(state['encoder_lower_day'])
        self.ae_night.encoder_lower.load_state_dict(
            state['encoder_lower_night'])
        self.ae_day.encoder_upper.load_state_dict(state['encoder_upper'])
        self.ae_day.decoder.load_state_dict(state['decoder_day'])
        self.ae_night.decoder.load_state_dict(state['decoder_night'])

    def load_optim_state_dict(self, state):
        self.optimizer_day.load_state_dict(state['optimizer_day'])
        self.optimizer_night.load_state_dict(state['optimizer_night'])
class CycleModel(CustomModule):
    ae_day: Autoencoder
    ae_night: Autoencoder
    reconstruction_loss_factor: float
    cycle_loss_factor: float

    def __init__(self, reconstruction_loss_factor: float, cycle_loss_factor: float):
        # share weights of the upper encoder & lower decoder
        encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder()
        self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder())
        self.ae_night = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder())
        self.loss_fn = nn.L1Loss()
        self.reconstruction_loss_factor = reconstruction_loss_factor
        self.cycle_loss_factor = cycle_loss_factor

        self.optimizer = None
        self.scheduler = None

    def __call__(self, input):
        raise NotImplementedError

    def init_optimizers(self):
        """
        Is called right before training and after model has been moved to GPU.
        Supposed to initialize optimizers and schedulers.
        """
        parameters = set()
        parameters |= set(self.ae_day.parameters())
        parameters |= set(self.ae_night.parameters())
        self.optimizer = Adam(parameters)

        # initialize scheduler
        self.scheduler = ReduceLROnPlateau(self.optimizer, patience=15, verbose=True)

    def train_epoch(self, train_loader, epoch, use_cuda, log_path, **kwargs):
        loss_day2night2day_sum, loss_night2day2night_sum, loss_day2day_sum, loss_night2night_sum = 0, 0, 0, 0

        for day_img, night_img in train_loader:
            if use_cuda:
                day_img, night_img = day_img.cuda(), night_img.cuda()

            # Day -> Night -> Day
            self.optimizer.zero_grad()
            loss_day2night2day, loss_day2day = self.cycle_plus_reconstruction_loss(day_img, self.ae_day, self.ae_night)
            loss = loss_day2night2day * self.cycle_loss_factor + loss_day2day * self.reconstruction_loss_factor
            loss.backward()
            self.optimizer.step()

            # Night -> Day -> Night
            self.optimizer.zero_grad()
            loss_night2day2night, loss_night2night \
                = self.cycle_plus_reconstruction_loss(night_img, self.ae_night, self.ae_day)
            loss = loss_night2day2night * self.cycle_loss_factor + loss_night2night * self.reconstruction_loss_factor
            loss.backward()
            self.optimizer.step()

            loss_day2night2day_sum += loss_day2night2day
            loss_day2day_sum += loss_day2day
            loss_night2day2night_sum += loss_night2day2night
            loss_night2night_sum += loss_night2night

        loss_day2night2day_mean = loss_day2night2day_sum / len(train_loader)
        loss_day2day_mean = loss_day2day_sum / len(train_loader)
        loss_night2day2night_mean = loss_night2day2night_sum / len(train_loader)
        loss_night2night_mean = loss_night2night_sum / len(train_loader)
        loss_mean = (loss_day2night2day_mean + loss_day2day_mean + loss_night2day2night_mean + loss_night2night_mean)/4

        self.scheduler.step(loss_mean, epoch)

        # log losses
        log_str = f'[Epoch {epoch}] ' \
            f'Train loss day -> night -> day: {loss_day2night2day_mean} ' \
            f'Train loss night -> day -> night: {loss_night2day2night_mean} ' \
            f'Train loss day -> day: {loss_day2day_mean} ' \
            f'Train loss night -> night: {loss_night2night_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

    def validate(self, val_loader, epoch, use_cuda, log_path, **kwargs):
        loss_day2night2day_sum, loss_night2day2night_sum, loss_day2day_sum, loss_night2night_sum = 0, 0, 0, 0
        day_img, night_img = None, None

        with torch.no_grad():
            for day_img, night_img in val_loader:
                if use_cuda:
                    day_img, night_img = day_img.cuda(), night_img.cuda()

                # Day -> Night -> Day  and  Day -> Day
                loss_day2night2day, loss_day2day = \
                    self.cycle_plus_reconstruction_loss(day_img, self.ae_day, self.ae_night)

                # Night -> Day -> Night  and  Night -> Night
                loss_night2day2night, loss_night2night = \
                    self.cycle_plus_reconstruction_loss(night_img, self.ae_night, self.ae_day)

                loss_day2night2day_sum += loss_day2night2day
                loss_day2day_sum += loss_day2day
                loss_night2day2night_sum += loss_night2day2night
                loss_night2night_sum += loss_night2night

        loss_day2night2day_mean = loss_day2night2day_sum / len(val_loader)
        loss_night2day2night_mean = loss_night2day2night_sum / len(val_loader)
        loss_day2day_mean = loss_day2day_sum / len(val_loader)
        loss_night2night_mean = loss_night2night_sum / len(val_loader)

        # log losses
        log_str = f'[Epoch {epoch}] ' \
            f'Val loss day -> night -> day: {loss_day2night2day_mean} ' \
            f'Val loss night -> day -> night: {loss_night2day2night_mean} ' \
            f'Val loss day -> day: {loss_day2day_mean} ' \
            f'Val loss night -> night: {loss_night2night_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

        # create sample images

        latent_day = self.ae_day.encode(day_img[0].unsqueeze(0))
        latent_night = self.ae_night.encode(night_img[0].unsqueeze(0))
        # reconstruction
        day2day = self.ae_day.decode(latent_day)
        night2night = self.ae_night.decode(latent_night)
        # domain translation
        day2night = self.ae_night.decode(latent_day)
        night2day = self.ae_day.decode(latent_night)
        # cycle
        day2night2day = self.ae_day.decode(self.ae_night.encode(day2night))
        night2day2night = self.ae_night.decode(self.ae_day.encode(night2day))

        # save sample images
        samples = {
            'day_img': day_img[0],
            'night_img': night_img[0],
            'day2day': day2day[0],
            'night2night': night2night[0],
            'day2night': day2night[0],
            'night2day': night2day[0],
            'day2night2day': day2night2day[0],
            'night2day2night': night2day2night[0],
        }

        for name, img in samples.items():
            ToPILImage()(img.cpu()).save(os.path.join(log_path, f'{epoch}_{name}.jpeg'), 'JPEG')

    def cycle_plus_reconstruction_loss(self, image, autoencoder1, autoencoder2):
        # send the image through the cycle
        intermediate_latent_1 = autoencoder1.encode(image)
        intermediate_opposite = autoencoder2.decode(intermediate_latent_1)
        intermediate_latent_2 = autoencoder2.encode(intermediate_opposite)
        cycle_img = autoencoder1.decode(intermediate_latent_2)

        # do simple reconstruction
        reconstructed_img = autoencoder1.decode(intermediate_latent_1)

        cycle_loss = self.loss_fn(cycle_img, image)
        reconstruction_loss = self.loss_fn(reconstructed_img, image)
        return cycle_loss, reconstruction_loss

    def train(self):
        self.ae_day.train()
        self.ae_night.train()

    def eval(self):
        self.ae_day.eval()
        self.ae_night.eval()

    def cuda(self):
        self.ae_day.cuda()
        self.ae_night.cuda()

    def state_dict(self):
        return {
            'encoder_lower_day': self.ae_day.encoder_lower.state_dict(),
            'encoder_lower_night': self.ae_night.encoder_lower.state_dict(),
            'encoder_upper': self.ae_day.encoder_upper.state_dict(),
            'decoder_day': self.ae_day.decoder.state_dict(),
            'decoder_night': self.ae_night.decoder.state_dict()
        }

    def optim_state_dict(self):
        return {
            'optimizer': self.optimizer.state_dict(),
        }

    def load_state_dict(self, state):
        self.ae_day.encoder_lower.load_state_dict(state['encoder_lower_day'])
        self.ae_night.encoder_lower.load_state_dict(state['encoder_lower_night'])
        self.ae_day.encoder_upper.load_state_dict(state['encoder_upper'])
        self.ae_day.decoder.load_state_dict(state['decoder_day'])
        self.ae_night.decoder.load_state_dict(state['decoder_night'])
Exemple #4
0
class CycleVAE(CustomModule, EmbeddingGenerator):
    """
    CycleVAE model. This is the model which was used for evaluation.
    """
    def get_day_embeddings(self, img, layers):
        """
        Returns deep embeddings for the passed layers inside the upper encoder.
        """
        # forward pass
        latent = self.ae_day.encode(img)[0]

        return {'latent': latent}

    def get_night_embeddings(self, img, layers):
        """
        Returns deep embeddings for the passed layers inside the upper encoder.
        """
        # forward pass
        latent = self.ae_night.encode(img)[0]

        return {'latent': latent}

    def __init__(self, params: dict):
        self.params = params

        # share weights of the upper encoder & lower decoder
        encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder()
        self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower,
                                  UpperDecoder())
        self.ae_night = Autoencoder(LowerEncoder(), encoder_upper,
                                    decoder_lower, UpperDecoder())

        self.reconst_loss = nn.L1Loss()

        self.optimizer = None
        self.scheduler = None

    def __call__(self, input):
        raise NotImplementedError

    def init_optimizers(self):
        """
        Is called right before training and after model has been moved to GPU.
        Supposed to initialize optimizers and schedulers.
        """
        params = list(self.ae_day.parameters()) + list(
            self.ae_night.parameters())
        self.optimizer = Adam([p for p in params if p.requires_grad],
                              lr=self.params['lr'])
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           patience=self.params['patience'],
                                           verbose=True)

    def train_epoch(self, train_loader, epoch, use_cuda, log_path, **kwargs):
        loss_sum = 0

        for img_day, img_night in train_loader:
            if use_cuda:
                img_day, img_night = img_day.cuda(), img_night.cuda()

            self.optimizer.zero_grad()

            latent_day, noise_day = self.ae_day.encode(img_day)
            latent_night, noise_night = self.ae_night.encode(img_night)

            # same domain reconstruction
            reconst_day = self.ae_day.decode(latent_day + noise_day)
            reconst_night = self.ae_night.decode(latent_night + noise_night)

            # cross domain
            night_to_day = self.ae_day.decode(latent_night + noise_night)
            day_to_night = self.ae_night.decode(latent_day + noise_day)

            # encode again for cycle loss
            latent_night_to_day, noise_night_to_day = self.ae_day.encode(
                night_to_day)
            latent_day_to_night, noise_day_to_night = self.ae_night.encode(
                day_to_night)

            # aaaand decode again
            reconst_cycle_day = self.ae_day.decode(latent_day_to_night +
                                                   noise_day_to_night)
            reconst_cycle_night = self.ae_night.decode(latent_night_to_day +
                                                       noise_night_to_day)

            # loss formulations
            loss_reconst_day = self.reconst_loss(reconst_day, img_day)
            loss_reconst_night = self.reconst_loss(reconst_night, img_night)
            loss_kl_reconst_day = kl_loss(latent_day)
            loss_kl_reconst_night = kl_loss(latent_night)
            loss_cycle_day = self.reconst_loss(reconst_cycle_day, img_day)
            loss_cycle_night = self.reconst_loss(reconst_cycle_night,
                                                 img_night)
            loss_kl_cycle_day = kl_loss(latent_night_to_day)
            loss_kl_cycle_night = kl_loss(latent_day_to_night)

            loss = \
                self.params['loss_reconst'] * (loss_reconst_day + loss_reconst_night) + \
                self.params['loss_kl_reconst'] * (loss_kl_reconst_day + loss_kl_reconst_night) + \
                self.params['loss_cycle'] * (loss_cycle_day + loss_cycle_night) + \
                self.params['loss_kl_cycle'] * (loss_kl_cycle_day + loss_kl_cycle_night)

            loss.backward()
            self.optimizer.step()

            loss_sum += loss.detach().item()

        loss_mean = loss_sum / len(train_loader)
        self.scheduler.step(loss_mean, epoch)

        # log loss
        log_str = f'[Epoch {epoch}] Train loss: {loss_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

    def validate(self, val_loader, epoch, use_cuda, log_path, **kwargs):
        loss_sum = 0
        img_day, img_night, reconst_day, reconst_night, reconst_cycle_day, reconst_cycle_night = (
            None, ) * 6

        with torch.no_grad():
            for img_day, img_night in val_loader:
                if use_cuda:
                    img_day, img_night = img_day.cuda(), img_night.cuda()

                latent_day, noise_day = self.ae_day.encode(img_day)
                latent_night, noise_night = self.ae_night.encode(img_night)

                # same domain reconstruction
                reconst_day = self.ae_day.decode(latent_day + noise_day)
                reconst_night = self.ae_night.decode(latent_night +
                                                     noise_night)

                # cross domain
                night_to_day = self.ae_day.decode(latent_night + noise_night)
                day_to_night = self.ae_night.decode(latent_day + noise_day)

                # encode again for cycle loss
                latent_night_to_day, noise_night_to_day = self.ae_day.encode(
                    night_to_day)
                latent_day_to_night, noise_day_to_night = self.ae_night.encode(
                    day_to_night)

                # aaaand decode again
                reconst_cycle_day = self.ae_day.decode(latent_day_to_night +
                                                       noise_day_to_night)
                reconst_cycle_night = self.ae_night.decode(
                    latent_night_to_day + noise_night_to_day)

                # loss formulations
                loss_reconst_day = self.reconst_loss(reconst_day, img_day)
                loss_reconst_night = self.reconst_loss(reconst_night,
                                                       img_night)
                loss_kl_reconst_day = kl_loss(latent_day)
                loss_kl_reconst_night = kl_loss(latent_night)
                loss_cycle_day = self.reconst_loss(reconst_cycle_day, img_day)
                loss_cycle_night = self.reconst_loss(reconst_cycle_night,
                                                     img_night)
                loss_kl_cycle_day = kl_loss(latent_night_to_day)
                loss_kl_cycle_night = kl_loss(latent_day_to_night)

                loss = \
                    self.params['loss_reconst'] * (loss_reconst_day + loss_reconst_night) + \
                    self.params['loss_kl_reconst'] * (loss_kl_reconst_day + loss_kl_reconst_night) + \
                    self.params['loss_cycle'] * (loss_cycle_day + loss_cycle_night) + \
                    self.params['loss_kl_cycle'] * (loss_kl_cycle_day + loss_kl_cycle_night)

                loss_sum += loss.detach().item()

        loss_mean = loss_sum / len(val_loader)

        # domain translation
        day_to_night = self.ae_night.decode(
            self.ae_day.encode(img_day[0].unsqueeze(0))[0])
        night_to_day = self.ae_day.decode(
            self.ae_night.encode(img_night[0].unsqueeze(0))[0])

        # log loss
        log_str = f'[Epoch {epoch}] Val loss: {loss_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

        # save sample images
        samples = {
            'day_img': img_day[0],
            'night_img': img_night[0],
            'reconst_day': reconst_day[0],
            'reconst_night': reconst_night[0],
            'reconst_cycle_day': reconst_cycle_day[0],
            'reconst_cycle_night': reconst_cycle_night[0],
            'day_to_night': day_to_night[0],
            'night_to_day': night_to_day[0]
        }

        for name, img in samples.items():
            ToPILImage()(img.cpu()).save(
                os.path.join(log_path, f'{epoch}_{name}.jpeg'), 'JPEG')

    def train(self):
        self.ae_day.train()
        self.ae_night.train()

    def eval(self):
        self.ae_day.eval()
        self.ae_night.eval()

    def cuda(self):
        self.ae_day.cuda()
        self.ae_night.cuda()

    def state_dict(self):
        return {
            'encoder_lower_day': self.ae_day.encoder_lower.state_dict(),
            'encoder_lower_night': self.ae_night.encoder_lower.state_dict(),
            'encoder_upper': self.ae_day.encoder_upper.state_dict(),
            'decoder_lower': self.ae_day.decoder_lower.state_dict(),
            'decoder_upper_day': self.ae_day.decoder_upper.state_dict(),
            'decoder_upper_night': self.ae_night.decoder_upper.state_dict()
        }

    def optim_state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state):
        self.ae_day.encoder_lower.load_state_dict(state['encoder_lower_day'])
        self.ae_night.encoder_lower.load_state_dict(
            state['encoder_lower_night'])
        self.ae_day.encoder_upper.load_state_dict(state['encoder_upper'])

        self.ae_day.decoder_lower.load_state_dict(state['decoder_lower'])
        self.ae_day.decoder_upper.load_state_dict(state['decoder_upper_day'])
        self.ae_night.decoder_upper.load_state_dict(
            state['decoder_upper_night'])

    def load_optim_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)