Ejemplo n.º 1
0
    def __init__(self, output_matching_size=None):
        super(GMNETCNet, self).__init__()

        self.vae = ConvVAE()
        for p in self.vae.parameters():
            p.requires_grad = False
        self.vae.train(False)

        resnet_model = resnet50(pretrained=True)
        self.cut_resnet = nn.Sequential(*(list(resnet_model.children())[:6]))
        for module in self.cut_resnet.modules():
            if isinstance(module, Bottleneck) and isinstance(
                    module.conv2, nn.Conv2d):
                module.conv2 = Adapter(module.conv2)

        resnet_model = resnet50(pretrained=True)
        self.cut_resnet_template = nn.Sequential(
            *(list(resnet_model.children())[:6]))
        for module in self.cut_resnet_template.modules():
            if isinstance(module, Bottleneck) and isinstance(
                    module.conv2, nn.Conv2d):
                module.conv2 = Adapter(module.conv2)

        self.adapt_pool = nn.AdaptiveAvgPool2d(1)
        self.batch_norm = nn.BatchNorm2d(512)

        self.matching_model = MatchingNet(1024, output_matching_size)
        self.output_conv = nn.Conv2d(256, 1, 3, stride=1,
                                     padding=1)  # TODO compute padding same
Ejemplo n.º 2
0
    def __init__(self, output_size=10):
        super(ETCNet, self).__init__()

        self.vae = ConvVAE()
        for p in self.vae.parameters():
            p.requires_grad = False
        self.vae.train(False)

        self.resnet_model = resnet50(pretrained=True)
        self.resnet_model.fc = FullyConnectedLayers()
        self.output = nn.Linear(128, output_size)
Ejemplo n.º 3
0
class ETCNet(nn.Module):
    """
    Encoded Template Convolutional Network (ETCNet) is a neural network architecture based on the usage of a
    pretrained VAE which encodes a given template and uses its encoding as weights for the first convolution of a
    ResNet.
    """
    def __init__(self, output_size=10):
        super(ETCNet, self).__init__()

        self.vae = ConvVAE()
        for p in self.vae.parameters():
            p.requires_grad = False
        self.vae.train(False)

        self.resnet_model = resnet50(pretrained=True)
        self.resnet_model.fc = FullyConnectedLayers()
        self.output = nn.Linear(128, output_size)

    def forward(self, x, x_object):
        mu, logvar = self.vae.encoder(x_object)
        template_weights = self.vae.reparametrize(mu, logvar)
        template_weights = template_weights[0].repeat(3, 1, 1, 1)
        template_weights = template_weights.permute(1, 0, 2, 3)
        self.resnet_model.conv1.weight = torch.nn.Parameter(
            template_weights, requires_grad=False)
        x = self.resnet_model(x)
        x = self.output(x)
        return x

    def load_vae(self, path):
        self.vae.load_state_dict(torch.load(path))
Ejemplo n.º 4
0
from datasets.ilsvrc_dataset import ILSVRC
from models.cnn_vae import ConvVAE

if __name__ == "__main__":
    image_shape = (255, 255)
    data_root = './data/ILSVRC/ILSVRC2015'

    transform = transforms.Compose([transforms.ToTensor()])

    test_set = ILSVRC(data_root,
                      image_shape=image_shape,
                      data_percentage=0.5,
                      train=False,
                      transform=transform)

    model = ConvVAE()
    model.load_state_dict(torch.load("./trained_models/ConvVAE.pt"))
    model.eval()
    model2 = ConvVAE()
    model2.load_state_dict(
        torch.load("./trained_models/ConvVAE_firstConv_GMN_batch.pt"))
    model2.eval()

    index = random.randint(0, len(test_set))
    print(index)
    images, _, ground_truth, count, resized_template = test_set[index]
    templates = torch.reshape(
        resized_template,
        (1, resized_template.shape[-3], resized_template.shape[-2],
         resized_template.shape[-1]))
    decoded, _, _ = model(templates)
