Ejemplo n.º 1
0
    def _train_gan(self):
        """
        TODO: Add in autoencoder to perform dimensionality reduction on data
        TODO: Not working yet - trying to work out good autoencoder model first
        :return:
        """

        criterion = nn.BCELoss()

        discriminator_optimiser = optim.Adam(self.discriminator.parameters(),
                                             lr=0.003,
                                             betas=(0.5, 0.999))
        discriminator_scheduler = optim.lr_scheduler.LambdaLR(
            discriminator_optimiser, lambda epoch: 0.97**epoch)
        discriminator_checkpoint = Checkpoint("discriminator")
        discriminator_epoch = 0
        if discriminator_checkpoint.load():
            discriminator_epoch = self.load_state(discriminator_checkpoint,
                                                  self.discriminator,
                                                  discriminator_optimiser)
        else:
            LOG.info('Discriminator checkpoint not found')

        generator_optimiser = optim.Adam(self.generator.parameters(),
                                         lr=0.003,
                                         betas=(0.5, 0.999))
        generator_scheduler = optim.lr_scheduler.LambdaLR(
            generator_optimiser, lambda epoch: 0.97**epoch)
        generator_checkpoint = Checkpoint("generator")
        generator_epoch = 0
        if generator_checkpoint.load():
            generator_epoch = self.load_state(generator_checkpoint,
                                              self.generator,
                                              generator_optimiser)
        else:
            LOG.info('Generator checkpoint not found')

        if discriminator_epoch is None or generator_epoch is None:
            epoch = 0
            LOG.info(
                "Discriminator or generator failed to load, training from start"
            )
        else:
            epoch = min(generator_epoch, discriminator_epoch)
            LOG.info("Generator loaded at epoch {0}".format(generator_epoch))
            LOG.info("Discriminator loaded at epoch {0}".format(
                discriminator_epoch))
            LOG.info("Training from lowest epoch {0}".format(epoch))

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "gan",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            real_labels = None  # all 1s
            fake_labels = None  # all 0s
            epochs_complete = 0
            while epoch < self.config.MAX_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return  # Requeue needed and training not complete

                for step, (data, noise1,
                           noise2) in enumerate(self.data_loader):
                    batch_size = data.size(0)
                    if real_labels is None or real_labels.size(
                            0) != batch_size:
                        real_labels = self.generate_labels(batch_size, [1.0])
                    if fake_labels is None or fake_labels.size(
                            0) != batch_size:
                        fake_labels = self.generate_labels(batch_size, [0.0])

                    if self.config.USE_CUDA:
                        data = data.cuda()
                        noise1 = noise1.cuda()
                        noise2 = noise2.cuda()

                    # ============= Train the discriminator =============
                    # Pass real noise through first - ideally the discriminator will return 1 #[1, 0]
                    d_output_real = self.discriminator(data)
                    # Pass generated noise through - ideally the discriminator will return 0 #[0, 1]
                    d_output_fake1 = self.discriminator(self.generator(noise1))

                    # Determine the loss of the discriminator by adding up the real and fake loss and backpropagate
                    d_loss_real = criterion(
                        d_output_real, real_labels
                    )  # How good the discriminator is on real input
                    d_loss_fake = criterion(
                        d_output_fake1, fake_labels
                    )  # How good the discriminator is on fake input
                    d_loss = d_loss_real + d_loss_fake
                    self.discriminator.zero_grad()
                    d_loss.backward()
                    discriminator_optimiser.step()

                    # =============== Train the generator ===============
                    # Pass in fake noise to the generator and get it to generate "real" noise
                    # Judge how good this noise is with the discriminator
                    d_output_fake2 = self.discriminator(self.generator(noise2))

                    # Determine the loss of the generator using the discriminator and backpropagate
                    g_loss = criterion(d_output_fake2, real_labels)
                    self.discriminator.zero_grad()
                    self.generator.zero_grad()
                    g_loss.backward()
                    generator_optimiser.step()

                    vis.step(d_loss_real.item(), d_loss_fake.item(),
                             g_loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], d_loss_real: {4:.4f}, d_loss_fake: {5:.4f}, g_loss: {6:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1, self.config.MAX_EPOCHS, step + 1,
                                   len(self.data_loader), d_loss_real,
                                   d_loss_fake, g_loss))

                epoch += 1
                epochs_complete += 1

                discriminator_checkpoint.set(
                    self.discriminator.state_dict(),
                    discriminator_optimiser.state_dict(), epoch).save()
                generator_checkpoint.set(self.generator.state_dict(),
                                         generator_optimiser.state_dict(),
                                         epoch).save()
                vis.plot_training(epoch)

                data, noise1, _ = iter(self.data_loader).__next__()
                if self.config.USE_CUDA:
                    data = data.cuda()
                    noise1 = noise1.cuda()
                vis.test(epoch, self.data_loader.get_input_size_first(),
                         self.discriminator, self.generator, noise1, data)

                generator_scheduler.step(epoch)
                discriminator_scheduler.step(epoch)

                LOG.info("Learning rates: d {0} g {1}".format(
                    discriminator_optimiser.param_groups[0]["lr"],
                    generator_optimiser.param_groups[0]["lr"]))

        LOG.info("GAN Training complete")
