コード例 #1
0
class ETCNet(nn.Module):
    """
    Encoded Template Convolutional Network (ETCNet) is a neural network architecture based on the usage of a
    pretrained VAE which encodes a given template and uses its encoding as weights for the first convolution of a
    ResNet.
    """
    def __init__(self, output_size=10):
        super(ETCNet, self).__init__()

        self.vae = ConvVAE()
        for p in self.vae.parameters():
            p.requires_grad = False
        self.vae.train(False)

        self.resnet_model = resnet50(pretrained=True)
        self.resnet_model.fc = FullyConnectedLayers()
        self.output = nn.Linear(128, output_size)

    def forward(self, x, x_object):
        mu, logvar = self.vae.encoder(x_object)
        template_weights = self.vae.reparametrize(mu, logvar)
        template_weights = template_weights[0].repeat(3, 1, 1, 1)
        template_weights = template_weights.permute(1, 0, 2, 3)
        self.resnet_model.conv1.weight = torch.nn.Parameter(
            template_weights, requires_grad=False)
        x = self.resnet_model(x)
        x = self.output(x)
        return x

    def load_vae(self, path):
        self.vae.load_state_dict(torch.load(path))
コード例 #2
0
class GMNETCNet(nn.Module):
    """

    """
    def __init__(self, output_matching_size=None):
        super(GMNETCNet, self).__init__()

        self.vae = ConvVAE()
        for p in self.vae.parameters():
            p.requires_grad = False
        self.vae.train(False)

        resnet_model = resnet50(pretrained=True)
        self.cut_resnet = nn.Sequential(*(list(resnet_model.children())[:6]))
        for module in self.cut_resnet.modules():
            if isinstance(module, Bottleneck) and isinstance(
                    module.conv2, nn.Conv2d):
                module.conv2 = Adapter(module.conv2)

        resnet_model = resnet50(pretrained=True)
        self.cut_resnet_template = nn.Sequential(
            *(list(resnet_model.children())[:6]))
        for module in self.cut_resnet_template.modules():
            if isinstance(module, Bottleneck) and isinstance(
                    module.conv2, nn.Conv2d):
                module.conv2 = Adapter(module.conv2)

        self.adapt_pool = nn.AdaptiveAvgPool2d(1)
        self.batch_norm = nn.BatchNorm2d(512)

        self.matching_model = MatchingNet(1024, output_matching_size)
        self.output_conv = nn.Conv2d(256, 1, 3, stride=1,
                                     padding=1)  # TODO compute padding same

    def forward(self, x, x_object, resized_template):

        mu, logvar = self.vae.encoder(resized_template)
        template_weights = self.vae.reparametrize(mu, logvar)
        template_weights = template_weights[0].repeat(3, 1, 1, 1)
        template_weights = template_weights.permute(1, 0, 2, 3)
        for module in self.cut_resnet.modules():
            if isinstance(module, nn.Conv2d):
                module.weight = torch.nn.Parameter(template_weights,
                                                   requires_grad=False)
                break

        for module in self.cut_resnet_template.modules():
            if isinstance(module, nn.Conv2d):
                module.weight = torch.nn.Parameter(template_weights,
                                                   requires_grad=False)
                break

        x_object = self.cut_resnet_template(x_object)
        x_object = self.adapt_pool(x_object)
        x_object = F.normalize(x_object, p=2, dim=-1)

        x = self.batch_norm(self.cut_resnet(x))

        x_object = x_object.repeat(1, 1, x.shape[-2], x.shape[-1])
        outputs = torch.cat((x, x_object), 1)

        outputs = self.matching_model(outputs)
        outputs = self.output_conv(outputs)
        return outputs

    def load_vae(self, path):
        self.vae.load_state_dict(torch.load(path))

    @staticmethod
    def get_count(matrix, plot=False):
        footprint_3x3 = np.ones((3, 3))
        footprint_3x3[0, 0] = 0
        footprint_3x3[2, 0] = 0
        footprint_3x3[0, 2] = 0
        footprint_3x3[2, 2] = 0

        footprint_7x7 = np.ones((7, 7))
        footprint_7x7[0, 0] = 0
        footprint_7x7[6, 0] = 0
        footprint_7x7[0, 6] = 0
        footprint_7x7[6, 6] = 0

        # Small maximum analysis
        data_max = ndimage.maximum_filter(matrix, footprint=footprint_3x3)
        maxima = (matrix == data_max)
        data_min = ndimage.minimum_filter(matrix, footprint=footprint_3x3)
        diff = ((data_max - data_min) > (matrix.max() / 3))
        maxima[diff == 0] = 0

        labeled, num_objects = ndimage.label(maxima)
        slices = ndimage.find_objects(labeled)
        x, y = [], []
        for dy, dx in slices:
            x_center = (dx.start + dx.stop - 1) // 2
            x.append(x_center)
            y_center = (dy.start + dy.stop - 1) // 2
            y.append(y_center)

        coordinates_small = list(zip(x, y))

        # Big maximum analysis
        data_max = ndimage.maximum_filter(matrix, footprint=footprint_7x7)
        maxima = (matrix == data_max)
        data_min = ndimage.minimum_filter(matrix, footprint=footprint_7x7)
        diff = ((data_max - data_min) > (matrix.max() / 3))
        maxima[diff == 0] = 0

        labeled, num_objects = ndimage.label(maxima)
        slices = ndimage.find_objects(labeled)
        x, y = [], []
        for dy, dx in slices:
            x_center = (dx.start + dx.stop - 1) // 2
            x.append(x_center)
            y_center = (dy.start + dy.stop - 1) // 2
            y.append(y_center)

        coordinates_big = list(zip(x, y))

        coordinates = np.array(list(set(coordinates_big + coordinates_small)))

        if plot:
            plt.figure()
            plt.imshow(matrix, cmap="gray")
            plt.axis('off')
            plt.autoscale(False)
            plt.plot(coordinates[:, 0], coordinates[:, 1], 'rx')
            # plt.title("Located instances")
            plt.show()

        return len(coordinates)
コード例 #3
0
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0)
    # test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

    model = ConvVAE()
    model = model.to(device)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    init_epoch = 0
    # model = nn.DataParallel(model)
    if file_exists('./trained_models/checkpoints/' + run_name +
                   '_checkpoint.pth'):
        print("Loading checkpoint.", flush=True)
        checkpoint = torch.load('./trained_models/checkpoints/' + run_name +
                                '_checkpoint.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        init_epoch = checkpoint['epoch']
        print("Init epoch:", init_epoch, flush=True)

        model.train()