コード例 #1
0
class SimpleModel(CustomModule, EmbeddingGenerator):

    ae_day: Autoencoder
    ae_night: Autoencoder

    def __init__(self):
        encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder()
        self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower,
                                  UpperDecoder())
        self.ae_night = Autoencoder(LowerEncoder(), encoder_upper,
                                    decoder_lower, UpperDecoder())
        self.loss_fn = nn.L1Loss()

        self.optimizer_day = None
        self.optimizer_night = None
        self.scheduler_day = None
        self.scheduler_night = None

    def __call__(self, input):
        raise NotImplementedError

    def init_optimizers(self):
        """
        Is called right before training and after model has been moved to GPU.
        Supposed to initialize optimizers and schedulers.
        """
        self.optimizer_day = Adam(self.ae_day.parameters(), lr=1e-4)
        self.optimizer_night = Adam(self.ae_night.parameters(), lr=1e-4)
        self.scheduler_day = ReduceLROnPlateau(self.optimizer_day,
                                               patience=15,
                                               verbose=True)
        self.scheduler_night = ReduceLROnPlateau(self.optimizer_night,
                                                 patience=15,
                                                 verbose=True)

    def train_epoch(self, train_loader, epoch, use_cuda, log_path, **kwargs):
        loss_day_sum, loss_night_sum = 0, 0

        for day_img, night_img in train_loader:
            if use_cuda:
                day_img, night_img = day_img.cuda(), night_img.cuda()

            # zero day gradients
            self.optimizer_day.zero_grad()

            # train day autoencoder
            out_day = self.ae_day(day_img)
            loss_day = self.loss_fn(out_day, day_img)

            # optimize
            loss_day.backward()
            self.optimizer_day.step()

            # zero night gradients
            self.optimizer_night.zero_grad()

            # train night autoencoder
            out_night = self.ae_night(night_img)
            loss_night = self.loss_fn(out_night, night_img)

            # optimize
            loss_night.backward()
            self.optimizer_night.step()

            loss_day_sum += loss_day
            loss_night_sum += loss_night

        loss_day_mean = loss_day_sum / len(train_loader)
        loss_night_mean = loss_night_sum / len(train_loader)

        self.scheduler_day.step(loss_day_mean, epoch)
        self.scheduler_night.step(loss_night_mean, epoch)

        # log losses
        log_str = f'[Epoch {epoch}] Train day loss: {loss_day_mean} Train night loss: {loss_night_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

    def validate(self, val_loader, epoch, use_cuda, log_path, **kwargs):
        loss_day_sum, loss_night_sum = 0, 0
        day_img, night_img, out_day, out_night = (None, ) * 4

        with torch.no_grad():
            for day_img, night_img in val_loader:
                if use_cuda:
                    day_img, night_img = day_img.cuda(), night_img.cuda()

                out_day = self.ae_day(day_img)
                loss_day = self.loss_fn(out_day, day_img)

                out_night = self.ae_night(night_img)
                loss_night = self.loss_fn(out_night, night_img)

                loss_day_sum += loss_day
                loss_night_sum += loss_night

        loss_day_mean = loss_day_sum / len(val_loader)
        loss_night_mean = loss_night_sum / len(val_loader)

        # domain translation
        day_to_night = self.ae_night.decode(
            self.ae_day.encode(day_img[0].unsqueeze(0)))
        night_to_day = self.ae_day.decode(
            self.ae_night.encode(night_img[0].unsqueeze(0)))

        # log losses
        log_str = f'[Epoch {epoch}] Val day loss: {loss_day_mean} Val night loss: {loss_night_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

        # save sample images
        samples = {
            'day_img': day_img[0],
            'night_img': night_img[0],
            'out_day': out_day[0],
            'out_night': out_night[0],
            'day_to_night': day_to_night[0],
            'night_to_day': night_to_day[0]
        }

        for name, img in samples.items():
            ToPILImage()(img.cpu()).save(
                os.path.join(log_path, f'{epoch}_{name}.jpeg'), 'JPEG')

    def register_hooks(self, layers):
        """
        This function is not supposed to be called from outside the class.
        """
        handles = []
        embedding_dict = {}

        def get_hook(name, embedding_dict):
            def hook(model, input, output):
                embedding_dict[name] = output.detach()

            return hook

        for layer in layers:
            hook = get_hook(layer, embedding_dict)
            handles.append(
                getattr(self.ae_day.encoder_upper,
                        layer).register_forward_hook(hook))

        return handles, embedding_dict

    def deregister_hooks(self, handles):
        """
        This function is not supposed to be called from outside the class.
        """
        for handle in handles:
            handle.remove()

    def get_day_embeddings(self, img, layers):
        """
        Returns deep embeddings for the passed layers inside the upper encoder.
        """
        handles, embedding_dict = self.register_hooks(layers)

        # forward pass
        self.ae_day.encode(img)

        self.deregister_hooks(handles)

        return embedding_dict

    def get_night_embeddings(self, img, layers):
        """
        Returns deep embeddings for the passed layers inside the upper encoder.
        """
        handles, embedding_dict = self.register_hooks(layers)

        # forward pass
        self.ae_night.encode(img)

        self.deregister_hooks(handles)

        return embedding_dict

    def train(self):
        self.ae_day.train()
        self.ae_night.train()

    def eval(self):
        self.ae_day.eval()
        self.ae_night.eval()

    def cuda(self):
        self.ae_day.cuda()
        self.ae_night.cuda()

    def state_dict(self):
        return {
            'encoder_lower_day': self.ae_day.encoder_lower.state_dict(),
            'encoder_lower_night': self.ae_night.encoder_lower.state_dict(),
            'encoder_upper': self.ae_day.encoder_upper.state_dict(),
            'decoder_day': self.ae_day.decoder.state_dict(),
            'decoder_night': self.ae_night.decoder.state_dict()
        }

    def optim_state_dict(self):
        return {
            'optimizer_day': self.optimizer_day.state_dict(),
            'optimizer_night': self.optimizer_night.state_dict()
        }

    def load_state_dict(self, state):
        self.ae_day.encoder_lower.load_state_dict(state['encoder_lower_day'])
        self.ae_night.encoder_lower.load_state_dict(
            state['encoder_lower_night'])
        self.ae_day.encoder_upper.load_state_dict(state['encoder_upper'])
        self.ae_day.decoder.load_state_dict(state['decoder_day'])
        self.ae_night.decoder.load_state_dict(state['decoder_night'])

    def load_optim_state_dict(self, state):
        self.optimizer_day.load_state_dict(state['optimizer_day'])
        self.optimizer_night.load_state_dict(state['optimizer_night'])
