class learn_kernel_model(torch.nn.Module):
    def __init__(self, kernel_size):
        super(learn_kernel_model, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 1, kernel_size)

    def forward(self, x):
        return torch.relu(self.conv1(x))


# Create the training dataset.
training_size = 100
training_x = []
training_y = []
training_n = []
kernel = tools.get_gaussian_kernel(33, 5)
for _ in range(training_size):
    x = tools.get_stably_bounded_shape(-3, 3, -3, 3, 64, 64)
    training_x.append(x)
    y = signal.fftconvolve(x, kernel, mode='valid')
    training_y.append(y)
    n = np.random.randn(*y.shape) * .0
    training_n.append(n)
training_x = np.array(training_x)
training_y = np.array(training_y)
training_n = np.array(training_n)

# Check the example.
plt.figure()
plt.title('noisy downsample')
plt.imshow(training_y[0] + training_n[0])
Example #2
0
        self.cv1 = torch.nn.Conv2d(1, 64, 8)
        self.fc1 = torch.nn.Linear(64 * 25 * 25, 64)
        self.fc2 = torch.nn.Linear(64, 15)
        x, y = torch.meshgrid(torch.linspace(- 3, 3, 64), torch.linspace(- 3, 3, 64))
        self.basis = torch.stack([x ** i * y ** j for i in range(5) for j in range(5 - i)])
    def forward(self, x):
        x = torch.relu(self.cv1(x))
        x = torch.relu(self.fc1(x.view(- 1, 64 * 25 * 25)))
        x = self.fc2(x)
        x = torch.sigmoid(torch.einsum('li, ijk -> ljk', x, self.basis)).view(- 1, 1, 64, 64)
        return x

print('generating data')
n_images = 500
sigma = .1
training_x, training_y, training_n, kernel = [], [], [], tools.get_gaussian_kernel(33, 5)
for _ in range(n_images):
    x = tools.get_stably_bounded_shape(- 3, 3, - 3, 3, 64, 64)
    training_x.append(x)
    y = signal.fftconvolve(x, kernel, mode='valid')
    training_y.append(y)
    n = np.random.normal(0, sigma, y.shape)
    training_n.append(n)
training_x = np.array(training_x)
training_x = torch.from_numpy(training_x).view(- 1, 1, * training_x[0].shape).float()
training_y = np.array(training_y)
training_y = torch.from_numpy(training_y).view(- 1, 1, * training_y[0].shape).float()
training_n = np.array(training_n)
training_n = torch.from_numpy(training_n).view(- 1, 1, * training_n[0].shape).float()

print('plotting')