Ejemplo n.º 2
0
    def _train_autoencoder(self):
        """
        Main training loop for the autencoder.
        This function will return False if:
        - Loading the autoencoder succeeded, but the NN model did not load the state dicts correctly.
        - The script needs to be re-queued because the NN has been trained for REQUEUE_EPOCHS
        :return: True if training was completed, False if training needs to continue.
        :rtype bool
        """

        criterion = nn.SmoothL1Loss()

        optimiser = optim.Adam(self.generator.parameters(),
                               lr=0.00003,
                               betas=(0.5, 0.999))
        checkpoint = Checkpoint("autoencoder")
        epoch = 0
        if checkpoint.load():
            epoch = self.load_state(checkpoint, self.autoencoder, optimiser)
            if epoch is not None and epoch >= self.config.MAX_AUTOENCODER_EPOCHS:
                LOG.info("Autoencoder already trained")
                return True
            else:
                LOG.info(
                    "Autoencoder training beginning from epoch {0}".format(
                        epoch))
        else:
            LOG.info('Autoencoder checkpoint not found. Training from start')

        # Train autoencoder
        self._autoencoder.set_mode(Autoencoder.Mode.AUTOENCODER)

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "autoencoder",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            epochs_complete = 0
            while epoch < self.config.MAX_AUTOENCODER_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return False  # Requeue needed and training not complete

                for step, (data, _, _) in enumerate(self.data_loader):
                    if self.config.USE_CUDA:
                        data = data.cuda()

                    if self.config.ADD_DROPOUT:
                        # Drop out parts of the input, but compute loss on the full input.
                        out = self.autoencoder(nn.functional.dropout(
                            data, 0.5))
                    else:
                        out = self.autoencoder(data)

                    loss = criterion(out.cpu(), data.cpu())
                    self.autoencoder.zero_grad()
                    loss.backward()
                    optimiser.step()

                    vis.step_autoencoder(loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], loss: {4:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1,
                                   self.config.MAX_AUTOENCODER_EPOCHS, step,
                                   len(self.data_loader), loss))

                epoch += 1
                epochs_complete += 1

                checkpoint.set(self.autoencoder.state_dict(),
                               optimiser.state_dict(), epoch).save()

                LOG.info("Plotting autoencoder progress")
                vis.plot_training(epoch)
                data, _, _ = iter(self.data_loader).__next__()
                vis.test_autoencoder(epoch, self.autoencoder, data.cuda())

        LOG.info("Autoencoder training complete")
        return True  # Training complete