def train_dcgan(*,
                generator,
                discriminator,
                train_iterator,
                device,
                n_epoch,
                generator_opt,
                discriminator_opt,
                n_noise_channels,
                callbacks: Sequence[Callable] = None,
                logger: TBLogger):
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    criterion = F.binary_cross_entropy_with_logits

    callbacks = callbacks or []
    for epoch in tqdm(range(n_epoch)):
        generator_losses = []
        discriminator_losses_on_real = []
        discriminator_losses_on_fake = []

        with train_iterator as iterator:
            for real_batch, _ in iterator:
                real_batch = transform_gan(real_batch)
                batch_size = len(real_batch)
                discriminator_opt.zero_grad()
                # train discriminator on real
                real_loss = process_batch(real_batch,
                                          torch.ones(batch_size, 1, 1, 1),
                                          discriminator, criterion)
                # train discriminator on fake
                noise = generate_noise(batch_size, n_noise_channels, device)
                fake_batch = generator(noise)
                fake_loss = process_batch(fake_batch.detach(),
                                          torch.zeros(batch_size, 1, 1, 1),
                                          discriminator, criterion)
                discriminator_opt.step()
                # train generator
                generator_opt.zero_grad()
                target_for_generator = torch.ones(batch_size, 1, 1, 1)
                generator_loss = process_batch(fake_batch,
                                               target_for_generator,
                                               discriminator, criterion)
                generator_opt.step()

                generator_losses.append(generator_loss)
                discriminator_losses_on_real.append(real_loss)
                discriminator_losses_on_fake.append(fake_loss)

        # run callbacks
        for callback in callbacks:
            callback(epoch)

        losses = {
            'Generator': np.mean(generator_losses),
            'Discriminator on fake': np.mean(discriminator_losses_on_fake),
            'Discriminator on real': np.mean(discriminator_losses_on_real)
        }
        logger.policies(losses, epoch)
def run_experiment(*, device, download: bool, n_epoch: int, batch_size: int,
                   n_noise_channels: int, data_path: str,
                   experiment_path: str):
    # path to save everything related to experiment
    data_path = Path(data_path).expanduser()
    experiment_path = Path(experiment_path).expanduser()
    # dataset and batch iterator
    dataset = CelebDataset(root=data_path, download=download)
    indices = list(range(len(dataset)))
    train_iterator = DataBatchIterator(dataset, indices, batch_size=batch_size)
    # models
    generator = Generator(in_channels=n_noise_channels).to(device)
    discriminator = Discriminator().to(device)
    generator.apply(init_weights)
    discriminator.apply(init_weights)

    # TODO: remove hardcode
    optimizer_parameters = dict(lr=1e-4, betas=(0.5, 0.99))
    generator_opt = Adam(generator.parameters(), **optimizer_parameters)
    discriminator_opt = Adam(discriminator.parameters(),
                             **optimizer_parameters)

    fixed_noise = torch.randn(64, n_noise_channels, 1, 1, device=device)

    def predict_on_fixed_noise(epoch, prefix='fixed_noise', compression=1):
        predict = to_numpy(inference_step(fixed_noise, generator))
        os.makedirs(experiment_path / prefix, exist_ok=True)
        save_numpy(predict,
                   experiment_path / prefix / f'{epoch}.npy.gz',
                   compression=compression)

    def save_models(epoch):
        os.makedirs(experiment_path / f'generator/', exist_ok=True)
        os.makedirs(experiment_path / f'discriminator', exist_ok=True)
        save_torch(generator, experiment_path / f'generator/generator_{epoch}')
        save_torch(discriminator,
                   experiment_path / f'discriminator/discriminator_{epoch}')

    logger = TBLogger(experiment_path / 'logs')
    epoch_callbacks = [predict_on_fixed_noise, save_models]

    train_dcgan(generator=generator,
                generator_opt=generator_opt,
                discriminator=discriminator,
                discriminator_opt=discriminator_opt,
                train_iterator=train_iterator,
                device=device,
                n_epoch=n_epoch,
                n_noise_channels=n_noise_channels,
                callbacks=epoch_callbacks,
                logger=logger)