コード例 #2
0
ファイル: cycle_model.py プロジェクト: alebeck/ImageRetrieval
class CycleModel(CustomModule):
    ae_day: Autoencoder
    ae_night: Autoencoder
    reconstruction_loss_factor: float
    cycle_loss_factor: float

    def __init__(self, reconstruction_loss_factor: float, cycle_loss_factor: float):
        # share weights of the upper encoder & lower decoder
        encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder()
        self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder())
        self.ae_night = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower, UpperDecoder())
        self.loss_fn = nn.L1Loss()
        self.reconstruction_loss_factor = reconstruction_loss_factor
        self.cycle_loss_factor = cycle_loss_factor

        self.optimizer = None
        self.scheduler = None

    def __call__(self, input):
        raise NotImplementedError

    def init_optimizers(self):
        """
        Is called right before training and after model has been moved to GPU.
        Supposed to initialize optimizers and schedulers.
        """
        parameters = set()
        parameters |= set(self.ae_day.parameters())
        parameters |= set(self.ae_night.parameters())
        self.optimizer = Adam(parameters)

        # initialize scheduler
        self.scheduler = ReduceLROnPlateau(self.optimizer, patience=15, verbose=True)

    def train_epoch(self, train_loader, epoch, use_cuda, log_path, **kwargs):
        loss_day2night2day_sum, loss_night2day2night_sum, loss_day2day_sum, loss_night2night_sum = 0, 0, 0, 0

        for day_img, night_img in train_loader:
            if use_cuda:
                day_img, night_img = day_img.cuda(), night_img.cuda()

            # Day -> Night -> Day
            self.optimizer.zero_grad()
            loss_day2night2day, loss_day2day = self.cycle_plus_reconstruction_loss(day_img, self.ae_day, self.ae_night)
            loss = loss_day2night2day * self.cycle_loss_factor + loss_day2day * self.reconstruction_loss_factor
            loss.backward()
            self.optimizer.step()

            # Night -> Day -> Night
            self.optimizer.zero_grad()
            loss_night2day2night, loss_night2night \
                = self.cycle_plus_reconstruction_loss(night_img, self.ae_night, self.ae_day)
            loss = loss_night2day2night * self.cycle_loss_factor + loss_night2night * self.reconstruction_loss_factor
            loss.backward()
            self.optimizer.step()

            loss_day2night2day_sum += loss_day2night2day
            loss_day2day_sum += loss_day2day
            loss_night2day2night_sum += loss_night2day2night
            loss_night2night_sum += loss_night2night

        loss_day2night2day_mean = loss_day2night2day_sum / len(train_loader)
        loss_day2day_mean = loss_day2day_sum / len(train_loader)
        loss_night2day2night_mean = loss_night2day2night_sum / len(train_loader)
        loss_night2night_mean = loss_night2night_sum / len(train_loader)
        loss_mean = (loss_day2night2day_mean + loss_day2day_mean + loss_night2day2night_mean + loss_night2night_mean)/4

        self.scheduler.step(loss_mean, epoch)

        # log losses
        log_str = f'[Epoch {epoch}] ' \
            f'Train loss day -> night -> day: {loss_day2night2day_mean} ' \
            f'Train loss night -> day -> night: {loss_night2day2night_mean} ' \
            f'Train loss day -> day: {loss_day2day_mean} ' \
            f'Train loss night -> night: {loss_night2night_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

    def validate(self, val_loader, epoch, use_cuda, log_path, **kwargs):
        loss_day2night2day_sum, loss_night2day2night_sum, loss_day2day_sum, loss_night2night_sum = 0, 0, 0, 0
        day_img, night_img = None, None

        with torch.no_grad():
            for day_img, night_img in val_loader:
                if use_cuda:
                    day_img, night_img = day_img.cuda(), night_img.cuda()

                # Day -> Night -> Day  and  Day -> Day
                loss_day2night2day, loss_day2day = \
                    self.cycle_plus_reconstruction_loss(day_img, self.ae_day, self.ae_night)

                # Night -> Day -> Night  and  Night -> Night
                loss_night2day2night, loss_night2night = \
                    self.cycle_plus_reconstruction_loss(night_img, self.ae_night, self.ae_day)

                loss_day2night2day_sum += loss_day2night2day
                loss_day2day_sum += loss_day2day
                loss_night2day2night_sum += loss_night2day2night
                loss_night2night_sum += loss_night2night

        loss_day2night2day_mean = loss_day2night2day_sum / len(val_loader)
        loss_night2day2night_mean = loss_night2day2night_sum / len(val_loader)
        loss_day2day_mean = loss_day2day_sum / len(val_loader)
        loss_night2night_mean = loss_night2night_sum / len(val_loader)

        # log losses
        log_str = f'[Epoch {epoch}] ' \
            f'Val loss day -> night -> day: {loss_day2night2day_mean} ' \
            f'Val loss night -> day -> night: {loss_night2day2night_mean} ' \
            f'Val loss day -> day: {loss_day2day_mean} ' \
            f'Val loss night -> night: {loss_night2night_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

        # create sample images

        latent_day = self.ae_day.encode(day_img[0].unsqueeze(0))
        latent_night = self.ae_night.encode(night_img[0].unsqueeze(0))
        # reconstruction
        day2day = self.ae_day.decode(latent_day)
        night2night = self.ae_night.decode(latent_night)
        # domain translation
        day2night = self.ae_night.decode(latent_day)
        night2day = self.ae_day.decode(latent_night)
        # cycle
        day2night2day = self.ae_day.decode(self.ae_night.encode(day2night))
        night2day2night = self.ae_night.decode(self.ae_day.encode(night2day))

        # save sample images
        samples = {
            'day_img': day_img[0],
            'night_img': night_img[0],
            'day2day': day2day[0],
            'night2night': night2night[0],
            'day2night': day2night[0],
            'night2day': night2day[0],
            'day2night2day': day2night2day[0],
            'night2day2night': night2day2night[0],
        }

        for name, img in samples.items():
            ToPILImage()(img.cpu()).save(os.path.join(log_path, f'{epoch}_{name}.jpeg'), 'JPEG')

    def cycle_plus_reconstruction_loss(self, image, autoencoder1, autoencoder2):
        # send the image through the cycle
        intermediate_latent_1 = autoencoder1.encode(image)
        intermediate_opposite = autoencoder2.decode(intermediate_latent_1)
        intermediate_latent_2 = autoencoder2.encode(intermediate_opposite)
        cycle_img = autoencoder1.decode(intermediate_latent_2)

        # do simple reconstruction
        reconstructed_img = autoencoder1.decode(intermediate_latent_1)

        cycle_loss = self.loss_fn(cycle_img, image)
        reconstruction_loss = self.loss_fn(reconstructed_img, image)
        return cycle_loss, reconstruction_loss

    def train(self):
        self.ae_day.train()
        self.ae_night.train()

    def eval(self):
        self.ae_day.eval()
        self.ae_night.eval()

    def cuda(self):
        self.ae_day.cuda()
        self.ae_night.cuda()

    def state_dict(self):
        return {
            'encoder_lower_day': self.ae_day.encoder_lower.state_dict(),
            'encoder_lower_night': self.ae_night.encoder_lower.state_dict(),
            'encoder_upper': self.ae_day.encoder_upper.state_dict(),
            'decoder_day': self.ae_day.decoder.state_dict(),
            'decoder_night': self.ae_night.decoder.state_dict()
        }

    def optim_state_dict(self):
        return {
            'optimizer': self.optimizer.state_dict(),
        }

    def load_state_dict(self, state):
        self.ae_day.encoder_lower.load_state_dict(state['encoder_lower_day'])
        self.ae_night.encoder_lower.load_state_dict(state['encoder_lower_night'])
        self.ae_day.encoder_upper.load_state_dict(state['encoder_upper'])
        self.ae_day.decoder.load_state_dict(state['decoder_day'])
        self.ae_night.decoder.load_state_dict(state['decoder_night'])
