def main(): parser = argparse.ArgumentParser( description='Regularized inverse scattering') parser.add_argument('--num_epochs', default=2, help='Number of epochs to train') parser.add_argument('--load_model', default=False, help='Load a trained model?') parser.add_argument('--dir_save_images', default='interpolation_images', help='Dir to save the sequence of images') args = parser.parse_args() num_epochs = args.num_epochs load_model = args.load_model dir_save_images = args.dir_save_images dir_to_save = get_cache_dir('reg_inverse_example') transforms_to_apply = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Pixel values should be in [-1,1] ]) mnist_dir = get_dataset_dir("MNIST", create=True) dataset = datasets.MNIST(mnist_dir, train=True, download=True, transform=transforms_to_apply) dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True) fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True) fixed_batch = next(iter(fixed_dataloader)) fixed_batch = fixed_batch[0].float().cuda() scattering = Scattering(M=28, N=28, J=2) scattering.cuda() scattering_fixed_batch = scattering(fixed_batch).squeeze(1) num_input_channels = scattering_fixed_batch.shape[1] num_hidden_channels = num_input_channels generator = Generator(num_input_channels, num_hidden_channels) generator.cuda() generator.train() # Either train the network or load a trained model ################################################## if load_model: filename_model = os.path.join(dir_to_save, 'model.pth') generator.load_state_dict(torch.load(filename_model)) else: criterion = torch.nn.L1Loss() optimizer = optim.Adam(generator.parameters()) for idx_epoch in range(num_epochs): print('Training epoch {}'.format(idx_epoch)) for _, current_batch in enumerate(dataloader): generator.zero_grad() batch_images = Variable(current_batch[0]).float().cuda() batch_scattering = scattering(batch_images).squeeze(1) batch_inverse_scattering = generator(batch_scattering) loss = criterion(batch_inverse_scattering, batch_images) loss.backward() optimizer.step() print('Saving results in {}'.format(dir_to_save)) torch.save(generator.state_dict(), os.path.join(dir_to_save, 'model.pth')) generator.eval() # We create the batch containing the linear interpolation points in the scattering space ######################################################################################## z0 = scattering_fixed_batch.cpu().numpy()[[0]] z1 = scattering_fixed_batch.cpu().numpy()[[1]] batch_z = np.copy(z0) num_samples = 32 interval = np.linspace(0, 1, num_samples) for t in interval: if t > 0: zt = (1 - t) * z0 + t * z1 batch_z = np.vstack((batch_z, zt)) z = torch.from_numpy(batch_z).float().cuda() path = generator(z).data.cpu().numpy().squeeze(1) path = (path + 1) / 2 # The pixels are now in [0, 1] # We show and store the nonlinear interpolation in the image space ################################################################## dir_path = os.path.join(dir_to_save, dir_save_images) if not os.path.exists(dir_path): os.makedirs(dir_path) for idx_image in range(num_samples): current_image = np.uint8(path[idx_image] * 255.0) filename = os.path.join(dir_path, '{}.png'.format(idx_image)) Image.fromarray(current_image).save(filename) plt.imshow(current_image, cmap='gray') plt.axis('off') plt.pause(0.1) plt.draw()
order_0, order_1, order_2 = compute_qm7_solid_harmonic_scattering_coefficients( M=M, N=N, O=O, J=J, L=L, integral_powers=integral_powers, sigma=sigma, batch_size=8) n_molecules = order_0.size(0) np_order_0 = order_0.numpy().reshape((n_molecules, -1)) np_order_1 = order_1.numpy().reshape((n_molecules, -1)) np_order_2 = order_2.numpy().reshape((n_molecules, -1)) basename = 'qm7_L_{}_J_{}_sigma_{}_MNO_{}_powers_{}.npy'.format( L, J, sigma, (M, N, O), integral_powers) cachedir = get_cache_dir("qm7/experiments") np.save(os.path.join(cachedir, 'order_0_' + basename), np_order_0) np.save(os.path.join(cachedir, 'order_1_' + basename), np_order_1) np.save(os.path.join(cachedir, 'order_2_' + basename), np_order_2) scattering_coef = np.concatenate([np_order_0, np_order_1, np_order_2], axis=1) target = get_qm7_energies() print('order 1 : {} coef, order 2 : {} coefs'.format(np_order_1.shape[1], np_order_2.shape[1])) evaluate_linear_regression(scattering_coef, target)
return self.main(input_tensor) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Regularized inverse scattering') parser.add_argument('--num_epochs', default=2, help='Number of epochs to train') parser.add_argument('--load_model', default=False, help='Load a trained model?') parser.add_argument('--dir_save_images', default='interpolation_images', help='Dir to save the sequence of images') args = parser.parse_args() num_epochs = args.num_epochs load_model = args.load_model dir_save_images = args.dir_save_images dir_to_save = get_cache_dir('reg_inverse_example') transforms_to_apply = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Pixel values should be in [-1,1] ]) mnist_dir = get_dataset_dir("MNIST", create=True) dataset = datasets.MNIST(mnist_dir, train=True, download=True, transform=transforms_to_apply) dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True) fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True) fixed_batch = next(iter(fixed_dataloader)) fixed_batch = fixed_batch[0].float().cuda() scattering = Scattering(J=2, shape=(28, 28))