Ejemplo n.º 5
0
class GMNETCNet(nn.Module):
    """

    """
    def __init__(self, output_matching_size=None):
        super(GMNETCNet, self).__init__()

        self.vae = ConvVAE()
        for p in self.vae.parameters():
            p.requires_grad = False
        self.vae.train(False)

        resnet_model = resnet50(pretrained=True)
        self.cut_resnet = nn.Sequential(*(list(resnet_model.children())[:6]))
        for module in self.cut_resnet.modules():
            if isinstance(module, Bottleneck) and isinstance(
                    module.conv2, nn.Conv2d):
                module.conv2 = Adapter(module.conv2)

        resnet_model = resnet50(pretrained=True)
        self.cut_resnet_template = nn.Sequential(
            *(list(resnet_model.children())[:6]))
        for module in self.cut_resnet_template.modules():
            if isinstance(module, Bottleneck) and isinstance(
                    module.conv2, nn.Conv2d):
                module.conv2 = Adapter(module.conv2)

        self.adapt_pool = nn.AdaptiveAvgPool2d(1)
        self.batch_norm = nn.BatchNorm2d(512)

        self.matching_model = MatchingNet(1024, output_matching_size)
        self.output_conv = nn.Conv2d(256, 1, 3, stride=1,
                                     padding=1)  # TODO compute padding same

    def forward(self, x, x_object, resized_template):

        mu, logvar = self.vae.encoder(resized_template)
        template_weights = self.vae.reparametrize(mu, logvar)
        template_weights = template_weights[0].repeat(3, 1, 1, 1)
        template_weights = template_weights.permute(1, 0, 2, 3)
        for module in self.cut_resnet.modules():
            if isinstance(module, nn.Conv2d):
                module.weight = torch.nn.Parameter(template_weights,
                                                   requires_grad=False)
                break

        for module in self.cut_resnet_template.modules():
            if isinstance(module, nn.Conv2d):
                module.weight = torch.nn.Parameter(template_weights,
                                                   requires_grad=False)
                break

        x_object = self.cut_resnet_template(x_object)
        x_object = self.adapt_pool(x_object)
        x_object = F.normalize(x_object, p=2, dim=-1)

        x = self.batch_norm(self.cut_resnet(x))

        x_object = x_object.repeat(1, 1, x.shape[-2], x.shape[-1])
        outputs = torch.cat((x, x_object), 1)

        outputs = self.matching_model(outputs)
        outputs = self.output_conv(outputs)
        return outputs

    def load_vae(self, path):
        self.vae.load_state_dict(torch.load(path))

    @staticmethod
    def get_count(matrix, plot=False):
        footprint_3x3 = np.ones((3, 3))
        footprint_3x3[0, 0] = 0
        footprint_3x3[2, 0] = 0
        footprint_3x3[0, 2] = 0
        footprint_3x3[2, 2] = 0

        footprint_7x7 = np.ones((7, 7))
        footprint_7x7[0, 0] = 0
        footprint_7x7[6, 0] = 0
        footprint_7x7[0, 6] = 0
        footprint_7x7[6, 6] = 0

        # Small maximum analysis
        data_max = ndimage.maximum_filter(matrix, footprint=footprint_3x3)
        maxima = (matrix == data_max)
        data_min = ndimage.minimum_filter(matrix, footprint=footprint_3x3)
        diff = ((data_max - data_min) > (matrix.max() / 3))
        maxima[diff == 0] = 0

        labeled, num_objects = ndimage.label(maxima)
        slices = ndimage.find_objects(labeled)
        x, y = [], []
        for dy, dx in slices:
            x_center = (dx.start + dx.stop - 1) // 2
            x.append(x_center)
            y_center = (dy.start + dy.stop - 1) // 2
            y.append(y_center)

        coordinates_small = list(zip(x, y))

        # Big maximum analysis
        data_max = ndimage.maximum_filter(matrix, footprint=footprint_7x7)
        maxima = (matrix == data_max)
        data_min = ndimage.minimum_filter(matrix, footprint=footprint_7x7)
        diff = ((data_max - data_min) > (matrix.max() / 3))
        maxima[diff == 0] = 0

        labeled, num_objects = ndimage.label(maxima)
        slices = ndimage.find_objects(labeled)
        x, y = [], []
        for dy, dx in slices:
            x_center = (dx.start + dx.stop - 1) // 2
            x.append(x_center)
            y_center = (dy.start + dy.stop - 1) // 2
            y.append(y_center)

        coordinates_big = list(zip(x, y))

        coordinates = np.array(list(set(coordinates_big + coordinates_small)))

        if plot:
            plt.figure()
            plt.imshow(matrix, cmap="gray")
            plt.axis('off')
            plt.autoscale(False)
            plt.plot(coordinates[:, 0], coordinates[:, 1], 'rx')
            # plt.title("Located instances")
            plt.show()

        return len(coordinates)
Ejemplo n.º 6
0
from models.cnn_vae import ConvVAE

if __name__ == "__main__":
    image_shape = (255, 255)
    dataset_root = './data/CIFAR10'

    transform = transforms.Compose([
        transforms.Resize((96, 96), interpolation=Image.NEAREST),
        transforms.ToTensor()
    ])

    test_set = datasets.CIFAR10(root=dataset_root,
                                train=False,
                                download=True,
                                transform=transform)
    model_r = ConvVAE(channels=1)
    model_g = ConvVAE(channels=1)
    model_b = ConvVAE(channels=1)

    model_r.load_state_dict(torch.load("./trained_models/ConvVAE_r.pt"))
    model_g.load_state_dict(torch.load("./trained_models/ConvVAE_g.pt"))
    model_b.load_state_dict(torch.load("./trained_models/ConvVAE_b.pt"))

    model_r.eval()
    model_g.eval()
    model_b.eval()

    index = random.randint(0, len(test_set))
    print(index)
    image, class_index = test_set[4942]
Ejemplo n.º 7
0
                     transform=transform)

    # train_len = len(train_set)
    # train_set, val_set = random_split(train_set, [int(train_len * 0.8), int(train_len * 0.2)])

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0)
    # test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

    model = ConvVAE()
    model = model.to(device)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    init_epoch = 0
    # model = nn.DataParallel(model)
    if file_exists('./trained_models/checkpoints/' + run_name +
                   '_checkpoint.pth'):
        print("Loading checkpoint.", flush=True)
        checkpoint = torch.load('./trained_models/checkpoints/' + run_name +
                                '_checkpoint.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        init_epoch = checkpoint['epoch']