Example #3
0
def run_on_generated(*, device, n_epoch: int, batch_size: int, batches_per_epoch, val_batches_per_epoch,
                     experiment_path: str, embedding_fidelity: float, model_path: str, n_noise_channels: int):
    model_path = Path(model_path).expanduser()
    experiment_path = Path(experiment_path).expanduser()

    generator = Generator().to(device)
    generator.load_state_dict(torch.load(model_path))
    train_iterator = ModelBatchGenerator(generator,
                                         batch_size=batch_size,
                                         batches_per_epoch=batches_per_epoch,
                                         n_noise_channels=n_noise_channels)
    val_iterator = ModelBatchGenerator(generator,
                                       batch_size=batch_size,
                                       batches_per_epoch=val_batches_per_epoch,
                                       n_noise_channels=n_noise_channels)
    # text stuff
    encoder = SigmoidTorchEncoder(beta=embedding_fidelity)
    text_loader = TextLoader()
    text_iterator = text_loader.create_generator()
    # model
    stegoanalyser = Stegoanalyser().to(device)
    stegoanalyser.apply(init_weights)
    optimizer_parameters = dict(lr=1e-4, betas=(0.5, 0.99))
    stegoanalyser_opt = Adam(stegoanalyser.parameters(), **optimizer_parameters)

    def save_models(epoch):
        os.makedirs(experiment_path / f'stegoanalyser', exist_ok=True)
        save_torch(stegoanalyser, experiment_path / f'stegoanalyser/stegoanalyser_{epoch}')

    logger = TBLogger(experiment_path / 'logs')
    epoch_callbacks = [save_models]
    # train on real images: with/without messages
    train_stego(
        stegoanalyser=stegoanalyser,
        train_iterator=train_iterator,
        val_iterator=val_iterator,
        text_iterator=text_iterator,
        n_epoch=n_epoch,
        stegoanalyser_opt=stegoanalyser_opt,
        callbacks=epoch_callbacks,
        logger=logger,
        encoder=encoder
    )
def run_on_real(*, device, download: bool, n_epoch: int, batch_size: int, data_path: str, experiment_path: str,
                embedding_fidelity: float):
    data_path = Path(data_path).expanduser()
    experiment_path = Path(experiment_path).expanduser()
    os.makedirs(experiment_path, exist_ok=True)
    # dataset and batch iterator: real images
    dataset = CelebDataset(root=data_path, download=download)
    train_indices, val_indices, test_indices = split_data(dataset, train_size=0.6, val_size=0.1, test_size=0.3)
    train_iterator = DataBatchIterator(dataset, train_indices, batch_size=batch_size)
    val_iterator = DataBatchIterator(dataset, val_indices, batch_size=batch_size)
    # save indices for reproducibility
    save(test_indices, experiment_path / 'test_indices.json')
    # text stuff
    encoder = SigmoidTorchEncoder(beta=embedding_fidelity)
    text_loader = TextLoader()
    text_iterator = text_loader.create_generator()
    # model
    stegoanalyser = Stegoanalyser().to(device)
    # stegoanalyser.apply(init_weights)
    optimizer_parameters = dict(lr=1e-4, betas=(0.9, 0.99))
    stegoanalyser_opt = Adam(stegoanalyser.parameters(), **optimizer_parameters)

    def save_models(epoch):
        os.makedirs(experiment_path / f'stegoanalyser', exist_ok=True)
        save_torch(stegoanalyser, experiment_path / f'stegoanalyser/stegoanalyser_{epoch}')

    logger = TBLogger(experiment_path / 'logs')
    epoch_callbacks = [save_models]
    # train on real images: with/without messages
    train_stego(
        stegoanalyser=stegoanalyser,
        train_iterator=train_iterator,
        val_iterator=val_iterator,
        text_iterator=text_iterator,
        n_epoch=n_epoch,
        stegoanalyser_opt=stegoanalyser_opt,
        callbacks=epoch_callbacks,
        logger=logger,
        encoder=encoder
    )
