コード例 #1
0
class ETCNet(nn.Module):
    """
    Encoded Template Convolutional Network (ETCNet) is a neural network architecture based on the usage of a
    pretrained VAE which encodes a given template and uses its encoding as weights for the first convolution of a
    ResNet.
    """
    def __init__(self, output_size=10):
        super(ETCNet, self).__init__()

        self.vae = ConvVAE()
        for p in self.vae.parameters():
            p.requires_grad = False
        self.vae.train(False)

        self.resnet_model = resnet50(pretrained=True)
        self.resnet_model.fc = FullyConnectedLayers()
        self.output = nn.Linear(128, output_size)

    def forward(self, x, x_object):
        mu, logvar = self.vae.encoder(x_object)
        template_weights = self.vae.reparametrize(mu, logvar)
        template_weights = template_weights[0].repeat(3, 1, 1, 1)
        template_weights = template_weights.permute(1, 0, 2, 3)
        self.resnet_model.conv1.weight = torch.nn.Parameter(
            template_weights, requires_grad=False)
        x = self.resnet_model(x)
        x = self.output(x)
        return x

    def load_vae(self, path):
        self.vae.load_state_dict(torch.load(path))
コード例 #2
0
        resized_template,
        (1, resized_template.shape[-3], resized_template.shape[-2],
         resized_template.shape[-1]))
    decoded, _, _ = model(templates)

    im1 = transforms.ToPILImage()(templates[0]).convert("RGB")
    im2 = transforms.ToPILImage()(decoded[0]).convert("RGB")
    # im1.save('../../test.jpg')
    # im2.save('../../test_decoded.jpg')
    Image.fromarray(np.hstack((np.array(im1), np.array(im2)))).show()
    decoded, _, _ = model2(templates)
    im2 = transforms.ToPILImage()(decoded[0]).convert("RGB")
    Image.fromarray(np.hstack((np.array(im1), np.array(im2)))).show()

    mu, logvar = model.encoder(templates)
    z = model.reparametrize(mu, logvar)
    encoded = z[0].detach().numpy()
    # encoded = torch.reshape(z, (8, 8, z.shape[-2], z.shape[-1]))

    # Plot
    fig, axs = plt.subplots(8,
                            8,
                            figsize=(10, 10),
                            facecolor='w',
                            edgecolor='k')
    fig.subplots_adjust(hspace=.1, wspace=.0005)

    axs = axs.ravel()
    plt.axis('off')
    for i in range(64):
        # axs[i].contourf(encoded[i], 5, cmap=plt.cm.Oranges)
コード例 #3
0
class GMNETCNet(nn.Module):
    """

    """
    def __init__(self, output_matching_size=None):
        super(GMNETCNet, self).__init__()

        self.vae = ConvVAE()
        for p in self.vae.parameters():
            p.requires_grad = False
        self.vae.train(False)

        resnet_model = resnet50(pretrained=True)
        self.cut_resnet = nn.Sequential(*(list(resnet_model.children())[:6]))
        for module in self.cut_resnet.modules():
            if isinstance(module, Bottleneck) and isinstance(
                    module.conv2, nn.Conv2d):
                module.conv2 = Adapter(module.conv2)

        resnet_model = resnet50(pretrained=True)
        self.cut_resnet_template = nn.Sequential(
            *(list(resnet_model.children())[:6]))
        for module in self.cut_resnet_template.modules():
            if isinstance(module, Bottleneck) and isinstance(
                    module.conv2, nn.Conv2d):
                module.conv2 = Adapter(module.conv2)

        self.adapt_pool = nn.AdaptiveAvgPool2d(1)
        self.batch_norm = nn.BatchNorm2d(512)

        self.matching_model = MatchingNet(1024, output_matching_size)
        self.output_conv = nn.Conv2d(256, 1, 3, stride=1,
                                     padding=1)  # TODO compute padding same

    def forward(self, x, x_object, resized_template):

        mu, logvar = self.vae.encoder(resized_template)
        template_weights = self.vae.reparametrize(mu, logvar)
        template_weights = template_weights[0].repeat(3, 1, 1, 1)
        template_weights = template_weights.permute(1, 0, 2, 3)
        for module in self.cut_resnet.modules():
            if isinstance(module, nn.Conv2d):
                module.weight = torch.nn.Parameter(template_weights,
                                                   requires_grad=False)
                break

        for module in self.cut_resnet_template.modules():
            if isinstance(module, nn.Conv2d):
                module.weight = torch.nn.Parameter(template_weights,
                                                   requires_grad=False)
                break

        x_object = self.cut_resnet_template(x_object)
        x_object = self.adapt_pool(x_object)
        x_object = F.normalize(x_object, p=2, dim=-1)

        x = self.batch_norm(self.cut_resnet(x))

        x_object = x_object.repeat(1, 1, x.shape[-2], x.shape[-1])
        outputs = torch.cat((x, x_object), 1)

        outputs = self.matching_model(outputs)
        outputs = self.output_conv(outputs)
        return outputs

    def load_vae(self, path):
        self.vae.load_state_dict(torch.load(path))

    @staticmethod
    def get_count(matrix, plot=False):
        footprint_3x3 = np.ones((3, 3))
        footprint_3x3[0, 0] = 0
        footprint_3x3[2, 0] = 0
        footprint_3x3[0, 2] = 0
        footprint_3x3[2, 2] = 0

        footprint_7x7 = np.ones((7, 7))
        footprint_7x7[0, 0] = 0
        footprint_7x7[6, 0] = 0
        footprint_7x7[0, 6] = 0
        footprint_7x7[6, 6] = 0

        # Small maximum analysis
        data_max = ndimage.maximum_filter(matrix, footprint=footprint_3x3)
        maxima = (matrix == data_max)
        data_min = ndimage.minimum_filter(matrix, footprint=footprint_3x3)
        diff = ((data_max - data_min) > (matrix.max() / 3))
        maxima[diff == 0] = 0

        labeled, num_objects = ndimage.label(maxima)
        slices = ndimage.find_objects(labeled)
        x, y = [], []
        for dy, dx in slices:
            x_center = (dx.start + dx.stop - 1) // 2
            x.append(x_center)
            y_center = (dy.start + dy.stop - 1) // 2
            y.append(y_center)

        coordinates_small = list(zip(x, y))

        # Big maximum analysis
        data_max = ndimage.maximum_filter(matrix, footprint=footprint_7x7)
        maxima = (matrix == data_max)
        data_min = ndimage.minimum_filter(matrix, footprint=footprint_7x7)
        diff = ((data_max - data_min) > (matrix.max() / 3))
        maxima[diff == 0] = 0

        labeled, num_objects = ndimage.label(maxima)
        slices = ndimage.find_objects(labeled)
        x, y = [], []
        for dy, dx in slices:
            x_center = (dx.start + dx.stop - 1) // 2
            x.append(x_center)
            y_center = (dy.start + dy.stop - 1) // 2
            y.append(y_center)

        coordinates_big = list(zip(x, y))

        coordinates = np.array(list(set(coordinates_big + coordinates_small)))

        if plot:
            plt.figure()
            plt.imshow(matrix, cmap="gray")
            plt.axis('off')
            plt.autoscale(False)
            plt.plot(coordinates[:, 0], coordinates[:, 1], 'rx')
            # plt.title("Located instances")
            plt.show()

        return len(coordinates)