class ETCNet(nn.Module): """ Encoded Template Convolutional Network (ETCNet) is a neural network architecture based on the usage of a pretrained VAE which encodes a given template and uses its encoding as weights for the first convolution of a ResNet. """ def __init__(self, output_size=10): super(ETCNet, self).__init__() self.vae = ConvVAE() for p in self.vae.parameters(): p.requires_grad = False self.vae.train(False) self.resnet_model = resnet50(pretrained=True) self.resnet_model.fc = FullyConnectedLayers() self.output = nn.Linear(128, output_size) def forward(self, x, x_object): mu, logvar = self.vae.encoder(x_object) template_weights = self.vae.reparametrize(mu, logvar) template_weights = template_weights[0].repeat(3, 1, 1, 1) template_weights = template_weights.permute(1, 0, 2, 3) self.resnet_model.conv1.weight = torch.nn.Parameter( template_weights, requires_grad=False) x = self.resnet_model(x) x = self.output(x) return x def load_vae(self, path): self.vae.load_state_dict(torch.load(path))
resized_template, (1, resized_template.shape[-3], resized_template.shape[-2], resized_template.shape[-1])) decoded, _, _ = model(templates) im1 = transforms.ToPILImage()(templates[0]).convert("RGB") im2 = transforms.ToPILImage()(decoded[0]).convert("RGB") # im1.save('../../test.jpg') # im2.save('../../test_decoded.jpg') Image.fromarray(np.hstack((np.array(im1), np.array(im2)))).show() decoded, _, _ = model2(templates) im2 = transforms.ToPILImage()(decoded[0]).convert("RGB") Image.fromarray(np.hstack((np.array(im1), np.array(im2)))).show() mu, logvar = model.encoder(templates) z = model.reparametrize(mu, logvar) encoded = z[0].detach().numpy() # encoded = torch.reshape(z, (8, 8, z.shape[-2], z.shape[-1])) # Plot fig, axs = plt.subplots(8, 8, figsize=(10, 10), facecolor='w', edgecolor='k') fig.subplots_adjust(hspace=.1, wspace=.0005) axs = axs.ravel() plt.axis('off') for i in range(64): # axs[i].contourf(encoded[i], 5, cmap=plt.cm.Oranges)
class GMNETCNet(nn.Module): """ """ def __init__(self, output_matching_size=None): super(GMNETCNet, self).__init__() self.vae = ConvVAE() for p in self.vae.parameters(): p.requires_grad = False self.vae.train(False) resnet_model = resnet50(pretrained=True) self.cut_resnet = nn.Sequential(*(list(resnet_model.children())[:6])) for module in self.cut_resnet.modules(): if isinstance(module, Bottleneck) and isinstance( module.conv2, nn.Conv2d): module.conv2 = Adapter(module.conv2) resnet_model = resnet50(pretrained=True) self.cut_resnet_template = nn.Sequential( *(list(resnet_model.children())[:6])) for module in self.cut_resnet_template.modules(): if isinstance(module, Bottleneck) and isinstance( module.conv2, nn.Conv2d): module.conv2 = Adapter(module.conv2) self.adapt_pool = nn.AdaptiveAvgPool2d(1) self.batch_norm = nn.BatchNorm2d(512) self.matching_model = MatchingNet(1024, output_matching_size) self.output_conv = nn.Conv2d(256, 1, 3, stride=1, padding=1) # TODO compute padding same def forward(self, x, x_object, resized_template): mu, logvar = self.vae.encoder(resized_template) template_weights = self.vae.reparametrize(mu, logvar) template_weights = template_weights[0].repeat(3, 1, 1, 1) template_weights = template_weights.permute(1, 0, 2, 3) for module in self.cut_resnet.modules(): if isinstance(module, nn.Conv2d): module.weight = torch.nn.Parameter(template_weights, requires_grad=False) break for module in self.cut_resnet_template.modules(): if isinstance(module, nn.Conv2d): module.weight = torch.nn.Parameter(template_weights, requires_grad=False) break x_object = self.cut_resnet_template(x_object) x_object = self.adapt_pool(x_object) x_object = F.normalize(x_object, p=2, dim=-1) x = self.batch_norm(self.cut_resnet(x)) x_object = x_object.repeat(1, 1, x.shape[-2], x.shape[-1]) outputs = torch.cat((x, x_object), 1) outputs = self.matching_model(outputs) outputs = self.output_conv(outputs) return outputs def load_vae(self, path): self.vae.load_state_dict(torch.load(path)) @staticmethod def get_count(matrix, plot=False): footprint_3x3 = np.ones((3, 3)) footprint_3x3[0, 0] = 0 footprint_3x3[2, 0] = 0 footprint_3x3[0, 2] = 0 footprint_3x3[2, 2] = 0 footprint_7x7 = np.ones((7, 7)) footprint_7x7[0, 0] = 0 footprint_7x7[6, 0] = 0 footprint_7x7[0, 6] = 0 footprint_7x7[6, 6] = 0 # Small maximum analysis data_max = ndimage.maximum_filter(matrix, footprint=footprint_3x3) maxima = (matrix == data_max) data_min = ndimage.minimum_filter(matrix, footprint=footprint_3x3) diff = ((data_max - data_min) > (matrix.max() / 3)) maxima[diff == 0] = 0 labeled, num_objects = ndimage.label(maxima) slices = ndimage.find_objects(labeled) x, y = [], [] for dy, dx in slices: x_center = (dx.start + dx.stop - 1) // 2 x.append(x_center) y_center = (dy.start + dy.stop - 1) // 2 y.append(y_center) coordinates_small = list(zip(x, y)) # Big maximum analysis data_max = ndimage.maximum_filter(matrix, footprint=footprint_7x7) maxima = (matrix == data_max) data_min = ndimage.minimum_filter(matrix, footprint=footprint_7x7) diff = ((data_max - data_min) > (matrix.max() / 3)) maxima[diff == 0] = 0 labeled, num_objects = ndimage.label(maxima) slices = ndimage.find_objects(labeled) x, y = [], [] for dy, dx in slices: x_center = (dx.start + dx.stop - 1) // 2 x.append(x_center) y_center = (dy.start + dy.stop - 1) // 2 y.append(y_center) coordinates_big = list(zip(x, y)) coordinates = np.array(list(set(coordinates_big + coordinates_small))) if plot: plt.figure() plt.imshow(matrix, cmap="gray") plt.axis('off') plt.autoscale(False) plt.plot(coordinates[:, 0], coordinates[:, 1], 'rx') # plt.title("Located instances") plt.show() return len(coordinates)