コード例 #3
0
class CycleVAE(CustomModule, EmbeddingGenerator):
    """
    CycleVAE model. This is the model which was used for evaluation.
    """
    def get_day_embeddings(self, img, layers):
        """
        Returns deep embeddings for the passed layers inside the upper encoder.
        """
        # forward pass
        latent = self.ae_day.encode(img)[0]

        return {'latent': latent}

    def get_night_embeddings(self, img, layers):
        """
        Returns deep embeddings for the passed layers inside the upper encoder.
        """
        # forward pass
        latent = self.ae_night.encode(img)[0]

        return {'latent': latent}

    def __init__(self, params: dict):
        self.params = params

        # share weights of the upper encoder & lower decoder
        encoder_upper, decoder_lower = UpperEncoder(), LowerDecoder()
        self.ae_day = Autoencoder(LowerEncoder(), encoder_upper, decoder_lower,
                                  UpperDecoder())
        self.ae_night = Autoencoder(LowerEncoder(), encoder_upper,
                                    decoder_lower, UpperDecoder())

        self.reconst_loss = nn.L1Loss()

        self.optimizer = None
        self.scheduler = None

    def __call__(self, input):
        raise NotImplementedError

    def init_optimizers(self):
        """
        Is called right before training and after model has been moved to GPU.
        Supposed to initialize optimizers and schedulers.
        """
        params = list(self.ae_day.parameters()) + list(
            self.ae_night.parameters())
        self.optimizer = Adam([p for p in params if p.requires_grad],
                              lr=self.params['lr'])
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           patience=self.params['patience'],
                                           verbose=True)

    def train_epoch(self, train_loader, epoch, use_cuda, log_path, **kwargs):
        loss_sum = 0

        for img_day, img_night in train_loader:
            if use_cuda:
                img_day, img_night = img_day.cuda(), img_night.cuda()

            self.optimizer.zero_grad()

            latent_day, noise_day = self.ae_day.encode(img_day)
            latent_night, noise_night = self.ae_night.encode(img_night)

            # same domain reconstruction
            reconst_day = self.ae_day.decode(latent_day + noise_day)
            reconst_night = self.ae_night.decode(latent_night + noise_night)

            # cross domain
            night_to_day = self.ae_day.decode(latent_night + noise_night)
            day_to_night = self.ae_night.decode(latent_day + noise_day)

            # encode again for cycle loss
            latent_night_to_day, noise_night_to_day = self.ae_day.encode(
                night_to_day)
            latent_day_to_night, noise_day_to_night = self.ae_night.encode(
                day_to_night)

            # aaaand decode again
            reconst_cycle_day = self.ae_day.decode(latent_day_to_night +
                                                   noise_day_to_night)
            reconst_cycle_night = self.ae_night.decode(latent_night_to_day +
                                                       noise_night_to_day)

            # loss formulations
            loss_reconst_day = self.reconst_loss(reconst_day, img_day)
            loss_reconst_night = self.reconst_loss(reconst_night, img_night)
            loss_kl_reconst_day = kl_loss(latent_day)
            loss_kl_reconst_night = kl_loss(latent_night)
            loss_cycle_day = self.reconst_loss(reconst_cycle_day, img_day)
            loss_cycle_night = self.reconst_loss(reconst_cycle_night,
                                                 img_night)
            loss_kl_cycle_day = kl_loss(latent_night_to_day)
            loss_kl_cycle_night = kl_loss(latent_day_to_night)

            loss = \
                self.params['loss_reconst'] * (loss_reconst_day + loss_reconst_night) + \
                self.params['loss_kl_reconst'] * (loss_kl_reconst_day + loss_kl_reconst_night) + \
                self.params['loss_cycle'] * (loss_cycle_day + loss_cycle_night) + \
                self.params['loss_kl_cycle'] * (loss_kl_cycle_day + loss_kl_cycle_night)

            loss.backward()
            self.optimizer.step()

            loss_sum += loss.detach().item()

        loss_mean = loss_sum / len(train_loader)
        self.scheduler.step(loss_mean, epoch)

        # log loss
        log_str = f'[Epoch {epoch}] Train loss: {loss_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

    def validate(self, val_loader, epoch, use_cuda, log_path, **kwargs):
        loss_sum = 0
        img_day, img_night, reconst_day, reconst_night, reconst_cycle_day, reconst_cycle_night = (
            None, ) * 6

        with torch.no_grad():
            for img_day, img_night in val_loader:
                if use_cuda:
                    img_day, img_night = img_day.cuda(), img_night.cuda()

                latent_day, noise_day = self.ae_day.encode(img_day)
                latent_night, noise_night = self.ae_night.encode(img_night)

                # same domain reconstruction
                reconst_day = self.ae_day.decode(latent_day + noise_day)
                reconst_night = self.ae_night.decode(latent_night +
                                                     noise_night)

                # cross domain
                night_to_day = self.ae_day.decode(latent_night + noise_night)
                day_to_night = self.ae_night.decode(latent_day + noise_day)

                # encode again for cycle loss
                latent_night_to_day, noise_night_to_day = self.ae_day.encode(
                    night_to_day)
                latent_day_to_night, noise_day_to_night = self.ae_night.encode(
                    day_to_night)

                # aaaand decode again
                reconst_cycle_day = self.ae_day.decode(latent_day_to_night +
                                                       noise_day_to_night)
                reconst_cycle_night = self.ae_night.decode(
                    latent_night_to_day + noise_night_to_day)

                # loss formulations
                loss_reconst_day = self.reconst_loss(reconst_day, img_day)
                loss_reconst_night = self.reconst_loss(reconst_night,
                                                       img_night)
                loss_kl_reconst_day = kl_loss(latent_day)
                loss_kl_reconst_night = kl_loss(latent_night)
                loss_cycle_day = self.reconst_loss(reconst_cycle_day, img_day)
                loss_cycle_night = self.reconst_loss(reconst_cycle_night,
                                                     img_night)
                loss_kl_cycle_day = kl_loss(latent_night_to_day)
                loss_kl_cycle_night = kl_loss(latent_day_to_night)

                loss = \
                    self.params['loss_reconst'] * (loss_reconst_day + loss_reconst_night) + \
                    self.params['loss_kl_reconst'] * (loss_kl_reconst_day + loss_kl_reconst_night) + \
                    self.params['loss_cycle'] * (loss_cycle_day + loss_cycle_night) + \
                    self.params['loss_kl_cycle'] * (loss_kl_cycle_day + loss_kl_cycle_night)

                loss_sum += loss.detach().item()

        loss_mean = loss_sum / len(val_loader)

        # domain translation
        day_to_night = self.ae_night.decode(
            self.ae_day.encode(img_day[0].unsqueeze(0))[0])
        night_to_day = self.ae_day.decode(
            self.ae_night.encode(img_night[0].unsqueeze(0))[0])

        # log loss
        log_str = f'[Epoch {epoch}] Val loss: {loss_mean}'
        print(log_str)
        with open(os.path.join(log_path, 'log.txt'), 'a+') as f:
            f.write(log_str + '\n')

        # save sample images
        samples = {
            'day_img': img_day[0],
            'night_img': img_night[0],
            'reconst_day': reconst_day[0],
            'reconst_night': reconst_night[0],
            'reconst_cycle_day': reconst_cycle_day[0],
            'reconst_cycle_night': reconst_cycle_night[0],
            'day_to_night': day_to_night[0],
            'night_to_day': night_to_day[0]
        }

        for name, img in samples.items():
            ToPILImage()(img.cpu()).save(
                os.path.join(log_path, f'{epoch}_{name}.jpeg'), 'JPEG')

    def train(self):
        self.ae_day.train()
        self.ae_night.train()

    def eval(self):
        self.ae_day.eval()
        self.ae_night.eval()

    def cuda(self):
        self.ae_day.cuda()
        self.ae_night.cuda()

    def state_dict(self):
        return {
            'encoder_lower_day': self.ae_day.encoder_lower.state_dict(),
            'encoder_lower_night': self.ae_night.encoder_lower.state_dict(),
            'encoder_upper': self.ae_day.encoder_upper.state_dict(),
            'decoder_lower': self.ae_day.decoder_lower.state_dict(),
            'decoder_upper_day': self.ae_day.decoder_upper.state_dict(),
            'decoder_upper_night': self.ae_night.decoder_upper.state_dict()
        }

    def optim_state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state):
        self.ae_day.encoder_lower.load_state_dict(state['encoder_lower_day'])
        self.ae_night.encoder_lower.load_state_dict(
            state['encoder_lower_night'])
        self.ae_day.encoder_upper.load_state_dict(state['encoder_upper'])

        self.ae_day.decoder_lower.load_state_dict(state['decoder_lower'])
        self.ae_day.decoder_upper.load_state_dict(state['decoder_upper_day'])
        self.ae_night.decoder_upper.load_state_dict(
            state['decoder_upper_night'])

    def load_optim_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)