def __init__(self, train_loader, test_loader, valid_loader, general_args, trainer_args):
        super(AutoEncoderTrainer, self).__init__(train_loader, test_loader, valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Model
        self.autoencoder = AutoEncoder(general_args=general_args).to(self.device)

        # Optimizer and scheduler
        self.optimizer = torch.optim.Adam(params=self.autoencoder.parameters(), lr=trainer_args.lr)
        self.scheduler = lr_scheduler.StepLR(optimizer=self.optimizer,
                                             step_size=trainer_args.scheduler_step,
                                             gamma=trainer_args.scheduler_gamma)

        # Load saved states
        if os.path.exists(trainer_args.loadpath):
            self.load()

        # Loss function
        self.time_criterion = nn.MSELoss()
        self.frequency_criterion = nn.MSELoss()

        # Boolean to differentiate generator from auto-encoder
        self.is_autoencoder = True
    def __init__(self, in_features, en_features, de_features, in_size,
                 h_channels, kernel_sizes, num_layers, fc_h_features,
                 out_features, **kwargs):
        """
        params:
            in_features (int): size of input sample
            en_features (list): list of number of features for the encoder layers
            de_features (list): list of number of features for the decoder layers
            in_size (int, int): height and width of input tensor as (height, width)
            h_channels (int or list): number of channels of hidden state, assert len(h_channels) == num_layers
            kernel_sizes (list): size of the convolution kernels
            num_layers (int): number of layers in ConvLSTM
            fc_h_features (int): size of hidden features in the FC layer
            out_features (int): size of output sample
        """

        super(DeepLatte, self).__init__()

        self.kwargs = kwargs
        self.device = kwargs.get('device', 'cpu')

        # sparse layer
        self.sparse_layer = DiagPruneLinear(in_features=in_features,
                                            device=self.device)

        # auto_encoder layer
        self.ae = AutoEncoder(in_features=in_features,
                              en_features=en_features,
                              de_features=de_features)

        if kwargs.get('ae_pretrain_weight') is not None:
            self.ae.load_state_dict(kwargs['ae_pretrain_weight'])
            for param in self.ae.parameters():
                param.requires_grad = True

        # conv_lstm layers
        h_channels = self._extend_for_multilayer(
            h_channels, num_layers)  # len(h_channels) == num_layers
        self.conv_lstm_list = nn.ModuleList()
        for i in kernel_sizes:
            self.conv_lstm_list.append(
                ConvLSTM(in_size=in_size,
                         in_channels=en_features[-1],
                         h_channels=h_channels,
                         kernel_size=(i, i),
                         num_layers=num_layers,
                         batch_first=kwargs.get('batch_first', True),
                         output_last=kwargs.get('only_last_state', True),
                         device=self.device))

        self.fc = Stack2Linear(in_features=h_channels[-1] * len(kernel_sizes),
                               h_features=fc_h_features,
                               out_features=out_features)
示例#3
0
def test_autoencoder(image_shape, batch_size=128, latent_size=2):
    encoder = AutoEncoder(image_shape, latent_size=latent_size)
    batch = torch.zeros(batch_size, *image_shape)
    x_rec, z = encoder(batch)

    assert x_rec.shape == batch.shape
    assert z.shape == (batch_size, latent_size)
def main():

    dataset = PrepareDataset()
    cfg = Config()
    plotter = Plotter()
    preprocessor = Preprocess()

    train_img, train_label, test_img, test_label = dataset.get_dataset()

    train_img, train_label = dataset.filter_2500(images=train_img,
                                                 labels=train_label,
                                                 class_num_to_filter=[2, 4, 9],
                                                 images_to_keep=2500)

    # dataset.plot_sample(x_train)

    x_train = preprocessor.normalize(train_img)
    y_train = x_train.copy()

    if cfg.add_noise:
        x_train = preprocessor.add_noise(x_train)
    # dataset.plot_sample(x_train)
    x_test = preprocessor.normalize(test_img)
    x_val = x_test.copy()
    y_val = x_test.copy()

    model, classifier = AutoEncoder(return_encoder_classifier=True,
                                    use_skip_conn=cfg.use_skip_conn)

    #model = EncoderWithSkipConn()

    if os.path.exists(cfg.model_name):
        print("Saved Model found, Loading...")
        model = load_model(cfg.model_name)

    if not os.path.exists(cfg.model_name) and cfg.train_encoder == False:
        raise ('No saved model')

    if cfg.train_encoder:
        history = train_encoder(cfg, model, x_train, y_train, x_val, y_val)

        plotter.plot_history(history)
    test_encoder(x_test, test_label, model, plotter)

    train_img = preprocessor.normalize(train_img)
    train_label = to_categorical(train_label)

    #print(train_label)

    classifier = copy_weights(model, classifier)
    classifier.save("classifier.h5")

    history = train_classifier(cfg, classifier, train_img, train_label)

    plotter.plot_history(history)

    test_img = preprocessor.normalize(test_img)
    test_label = [i[0] for i in test_label.tolist()]
    test_classifier(cfg, classifier, test_img, test_label, plotter)
示例#5
0
def main():
    global args, best_loss
    args = parser.parse_args()

    # model = AutoEncoder().cuda()
    model = AutoEncoder(items=9)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-5)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loader
    train_loader = DataLoader(trainData,
                              batch_size=args.batch_size,
                              shuffle=True)
    val_loader = DataLoader(valData, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(testData,
                             batch_size=args.batch_size,
                             shuffle=True)

    if args.evaluate:
        test(test_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        loss1 = validate(val_loader, model, criterion)

        # remember best loss and save checkpoint
        is_best = loss1 < best_loss
        best_loss = min(loss1, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
                'optimizer': optimizer.state_dict(),
            }, is_best)
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(GanTrainer, self).__init__(train_loader, test_loader,
                                         valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the auto-encoder
        self.use_autoencoder = False
        if trainer_args.autoencoder_path and os.path.exists(
                trainer_args.autoencoder_path):
            self.use_autoencoder = True
            self.autoencoder = AutoEncoder(general_args=general_args).to(
                self.device)
            self.load_pretrained_autoencoder(trainer_args.autoencoder_path)
            self.autoencoder.eval()

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.generator_time_criterion = nn.MSELoss()
        self.generator_frequency_criterion = nn.MSELoss()
        self.generator_autoencoder_criterion = nn.MSELoss()

        # Define labels
        self.real_label = 1
        self.generated_label = 0

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_freq = trainer_args.lambda_freq
        self.lambda_autoencoder = trainer_args.lambda_autoencoder

        # Spectrogram converter
        self.spectrogram = Spectrogram(normalized=True).to(self.device)

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Boolean if the generator receives the feedback from the discriminator
        self.use_adversarial = trainer_args.use_adversarial
class GanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(GanTrainer, self).__init__(train_loader, test_loader,
                                         valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the auto-encoder
        self.use_autoencoder = False
        if trainer_args.autoencoder_path and os.path.exists(
                trainer_args.autoencoder_path):
            self.use_autoencoder = True
            self.autoencoder = AutoEncoder(general_args=general_args).to(
                self.device)
            self.load_pretrained_autoencoder(trainer_args.autoencoder_path)
            self.autoencoder.eval()

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.generator_time_criterion = nn.MSELoss()
        self.generator_frequency_criterion = nn.MSELoss()
        self.generator_autoencoder_criterion = nn.MSELoss()

        # Define labels
        self.real_label = 1
        self.generated_label = 0

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_freq = trainer_args.lambda_freq
        self.lambda_autoencoder = trainer_args.lambda_autoencoder

        # Spectrogram converter
        self.spectrogram = Spectrogram(normalized=True).to(self.device)

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Boolean if the generator receives the feedback from the discriminator
        self.use_adversarial = trainer_args.use_adversarial

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def load_pretrained_autoencoder(self, autoencoder_path):
        """
        Loads a pre-trained auto-encoder. Can be used to infer
        :param autoencoder_path: location of the pre-trained auto-encoder (string).
        :return: None
        """
        checkpoint = torch.load(autoencoder_path, map_location=self.device)
        self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])

    def train(self, epochs):
        """
        Trains the GAN for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                self.generator.train()
                self.discriminator.train()
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)
                batch_size = input_batch.shape[0]

                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                # Train the discriminator with real data
                self.discriminator_optimizer.zero_grad()
                label = torch.full((batch_size, ),
                                   self.real_label,
                                   device=self.device)
                output = self.discriminator(target_batch)

                # Compute and store the discriminator loss on real data
                loss_discriminator_real = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['real'].append(
                    loss_discriminator_real.item())
                loss_discriminator_real.backward()

                # Train the discriminator with fake data
                generated_batch = self.generator(input_batch)
                label.fill_(self.generated_label)
                output = self.discriminator(generated_batch.detach())

                # Compute and store the discriminator loss on fake data
                loss_discriminator_generated = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['fake'].append(
                    loss_discriminator_generated.item())
                loss_discriminator_generated.backward()

                # Update the discriminator weights
                self.discriminator_optimizer.step()

                ############################
                # Update G network: maximize log(D(G(z)))
                ###########################
                self.generator_optimizer.zero_grad()

                # Get the spectrogram
                specgram_target_batch = self.spectrogram(target_batch)
                specgram_fake_batch = self.spectrogram(generated_batch)

                # Fake labels are real for the generator cost
                label.fill_(self.real_label)
                output = self.discriminator(generated_batch)

                # Compute the generator loss on fake data
                # Get the adversarial loss
                loss_generator_adversarial = torch.zeros(size=[1],
                                                         device=self.device)
                if self.use_adversarial:
                    loss_generator_adversarial = self.adversarial_criterion(
                        output, torch.unsqueeze(label, dim=1))
                self.train_losses['generator_adversarial'].append(
                    loss_generator_adversarial.item())

                # Get the L2 loss in time domain
                loss_generator_time = self.generator_time_criterion(
                    generated_batch, target_batch)
                self.train_losses['time_l2'].append(loss_generator_time.item())

                # Get the L2 loss in frequency domain
                loss_generator_frequency = self.generator_frequency_criterion(
                    specgram_fake_batch, specgram_target_batch)
                self.train_losses['freq_l2'].append(
                    loss_generator_frequency.item())

                # Get the L2 loss in embedding space
                loss_generator_autoencoder = torch.zeros(size=[1],
                                                         device=self.device,
                                                         requires_grad=True)
                if self.use_autoencoder:
                    # Get the embeddings
                    _, embedding_target_batch = self.autoencoder(target_batch)
                    _, embedding_generated_batch = self.autoencoder(
                        generated_batch)
                    loss_generator_autoencoder = self.generator_autoencoder_criterion(
                        embedding_generated_batch, embedding_target_batch)
                    self.train_losses['autoencoder_l2'].append(
                        loss_generator_autoencoder.item())

                # Combine the different losses
                loss_generator = self.lambda_adv * loss_generator_adversarial + loss_generator_time + \
                                 self.lambda_freq * loss_generator_frequency + \
                                 self.lambda_autoencoder * loss_generator_autoencoder

                # Back-propagate and update the generator weights
                loss_generator.backward()
                self.generator_optimizer.step()

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Frequency: {} \n' \
                              '\t\t Autoencoder {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Real {} \n' \
                              '\t\t Fake {} \n'.format(i,
                                                       loss_generator_time.item(),
                                                       loss_generator_frequency.item(),
                                                       loss_generator_autoencoder.item(),
                                                       loss_generator_adversarial.item(),
                                                       loss_discriminator_real.item(),
                                                       loss_discriminator_generated.item())
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()
            # if self.need_saving:
            #     self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': [], 'freq_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            # Get the spectrogram
            specgram_target_batch = self.spectrogram(target_batch)
            specgram_generated_batch = self.spectrogram(generated_batch)

            loss_generator_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_generator_time.item())
            loss_generator_frequency = self.generator_frequency_criterion(
                specgram_generated_batch, specgram_target_batch)
            batch_losses['freq_l2'].append(loss_generator_frequency.item())

        # Store the validation losses
        self.valid_losses['time_l2'].append(np.mean(batch_losses['time_l2']))
        self.valid_losses['freq_l2'].append(np.mean(batch_losses['freq_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n' \
                  '\t Frequency: {} \n'.format(self.epoch,
                                               np.mean(np.mean(batch_losses['time_l2'])),
                                               np.mean(np.mean(batch_losses['freq_l2'])))
        print(message)

        # Check if the loss is decreasing
        self.check_improvement()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, self.savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(
            checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        self.discriminator_optimizer.load_state_dict(
            checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)
示例#8
0
def run(tuner=None, config=None):

    if config.load_data_from_pkl:
        #To speed up chain. Postprocessing involves loop over data for normalisation.
        #Load that data already prepped.

        import pickle
        dataFile = open("/Users/drdre/inputz/MNIST/preprocessed/full.pkl",
                        "rb")
        train_loader = pickle.load(dataFile)
        test_loader = pickle.load(dataFile)
        input_dimension = pickle.load(dataFile)
        train_ds_mean = pickle.load(dataFile)
        dataFile.close()
        tuner.register_dataLoaders(train_loader, test_loader)

    else:
        #load data, internally registers train and test dataloaders
        tuner.register_dataLoaders(*load_data(config=config))

        input_dimension = tuner.get_input_dimension()

        train_ds_mean = tuner.get_train_dataset_mean()

        # import pickle
        # dataFile=open("/Users/drdre/inputz/calo/preprocessed/all_la.pkl","wb")
        # pickle.dump(tuner.train_loader,dataFile)
        # pickle.dump(tuner.test_loader,dataFile)
        # pickle.dump(input_dimension,dataFile)
        # pickle.dump(train_ds_mean,dataFile)
        # dataFile.close()

    #set model properties
    model = None
    activation_fct = torch.nn.ReLU() if config.model.activation_fct.lower(
    ) == "relu" else None

    configString = "_".join(
        str(i) for i in [
            config.model.model_type, config.data.data_type,
            config.n_train_samples, config.n_test_samples,
            config.engine.n_batch_samples, config.engine.n_epochs, config.
            engine.learning_rate, config.model.n_latent_hierarchy_lvls, config.
            model.n_latent_nodes, config.model.activation_fct, config.tag
        ])

    date = datetime.datetime.now().strftime("%y%m%d")

    if config.data.data_type == 'calo':
        configString += "_nlayers_{0}_{1}".format(len(config.data.calo_layers),
                                                  config.particle_type)

    #TODO wrap all these in a container class
    if config.model.model_type == "AE":
        model = AutoEncoder(input_dimension=input_dimension,
                            config=config,
                            activation_fct=activation_fct)

    elif config.model.model_type == "sparseAE":
        model = SparseAutoEncoder(input_dimension=input_dimension,
                                  config=config,
                                  activation_fct=activation_fct)

    elif config.model.model_type == "VAE":
        model = VariationalAutoEncoder(input_dimension=input_dimension,
                                       config=config,
                                       activation_fct=activation_fct)

    elif config.model.model_type == "cVAE":
        model = ConditionalVariationalAutoEncoder(
            input_dimension=input_dimension,
            config=config,
            activation_fct=activation_fct)

    elif config.model.model_type == "sVAE":
        model = SequentialVariationalAutoEncoder(
            input_dimension=input_dimension,
            config=config,
            activation_fct=activation_fct)

    elif config.model.model_type == "HiVAE":
        model = HierarchicalVAE(input_dimension=input_dimension,
                                activation_fct=activation_fct,
                                config=config)

    elif config.model.model_type == "DiVAE":
        activation_fct = torch.nn.Tanh()
        model = DiVAE(input_dimension=input_dimension,
                      config=config,
                      activation_fct=activation_fct)
    else:
        logger.debug("ERROR Unknown Model Type")
        raise NotImplementedError

    model.create_networks()
    model.set_dataset_mean(train_ds_mean)
    # model.set_input_dimension(input_dimension)

    #TODO avoid this if statement
    if config.model.model_type == "DiVAE": model.set_train_bias()

    model.print_model_info()
    optimiser = torch.optim.Adam(model.parameters(),
                                 lr=config.engine.learning_rate)

    tuner.register_model(model)
    tuner.register_optimiser(optimiser)

    if not config.load_model:
        gif_frames = []
        logger.debug("Start Epoch Loop")
        for epoch in range(1, config.engine.n_epochs + 1):
            train_loss = tuner.train(epoch)
            test_loss, input_data, output_data, zetas, labels = tuner.test()

            if config.create_gif:
                #TODO improve
                if config.data.data_type == 'calo':
                    gif_frames.append(
                        plot_calo_images(input_data,
                                         output_data,
                                         output="{0}/{2}_reco_{1}.png".format(
                                             config.output_path, configString,
                                             date),
                                         do_gif=True))
                else:
                    gif_frames.append(
                        gif_output(input_data,
                                   output_data,
                                   epoch=epoch,
                                   max_epochs=config.engine.n_epochs,
                                   train_loss=train_loss,
                                   test_loss=test_loss))

            if model.type == "DiVAE" and config.sample_from_prior:
                random_samples = model.generate_samples()
                #TODO make a plot of the output

        if config.create_gif:
            gif.save(gif_frames,
                     "{0}/runs_{1}.gif".format(config.output_path,
                                               configString),
                     duration=200)

        if config.save_model:
            tuner.save_model(configString)
            if model.type == "DiVAE":
                tuner.save_rbm(configString)

    else:
        tuner.load_model(set_eval=True)

    #TODO move this around
    if config.generate_samples:
        if config.load_model:
            configString = config.input_model.split("/")[-1].replace('.pt', '')

        if config.model.model_type == "DiVAE":
            from utils.generate_samples import generate_samples_divae
            generate_samples_divae(tuner._model, configString)

        #TODO split this up in plotting and generation routine and have one
        #common function for all generative models.
        elif config.model.model_type == "VAE":
            from utils.generate_samples import generate_samples_vae
            generate_samples_vae(tuner._model, configString)

        elif config.model.model_type == "cVAE":
            from utils.generate_samples import generate_samples_cvae
            generate_samples_cvae(tuner._model, configString)

        elif config.model.model_type == "sVAE":
            from utils.generate_samples import generate_samples_svae
            generate_samples_svae(tuner._model, configString)

    if config.create_plots:
        if config.data.data_type == 'calo':
            if config.model.model_type == "sVAE":
                test_loss, input_data, output_data, zetas, labels = tuner.test(
                )
                plot_calo_image_sequence(input_data,
                                         output_data,
                                         input_dimension,
                                         output="{0}/{2}_{1}.png".format(
                                             config.output_path, configString,
                                             date))
            else:
                test_loss, input_data, output_data, zetas, labels = tuner.test(
                )
                plot_calo_images(input_data,
                                 output_data,
                                 output="{0}/{2}_reco_{1}.png".format(
                                     config.output_path, configString, date))
        else:
            test_loss, input_data, output_data, zetas, labels = tuner.test()
            if not config.model.model_type == "cVAE" and not config.model.model_type == "DiVAE":
                plot_latent_space(zetas,
                                  labels,
                                  output="{0}/{2}_latSpace_{1}".format(
                                      config.output_path, configString, date),
                                  dimensions=0)
            plot_MNIST_output(input_data,
                              output_data,
                              output="{0}/{2}_reco_{1}.png".format(
                                  config.output_path, configString, date))
def test(dset,
         dset_type,
         model_name,
         epochs=30,
         optimizer='adam',
         loss_fun='binary_crossentropy',
         batch_size=64,
         kernel_size=(5, 5),
         n_values=[1, 3, 5, 10],
         net_parameters=None):

    ext = 'csv' if dset_type == 'csv' else 'npz'

    norm_dset_path = f'dset/{dset}/{dset_type}/normal.{ext}'
    atk_dset_path = f'dset/{dset}/{dset_type}/attack.{ext}'

    is_img = dset_type != 'csv'
    normalize = dset_type != 'bin'

    print_with_time(f'[INFO] Loading {dset_type} dataset {dset}...')
    normal_dataset, attack_dataset, input_shape = load_dataset(
        norm_dset_path, atk_dset_path, is_img, normalize=normalize)

    x_train, x_test, x_test_treshold, x_val = split_normal_dataset(
        normal_dataset)
    x_abnormal = attack_dataset  #[:x_test.shape[0]]

    print_with_time(f'[INFO] Dataset {dset} loaded.')

    x_train, x_test, x_test_treshold, x_val, x_abnormal, input_shape = parse_datasets(
        x_train, x_test, x_test_treshold, x_val, x_abnormal, input_shape, dset,
        dset_type, model_name)

    n_results = {}

    for n in n_values:

        print_with_time(f'[INFO] Training {n} network(s)')

        z_values = [0, 1, 2, 3]
        y_true = np.array([False] * x_test.shape[0] +
                          [True] * x_abnormal.shape[0])
        y_pred_matrix = np.zeros((n, len(z_values), y_true.shape[0]))

        cols = x_test.shape[0] + x_abnormal.shape[0]
        cols_tresh = x_test_treshold.shape[0]

        final_losses = np.zeros(n)
        pred_losses = np.zeros((n, cols))
        pred_losses_tresh = np.zeros((n, cols_tresh))

        for i in range(n):

            print_with_time(f'[INFO] Training {i}th {model_name}')

            layers_size = net_parameters[model_name]

            model = AutoEncoder(model_name, input_shape, layers_size,
                                kernel_size).model
            model.compile(optimizer=optimizer, loss=loss_fun)

            #print(model.summary())

            # train on only normal training data
            history = model.fit(x=x_train,
                                y=x_train,
                                epochs=epochs,
                                batch_size=batch_size,
                                validation_data=(x_val, x_val),
                                verbose=0)

            print_with_time(f'[INFO] {i}th {model_name} trained for {epochs}')

            # [-1] takes the last value of the loss given all the iterations
            final_losses[i] = np.nanmean(history.history['val_loss'])

            # test
            losses = []
            x_concat = np.concatenate([x_test, x_abnormal], axis=0)

            for x in x_concat:
                #compute loss for each test sample
                x = np.expand_dims(x, axis=0)
                loss = model.test_on_batch(x, x)
                losses.append(loss)

            pred_losses[i] = losses
            print_with_time(
                f'[INFO] {i}th {model_name} tested on {x_concat.shape[0]} samples'
            )

            losses = []

            for x in x_test_treshold:
                x = np.expand_dims(x, axis=0)
                loss = model.test_on_batch(x, x)
                losses.append(loss)

            pred_losses_tresh[i] = losses
            print_with_time(f'[INFO] Tested {i}th {model_name}')

        mean_tresh_losses = np.nanmean(
            pred_losses_tresh, axis=0) if n > 1 else pred_losses_tresh[0]
        var_tresh_losses = np.nanvar(pred_losses_tresh,
                                     axis=0) if n > 1 else pred_losses_tresh[0]

        mean_pred_losses = np.nanmean(pred_losses,
                                      axis=0) if n > 1 else pred_losses[0]
        var_pred_losses = np.nanvar(pred_losses,
                                    axis=0) if n > 1 else pred_losses[0]

        filename = f'{n}_{model_name}_{epochs}_{dset}_{dset_type}'
        plot_predictions(var_tresh_losses, var_pred_losses, filename, n)
        print_with_time('[INFO] Plotted results')

        z_results = {}

        for z in z_values:

            treshold_upperbound = np.nanmean(
                var_tresh_losses) + np.nanstd(var_tresh_losses) * z
            treshold_lowerbound = np.nanmean(
                var_tresh_losses) - np.nanstd(var_tresh_losses) * z

            if n > 1:
                y_pred = np.array(
                    [x > treshold_upperbound for x in var_pred_losses])

            else:
                y_pred = np.array([
                    x < treshold_lowerbound or x > treshold_upperbound
                    for x in var_pred_losses
                ])

            print(f'Z: {z}', end=', ')
            acc, f1 = compute_and_save_metrics(y_true, y_pred, False)

            z_results[z] = (acc, f1)

            n_results[n] = z_results

    return n_results
    'L_ora': args.L_ora,
}

if args.type == 'markov':
    fnet = FeatureNet(n_actions=4,
                      input_shape=x0.shape[1:],
                      n_latent_dims=args.latent_dims,
                      n_hidden_layers=1,
                      n_units_per_layer=32,
                      lr=args.learning_rate,
                      coefs=coefs)
elif args.type == 'autoencoder':
    fnet = AutoEncoder(n_actions=4,
                       input_shape=x0.shape[1:],
                       n_latent_dims=args.latent_dims,
                       n_hidden_layers=1,
                       n_units_per_layer=32,
                       lr=args.learning_rate,
                       coefs=coefs)
elif args.type == 'pixel-predictor':
    fnet = PixelPredictor(n_actions=4,
                          input_shape=x0.shape[1:],
                          n_latent_dims=args.latent_dims,
                          n_hidden_layers=1,
                          n_units_per_layer=32,
                          lr=args.learning_rate,
                          coefs=coefs)

fnet.print_summary()

n_test_samples = 2000
class DeepLatte(nn.Module):
    def __init__(self, in_features, en_features, de_features, in_size,
                 h_channels, kernel_sizes, num_layers, fc_h_features,
                 out_features, **kwargs):
        """
        params:
            in_features (int): size of input sample
            en_features (list): list of number of features for the encoder layers
            de_features (list): list of number of features for the decoder layers
            in_size (int, int): height and width of input tensor as (height, width)
            h_channels (int or list): number of channels of hidden state, assert len(h_channels) == num_layers
            kernel_sizes (list): size of the convolution kernels
            num_layers (int): number of layers in ConvLSTM
            fc_h_features (int): size of hidden features in the FC layer
            out_features (int): size of output sample
        """

        super(DeepLatte, self).__init__()

        self.kwargs = kwargs
        self.device = kwargs.get('device', 'cpu')

        # sparse layer
        self.sparse_layer = DiagPruneLinear(in_features=in_features,
                                            device=self.device)

        # auto_encoder layer
        self.ae = AutoEncoder(in_features=in_features,
                              en_features=en_features,
                              de_features=de_features)

        if kwargs.get('ae_pretrain_weight') is not None:
            self.ae.load_state_dict(kwargs['ae_pretrain_weight'])
            for param in self.ae.parameters():
                param.requires_grad = True

        # conv_lstm layers
        h_channels = self._extend_for_multilayer(
            h_channels, num_layers)  # len(h_channels) == num_layers
        self.conv_lstm_list = nn.ModuleList()
        for i in kernel_sizes:
            self.conv_lstm_list.append(
                ConvLSTM(in_size=in_size,
                         in_channels=en_features[-1],
                         h_channels=h_channels,
                         kernel_size=(i, i),
                         num_layers=num_layers,
                         batch_first=kwargs.get('batch_first', True),
                         output_last=kwargs.get('only_last_state', True),
                         device=self.device))

        self.fc = Stack2Linear(in_features=h_channels[-1] * len(kernel_sizes),
                               h_features=fc_h_features,
                               out_features=out_features)

    def forward(self, input_data):  # shape: (b, t, c, h, w)
        """
        param:
            input_data (batch_size, seq_len, num_channels, height, width)
        """

        batch_size, seq_len, num_channels, _, _ = input_data.shape
        x = input_data.permute(
            0, 1, 3, 4,
            2)  # shape: (b, t, h, w, c), moving feature dimension to the last

        # sparse layer
        sparse_x = self.sparse_layer(x)

        # auto-encoder layer
        en_x, de_x = self.ae(sparse_x)
        en_x = en_x.permute(
            0, 1, 4, 2,
            3)  # shape: (b, t, c, h, w), moving height and weight to the last

        # conv_lstm layers
        conv_lstm_out_list = []
        for conv_lstm in self.conv_lstm_list:
            _, (_, cell_last_state) = conv_lstm(en_x)
            conv_lstm_out_list.append(cell_last_state)

        conv_lstm_out = torch.cat(conv_lstm_out_list,
                                  dim=1)  # shape: (b, c, h, w)

        # fully-connected layer
        out = conv_lstm_out.permute(
            0, 2, 3,
            1)  # shape: (b, h, w, c), moving feature dimension to the last
        out = self.fc(out)
        out = out.permute(
            0, 3, 1,
            2)  # shape: (b, c, h, w), moving height and weight to the last

        return out, sparse_x, en_x, de_x, conv_lstm_out

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param
class AutoEncoderTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args, trainer_args):
        super(AutoEncoderTrainer, self).__init__(train_loader, test_loader, valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Model
        self.autoencoder = AutoEncoder(general_args=general_args).to(self.device)

        # Optimizer and scheduler
        self.optimizer = torch.optim.Adam(params=self.autoencoder.parameters(), lr=trainer_args.lr)
        self.scheduler = lr_scheduler.StepLR(optimizer=self.optimizer,
                                             step_size=trainer_args.scheduler_step,
                                             gamma=trainer_args.scheduler_gamma)

        # Load saved states
        if os.path.exists(trainer_args.loadpath):
            self.load()

        # Loss function
        self.time_criterion = nn.MSELoss()
        self.frequency_criterion = nn.MSELoss()

        # Boolean to differentiate generator from auto-encoder
        self.is_autoencoder = True

    def train(self, epochs):
        """
        Trains the auto-encoder for a given number of pseudo-epochs
        :param epochs: Number of pseudo-epochs to perform
        :return: None
        """
        for epoch in range(epochs):
            self.autoencoder.train()
            for i in range(self.train_batches_per_epoch):
                data_batch = next(self.train_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = data_batch[0].to(self.device), data_batch[1].to(self.device)

                # Concatenate the input and target signals along first dimension and transfer to GPU
                self.optimizer.zero_grad()

                # Train with input samples
                generated_batch, _ = self.autoencoder(input_batch)
                specgram_input_batch = self.spectrogram(input_batch)
                specgram_generated_batch = self.spectrogram(generated_batch)

                # Compute the input losses
                input_time_l2_loss = self.time_criterion(generated_batch, input_batch)
                input_freq_l2_loss = self.frequency_criterion(specgram_generated_batch, specgram_input_batch)
                input_loss = input_time_l2_loss + input_freq_l2_loss
                input_loss.backward()

                # Train with target samples
                generated_batch, _ = self.autoencoder(target_batch)
                specgram_target_batch = self.spectrogram(target_batch)
                specgram_generated_batch = self.spectrogram(generated_batch)

                # Compute the input losses
                target_time_l2_loss = self.time_criterion(generated_batch, target_batch)
                target_freq_l2_loss = self.frequency_criterion(specgram_generated_batch, specgram_target_batch)
                target_loss = target_time_l2_loss + target_freq_l2_loss
                target_loss.backward()

                # Update weights
                self.optimizer.step()

                # Store losses
                self.train_losses['time_l2'].append((input_time_l2_loss + target_time_l2_loss).item())
                self.train_losses['freq_l2'].append((input_freq_l2_loss + target_freq_l2_loss).item())

            # Print message
            message = 'Train, epoch {}: \n' \
                      '\t Time: {} \n' \
                      '\t Frequency: {} \n'.format(
                self.epoch, np.mean(self.train_losses['time_l2'][-self.train_batches_per_epoch:]),
                np.mean(self.train_losses['freq_l2'][-self.train_batches_per_epoch:]))
            print(message)

            with torch.no_grad():
                self.eval()

            # Save the trainer state
            if self.need_saving:
                self.save()

            # Increment epoch counter
            self.epoch += 1
            self.scheduler.step()

    def eval(self):
        """
        Evaluate the the auto-encoder on the validation set.
        :return: None
        """
        with torch.no_grad():
            self.autoencoder.eval()
            batch_losses = {'time_l2': [], 'freq_l2': []}
            for i in range(self.valid_batches_per_epoch):
                # Transfer to GPU
                data_batch = next(self.valid_loader_iter)
                data_batch = torch.cat(data_batch).to(self.device)

                # Forward pass
                generated_batch, _ = self.autoencoder.forward(data_batch)

                # Get the spectrogram
                specgram_batch = self.spectrogram(data_batch)
                specgram_generated_batch = self.spectrogram(generated_batch)

                # Compute and store the loss
                time_l2_loss = self.time_criterion(generated_batch, data_batch)
                freq_l2_loss = self.frequency_criterion(specgram_generated_batch, specgram_batch)
                batch_losses['time_l2'].append(time_l2_loss.item())
                batch_losses['freq_l2'].append(freq_l2_loss.item())

            # Store the validation losses
            self.valid_losses['time_l2'].append(np.mean(batch_losses['time_l2']))
            self.valid_losses['freq_l2'].append(np.mean(batch_losses['freq_l2']))

            # Display the validation loss
            message = 'Validation, epoch {}: \n' \
                      '\t Time: {} \n' \
                      '\t Frequency: {} \n'.format(self.epoch,
                                                   np.mean(np.mean(batch_losses['time_l2'])),
                                                   np.mean(np.mean(batch_losses['freq_l2'])))
            print(message)

            # Check if the loss is decreasing
            self.check_improvement()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        torch.save({
            'epoch': self.epoch,
            'autoencoder_state_dict': self.autoencoder.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'train_losses': self.train_losses,
            'test_losses': self.test_losses,
            'valid_losses': self.valid_losses
        }, self.savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def plot_autoencoder_embedding_space(self, n_batches, fig_savepath=None):
        """
        Plots a 2D representation of the embedding space. Can be useful to determine whether or not the auto-encoder's
        features can be used to improve the generation of realistic samples.
        :param n_batches: number of batches to use for the plot.
        :param fig_savepath: location where to save the figure
        :return: None
        """
        n_pairs = n_batches * self.valid_loader.batch_size
        n_features = 9
        with torch.no_grad():
            autoencoder = self.autoencoder.eval()
            embeddings = []
            for k in range(n_batches):
                # Transfer to GPU
                data_batch = next(self.valid_loader_iter)
                data_batch = torch.cat(data_batch).to(self.device)

                # Forward pass
                _, embedding_batch = autoencoder(data_batch)

                # Store the embeddings
                embeddings.append(embedding_batch)

            # Convert list to tensor
            embeddings = torch.cat(embeddings)

        # Randomly select features from the channel dimension
        random_features = np.random.randint(embeddings.shape[1], size=n_features)
        # Plot embeddings
        fig, axes = plt.subplots(3, 3, figsize=(12, 12))
        for i, random_feature in enumerate(random_features):
            # Map embedding to a 2D representation
            tsne = TSNE(n_components=2, verbose=0, perplexity=50)
            tsne_results = tsne.fit_transform(embeddings[:, random_feature, :].detach().cpu().numpy())
            for k in range(2):
                label = ('input' if k == 0 else 'target')
                axes[i // 3][i % 3].scatter(tsne_results[k * n_pairs: (k + 1) * n_pairs, 0],
                                            tsne_results[k * n_pairs: (k + 1) * n_pairs:, 1], label=label)
                axes[i // 3][i % 3].set_title('Channel {}'.format(random_feature), fontsize=14)
                axes[i // 3][i % 3].legend()

        # Save plot if needed
        if fig_savepath:
            plt.savefig(fig_savepath)
        plt.show()
示例#13
0
    'sample_data/processed_data/autoencoder_data/train_x.csv', index_col=0)
a_train_y = pd.read_csv(
    'sample_data/processed_data/autoencoder_data/train_y.csv', index_col=0)
a_test_x = pd.read_csv(
    'sample_data/processed_data/autoencoder_data/test_x.csv', index_col=0)
a_test_y = pd.read_csv(
    'sample_data/processed_data/autoencoder_data/test_y.csv', index_col=0)
print(a_train_x.head())
print(a_train_x.shape)

print('Scaling data...')
scaler = MinMaxScaler(feature_range=(-1, 1))
x_train_a = scaler.fit_transform(a_train_x.iloc[:, 1:])
x_test_a = scaler.transform(a_test_x.iloc[:, 1:])

autoencoder = AutoEncoder(20, x_train_a.shape[1])
autoencoder.build_model(100, 50, 50, 100)

print('Training model...')
autoencoder.train_model(autoencoder.autoencoder,
                        x_train_a,
                        epochs=20,
                        model_name='autoencoder')

print('Testing model...')
autoencoder.test_model(autoencoder.autoencoder, x_test_a)

print('Encoding data...')
a_full_data = pd.read_csv(
    'sample_data/processed_data/autoencoder_data/full_x.csv', index_col=0)
a_scaled_full = pd.DataFrame(scaler.transform(a_full_data.iloc[:, 1:]))