Пример #1
0
def main():

    transforms_to_apply = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5),
            (0.5, 0.5, 0.5))  # Pixel values should be in [-1,1]
    ])

    mnist_dir = get_dataset_dir("MNIST", create=True)
    dataset = datasets.MNIST(mnist_dir,
                             train=False,
                             download=False,
                             transform=transforms_to_apply)
    dataloader = DataLoader(dataset,
                            batch_size=20,
                            shuffle=True,
                            pin_memory=True)

    fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    fixed_batch = next(iter(fixed_dataloader))
    fixed_batch = fixed_batch[0].float().cuda()

    scattering = Scattering(J=2, shape=(28, 28))
    scattering.cuda()

    for _, current_batch in enumerate(dataloader):
        batch_images = Variable(current_batch[0]).float().cuda()
        batch_scattering = scattering(batch_images).squeeze(1)

        print(batch_scattering.shape)
        exit()
def main():
    parser = argparse.ArgumentParser(
        description='Regularized inverse scattering')
    parser.add_argument('--num_epochs',
                        default=2,
                        help='Number of epochs to train')
    parser.add_argument('--load_model',
                        default=False,
                        help='Load a trained model?')
    parser.add_argument('--dir_save_images',
                        default='interpolation_images',
                        help='Dir to save the sequence of images')
    args = parser.parse_args()

    num_epochs = args.num_epochs
    load_model = args.load_model
    dir_save_images = args.dir_save_images

    dir_to_save = get_cache_dir('reg_inverse_example')

    transforms_to_apply = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5),
            (0.5, 0.5, 0.5))  # Pixel values should be in [-1,1]
    ])

    mnist_dir = get_dataset_dir("MNIST", create=True)
    dataset = datasets.MNIST(mnist_dir,
                             train=True,
                             download=True,
                             transform=transforms_to_apply)
    dataloader = DataLoader(dataset,
                            batch_size=128,
                            shuffle=True,
                            pin_memory=True)

    fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    fixed_batch = next(iter(fixed_dataloader))
    fixed_batch = fixed_batch[0].float().cuda()

    scattering = Scattering(M=28, N=28, J=2)
    scattering.cuda()

    scattering_fixed_batch = scattering(fixed_batch).squeeze(1)
    num_input_channels = scattering_fixed_batch.shape[1]
    num_hidden_channels = num_input_channels

    generator = Generator(num_input_channels, num_hidden_channels)
    generator.cuda()
    generator.train()

    # Either train the network or load a trained model
    ##################################################
    if load_model:
        filename_model = os.path.join(dir_to_save, 'model.pth')
        generator.load_state_dict(torch.load(filename_model))
    else:
        criterion = torch.nn.L1Loss()
        optimizer = optim.Adam(generator.parameters())

        for idx_epoch in range(num_epochs):
            print('Training epoch {}'.format(idx_epoch))
            for _, current_batch in enumerate(dataloader):
                generator.zero_grad()
                batch_images = Variable(current_batch[0]).float().cuda()
                batch_scattering = scattering(batch_images).squeeze(1)
                batch_inverse_scattering = generator(batch_scattering)
                loss = criterion(batch_inverse_scattering, batch_images)
                loss.backward()
                optimizer.step()

        print('Saving results in {}'.format(dir_to_save))

        torch.save(generator.state_dict(),
                   os.path.join(dir_to_save, 'model.pth'))

    generator.eval()

    # We create the batch containing the linear interpolation points in the scattering space
    ########################################################################################
    z0 = scattering_fixed_batch.cpu().numpy()[[0]]
    z1 = scattering_fixed_batch.cpu().numpy()[[1]]
    batch_z = np.copy(z0)
    num_samples = 32
    interval = np.linspace(0, 1, num_samples)
    for t in interval:
        if t > 0:
            zt = (1 - t) * z0 + t * z1
            batch_z = np.vstack((batch_z, zt))

    z = torch.from_numpy(batch_z).float().cuda()
    path = generator(z).data.cpu().numpy().squeeze(1)
    path = (path + 1) / 2  # The pixels are now in [0, 1]

    # We show and store the nonlinear interpolation in the image space
    ##################################################################
    dir_path = os.path.join(dir_to_save, dir_save_images)

    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    for idx_image in range(num_samples):
        current_image = np.uint8(path[idx_image] * 255.0)
        filename = os.path.join(dir_path, '{}.png'.format(idx_image))
        Image.fromarray(current_image).save(filename)
        plt.imshow(current_image, cmap='gray')
        plt.axis('off')
        plt.pause(0.1)
        plt.draw()
Пример #3
0
    transforms_to_apply = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Pixel values should be in [-1,1]
    ])

    mnist_dir = get_dataset_dir("MNIST", create=True)
    dataset = datasets.MNIST(mnist_dir, train=True, download=True, transform=transforms_to_apply)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)

    fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    fixed_batch = next(iter(fixed_dataloader))
    fixed_batch = fixed_batch[0].float().cuda()

    scattering = Scattering(J=2, shape=(28, 28))
    scattering.cuda()

    scattering_fixed_batch = scattering(fixed_batch).squeeze(1)
    num_input_channels = scattering_fixed_batch.shape[1]
    num_hidden_channels = num_input_channels

    generator = Generator(num_input_channels, num_hidden_channels)
    generator.cuda()
    generator.train()

    # Either train the network or load a trained model
    ##################################################
    if load_model:
        filename_model = os.path.join(dir_to_save, 'model.pth')
        generator.load_state_dict(torch.load(filename_model))
    else: