def main():
    parser = argparse.ArgumentParser(
        description='Regularized inverse scattering')
    parser.add_argument('--num_epochs',
                        default=2,
                        help='Number of epochs to train')
    parser.add_argument('--load_model',
                        default=False,
                        help='Load a trained model?')
    parser.add_argument('--dir_save_images',
                        default='interpolation_images',
                        help='Dir to save the sequence of images')
    args = parser.parse_args()

    num_epochs = args.num_epochs
    load_model = args.load_model
    dir_save_images = args.dir_save_images

    dir_to_save = get_cache_dir('reg_inverse_example')

    transforms_to_apply = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5),
            (0.5, 0.5, 0.5))  # Pixel values should be in [-1,1]
    ])

    mnist_dir = get_dataset_dir("MNIST", create=True)
    dataset = datasets.MNIST(mnist_dir,
                             train=True,
                             download=True,
                             transform=transforms_to_apply)
    dataloader = DataLoader(dataset,
                            batch_size=128,
                            shuffle=True,
                            pin_memory=True)

    fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    fixed_batch = next(iter(fixed_dataloader))
    fixed_batch = fixed_batch[0].float().cuda()

    scattering = Scattering(M=28, N=28, J=2)
    scattering.cuda()

    scattering_fixed_batch = scattering(fixed_batch).squeeze(1)
    num_input_channels = scattering_fixed_batch.shape[1]
    num_hidden_channels = num_input_channels

    generator = Generator(num_input_channels, num_hidden_channels)
    generator.cuda()
    generator.train()

    # Either train the network or load a trained model
    ##################################################
    if load_model:
        filename_model = os.path.join(dir_to_save, 'model.pth')
        generator.load_state_dict(torch.load(filename_model))
    else:
        criterion = torch.nn.L1Loss()
        optimizer = optim.Adam(generator.parameters())

        for idx_epoch in range(num_epochs):
            print('Training epoch {}'.format(idx_epoch))
            for _, current_batch in enumerate(dataloader):
                generator.zero_grad()
                batch_images = Variable(current_batch[0]).float().cuda()
                batch_scattering = scattering(batch_images).squeeze(1)
                batch_inverse_scattering = generator(batch_scattering)
                loss = criterion(batch_inverse_scattering, batch_images)
                loss.backward()
                optimizer.step()

        print('Saving results in {}'.format(dir_to_save))

        torch.save(generator.state_dict(),
                   os.path.join(dir_to_save, 'model.pth'))

    generator.eval()

    # We create the batch containing the linear interpolation points in the scattering space
    ########################################################################################
    z0 = scattering_fixed_batch.cpu().numpy()[[0]]
    z1 = scattering_fixed_batch.cpu().numpy()[[1]]
    batch_z = np.copy(z0)
    num_samples = 32
    interval = np.linspace(0, 1, num_samples)
    for t in interval:
        if t > 0:
            zt = (1 - t) * z0 + t * z1
            batch_z = np.vstack((batch_z, zt))

    z = torch.from_numpy(batch_z).float().cuda()
    path = generator(z).data.cpu().numpy().squeeze(1)
    path = (path + 1) / 2  # The pixels are now in [0, 1]

    # We show and store the nonlinear interpolation in the image space
    ##################################################################
    dir_path = os.path.join(dir_to_save, dir_save_images)

    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    for idx_image in range(num_samples):
        current_image = np.uint8(path[idx_image] * 255.0)
        filename = os.path.join(dir_path, '{}.png'.format(idx_image))
        Image.fromarray(current_image).save(filename)
        plt.imshow(current_image, cmap='gray')
        plt.axis('off')
        plt.pause(0.1)
        plt.draw()
Beispiel #2
0
order_0, order_1, order_2 = compute_qm7_solid_harmonic_scattering_coefficients(
    M=M,
    N=N,
    O=O,
    J=J,
    L=L,
    integral_powers=integral_powers,
    sigma=sigma,
    batch_size=8)

n_molecules = order_0.size(0)

np_order_0 = order_0.numpy().reshape((n_molecules, -1))
np_order_1 = order_1.numpy().reshape((n_molecules, -1))
np_order_2 = order_2.numpy().reshape((n_molecules, -1))

basename = 'qm7_L_{}_J_{}_sigma_{}_MNO_{}_powers_{}.npy'.format(
    L, J, sigma, (M, N, O), integral_powers)
cachedir = get_cache_dir("qm7/experiments")
np.save(os.path.join(cachedir, 'order_0_' + basename), np_order_0)
np.save(os.path.join(cachedir, 'order_1_' + basename), np_order_1)
np.save(os.path.join(cachedir, 'order_2_' + basename), np_order_2)

scattering_coef = np.concatenate([np_order_0, np_order_1, np_order_2], axis=1)
target = get_qm7_energies()

print('order 1 : {} coef, order 2 : {} coefs'.format(np_order_1.shape[1],
                                                     np_order_2.shape[1]))
evaluate_linear_regression(scattering_coef, target)
Beispiel #3
0
        return self.main(input_tensor)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Regularized inverse scattering')
    parser.add_argument('--num_epochs', default=2, help='Number of epochs to train')
    parser.add_argument('--load_model', default=False, help='Load a trained model?')
    parser.add_argument('--dir_save_images', default='interpolation_images', help='Dir to save the sequence of images')
    args = parser.parse_args()

    num_epochs = args.num_epochs
    load_model = args.load_model
    dir_save_images = args.dir_save_images

    dir_to_save = get_cache_dir('reg_inverse_example')

    transforms_to_apply = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Pixel values should be in [-1,1]
    ])

    mnist_dir = get_dataset_dir("MNIST", create=True)
    dataset = datasets.MNIST(mnist_dir, train=True, download=True, transform=transforms_to_apply)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)

    fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    fixed_batch = next(iter(fixed_dataloader))
    fixed_batch = fixed_batch[0].float().cuda()

    scattering = Scattering(J=2, shape=(28, 28))