Example #5
0
def train_sgan(*, generator: nn.Module, image_analyser: nn.Module,
               message_analyser: nn.Module, generator_opt: Optimizer,
               image_analyser_opt: Optimizer, message_analyser_opt: Optimizer,
               encoder: SigmoidTorchEncoder, image_iterator: DataBatchIterator,
               text_iterator: Iterator, device: str, n_epoch: int,
               n_noise_channels: int, start_stego_epoch: int,
               loss_balancer: float, callbacks: Sequence[Callable],
               logger: TBLogger):
    generator = generator.to(device)
    image_analyser = image_analyser.to(device)
    message_analyser = message_analyser.to(device)
    criterion = F.binary_cross_entropy_with_logits

    assert 0. < loss_balancer < 1., loss_balancer
    callbacks = callbacks or []
    for epoch in tqdm(range(n_epoch)):
        generator_image_losses = []
        generator_message_losses = []
        image_analyser_real_losses = []
        image_analyser_fake_losses = []
        message_analyser_losses = []

        with image_iterator as iterator:
            for real_batch, _ in iterator:
                batch_size = len(real_batch)
                real_batch = transform_gan(real_batch)
                image_analyser_opt.zero_grad()
                # train discriminator on real
                real_images_target = torch.ones(batch_size, 1, 1, 1)
                image_analyser_real_losses.append(
                    process_batch(real_batch, real_images_target,
                                  image_analyser, criterion))
                # train discriminator on fake
                generated_images = generator(
                    generate_noise(batch_size, n_noise_channels, device))
                image_analyser_fake_losses.append(
                    process_batch(generated_images.detach(),
                                  torch.zeros(batch_size, 1, 1, 1),
                                  image_analyser, criterion))
                image_analyser_opt.step()
                # train generator
                generator_opt.zero_grad()
                generator_image_losses.append(
                    process_batch(generated_images,
                                  torch.ones(batch_size, 1, 1, 1),
                                  image_analyser, criterion))
                # rescale gradients
                scale_gradients(generator, loss_balancer)
                generator_opt.step()
                if epoch > start_stego_epoch:
                    # start second part
                    containers = generator(
                        generate_noise(batch_size, n_noise_channels, device))
                    # to [0...255]
                    containers = inverse_transform_gan(containers)

                    labels = np.random.choice([0, 1], (batch_size, 1, 1, 1))
                    encoded_images = []
                    for container, label in zip(containers, labels):
                        if label == 1:
                            msg = bytes_to_bits(next(text_iterator))
                            key = generate_random_key(container.shape[1:],
                                                      len(msg))
                            # to [-1, 1]
                            container = transform_encoder(container)
                            container = encoder.encode(container, msg, key)
                            # to [0...255]
                            container = inverse_transform_encoder(container)
                        encoded_images.append(container)

                    encoded_images = torch.stack(encoded_images)
                    labels = torch.from_numpy(labels).float()
                    # train analyser
                    message_analyser_opt.zero_grad()
                    message_analyser_losses.append(
                        process_batch(encoded_images.detach(), labels,
                                      message_analyser, criterion))
                    message_analyser_opt.step()
                    # train generator again
                    labels = torch.logical_xor(labels, torch.tensor(1)).float()
                    generator_opt.zero_grad()
                    generator_message_losses.append(
                        process_batch(encoded_images, labels, message_analyser,
                                      criterion))
                    scale_gradients(generator, 1 - loss_balancer)
                    generator_opt.step()

            # run callbacks
            for callback in callbacks:
                callback(epoch)

            losses = {
                'Generator image': np.mean(generator_image_losses),
                'Generator message': np.mean(generator_message_losses),
                'Image discriminator on real':
                np.mean(image_analyser_real_losses),
                'Image discriminator on fake':
                np.mean(image_analyser_fake_losses),
                'Message discriminator': np.mean(message_analyser_losses)
            }
            logger.policies(losses, epoch)
Example #6
0
def run_experiment(*, device, download: bool, data_path: str,
                   experiment_path: str, n_epoch: int, batch_size: int,
                   n_noise_channels: int, embedding_fidelity: float,
                   loss_balancer: float, start_stego_epoch: int):
    # path to save everything related to experiment
    data_path = Path(data_path).expanduser()
    experiment_path = Path(experiment_path).expanduser()

    # dataset and image iterator
    dataset = CelebDataset(root=data_path, download=download)
    indices = list(range(len(dataset)))
    train_iterator = DataBatchIterator(dataset, indices, batch_size=batch_size)
    # text loader and encoder
    encoder = SigmoidTorchEncoder(beta=embedding_fidelity)
    text_loader = TextLoader()
    text_iterator = text_loader.create_generator()
    # create models (both discriminator models have similar structure)
    image_analyser = Discriminator().to(device)
    message_analyzer = Stegoanalyser().to(device)
    generator = Generator(in_channels=n_noise_channels).to(device)
    generator.apply(init_weights)
    image_analyser.apply(init_weights)
    message_analyzer.apply(init_weights)

    optimizer_parameters = dict(lr=1e-4, betas=(0.5, 0.99))
    generator_opt = Adam(generator.parameters(), **optimizer_parameters)
    image_analyser_opt = Adam(image_analyser.parameters(),
                              **optimizer_parameters)
    message_analyser_opt = Adam(message_analyzer.parameters(),
                                **optimizer_parameters)

    fixed_noise = torch.randn(64, n_noise_channels, 1, 1, device=device)

    def predict_on_fixed_noise(epoch, prefix='fixed_noise', compression=1):
        predict = to_numpy(inference_step(fixed_noise, generator))
        os.makedirs(experiment_path / prefix, exist_ok=True)
        save_numpy(predict,
                   experiment_path / prefix / f'{epoch}.npy.gz',
                   compression=compression)

    def save_models(epoch):
        os.makedirs(experiment_path / f'generator', exist_ok=True)
        os.makedirs(experiment_path / f'image_analyser', exist_ok=True)
        os.makedirs(experiment_path / f'message_analyser', exist_ok=True)

        save_torch(generator, experiment_path / f'generator/generator_{epoch}')
        save_torch(image_analyser,
                   experiment_path / f'image_analyser/image_analyser_{epoch}')
        save_torch(
            message_analyzer,
            experiment_path / f'message_analyser/message_analyser_{epoch}')

    logger = TBLogger(experiment_path / 'logs')
    epoch_callbacks = [predict_on_fixed_noise, save_models]

    train_sgan(generator=generator,
               image_analyser=image_analyser,
               message_analyser=message_analyzer,
               generator_opt=generator_opt,
               image_analyser_opt=image_analyser_opt,
               message_analyser_opt=message_analyser_opt,
               encoder=encoder,
               image_iterator=train_iterator,
               text_iterator=text_iterator,
               n_epoch=n_epoch,
               start_stego_epoch=start_stego_epoch,
               n_noise_channels=n_noise_channels,
               loss_balancer=loss_balancer,
               logger=logger,
               callbacks=epoch_callbacks,
               device=device)
def train_stego(*, stegoanalyser: nn.Module,
                train_iterator: DataBatchIterator,
                val_iterator: DataBatchIterator,
                text_iterator: Iterator,
                n_epoch: int, stegoanalyser_opt: Optimizer,
                callbacks: Sequence[Callable] = None, logger: TBLogger,
                encoder: SigmoidTorchEncoder):
    criterion = F.binary_cross_entropy_with_logits
    callbacks = callbacks or []

    for epoch in tqdm(range(n_epoch)):
        stegoanalyser_losses = []
        with train_iterator as iterator:
            for real_batch, _ in iterator:
                batch_size = len(real_batch)
                labels = np.random.choice([0, 1], (batch_size, 1, 1, 1))
                encoded_images = []
                for image, label in zip(real_batch, labels):
                    if label == 1:
                        msg = bytes_to_bits(next(text_iterator))
                        key = generate_random_key(image.shape[1:], len(msg))
                        image = encoder.encode(transform_encoder(image), msg, key)
                        image = inverse_transform_encoder(image)
                    encoded_images.append(image)

                encoded_images = torch.stack(encoded_images)
                labels = torch.from_numpy(labels).float()
                # train stegoanalyzer
                stegoanalyser_opt.zero_grad()
                stegoanalyser_losses.append(
                    process_batch(encoded_images.detach(), labels, stegoanalyser, criterion))
                stegoanalyser_opt.step()

        with val_iterator as iterator:
            accuracy = []
            for real_batch, _ in iterator:
                batch_size = len(real_batch)

                labels = np.random.choice([0, 1], batch_size)
                encoded_images = []
                for image, label in zip(real_batch, labels):
                    if label == 1:
                        msg = bytes_to_bits(next(text_iterator))
                        key = generate_random_key(image.shape[1:], len(msg))
                        image = encoder.encode(transform_encoder(image), msg, key)
                        image = inverse_transform_encoder(image)
                    encoded_images.append(image)

                encoded_images = torch.stack(encoded_images)
                # evaluate stegoanalyzer
                out = inference_step(encoded_images, stegoanalyser).cpu().detach()
                out = torch.sigmoid(out) > 0.5
                out = out.reshape(len(encoded_images)).numpy()
                accuracy_score = sklearn.metrics.accuracy_score(labels, out)
                accuracy.append(accuracy_score)

            mean_accuracy = np.mean(accuracy)
            print(f'validation accuracy score {mean_accuracy}')

            losses = {'Stegoanalyser loss': np.mean(stegoanalyser_losses),
                      'Val accuracy': mean_accuracy}
            logger.policies(losses, epoch)

            # run callbacks
            for callback in callbacks:
                callback(epoch)