예제 #1
0
    def __init__(self, dataloader_kNN, num_classes):
        super().__init__(dataloader_kNN, num_classes)
        # create a ResNet backbone and remove the classification head
        resnet = torchvision.models.resnet18()
        last_conv_channels = list(resnet.children())[-1].in_features
        self.backbone = nn.Sequential(
            *list(resnet.children())[:-1],
            nn.Conv2d(last_conv_channels, num_ftrs, 1),
        )
        # create a simsiam model based on ResNet
        self.resnet_simsiam = \
            lightly.models.SimSiam(self.backbone, num_ftrs=num_ftrs)
        self.resnet_simsiam.projection_mlp = ProjectionHead([
            (
                self.resnet_simsiam.num_ftrs,
                self.resnet_simsiam.proj_hidden_dim,
                nn.BatchNorm1d(self.resnet_simsiam.proj_hidden_dim),
                nn.ReLU(inplace=True)
            ),
            (
                self.resnet_simsiam.proj_hidden_dim,
                self.resnet_simsiam.out_dim,
                nn.BatchNorm1d(self.resnet_simsiam.out_dim),
                None
            )
        ])
        self.criterion = lightly.loss.SymNegCosineSimilarityLoss()

        self.nn_replacer = NNMemoryBankModule(size=nn_size)
예제 #2
0
파일: nnclr.py 프로젝트: lightly-ai/lightly
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = NNCLRProjectionHead(512, 512, 128)
        self.prediction_head = NNCLRPredictionHead(128, 512, 128)
        self.memory_bank = NNMemoryBankModule(size=4096)

        self.criterion = NTXentLoss()
예제 #3
0
    def __init__(self, dataloader_kNN, num_classes):
        super().__init__(dataloader_kNN, num_classes)
        # create a ResNet backbone and remove the classification head
        resnet = torchvision.models.resnet18()
        last_conv_channels = list(resnet.children())[-1].in_features
        self.backbone = nn.Sequential(
            *list(resnet.children())[:-1],
            nn.Conv2d(last_conv_channels, num_ftrs, 1),
        )
        # create a byol model based on ResNet
        self.resnet_byol = \
            lightly.models.BYOL(self.backbone, num_ftrs=num_ftrs)
        self.criterion = lightly.loss.SymNegCosineSimilarityLoss()

        self.nn_replacer = NNMemoryBankModule(size=nn_size)
예제 #4
0
    def test_memory_bank(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        for model_name, config in self.resnet_variants.items():
            resnet = resnet_generator(model_name)
            model = NNCLR(get_backbone(resnet), **config).to(device)

            for nn_size in [2**3, 2**8]:
                nn_replacer = NNMemoryBankModule(size=nn_size)

                with torch.no_grad():
                    for i in range(10):
                        x0 = torch.rand(
                            (self.batch_size, 3, 64, 64)).to(device)
                        x1 = torch.rand(
                            (self.batch_size, 3, 64, 64)).to(device)
                        (z0, p0), (z1, p1) = model(x0, x1)
                        z0 = nn_replacer(z0.detach(), update=False)
                        z1 = nn_replacer(z1.detach(), update=True)
예제 #5
0
파일: nnclr.py 프로젝트: lightly-ai/lightly
    def forward(self, x):
        y = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(y)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = NNCLR(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

memory_bank = NNMemoryBankModule(size=4096)
memory_bank.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
dataset = LightlyDataset.from_torch_dataset(cifar10)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = SimCLRCollateFunction(input_size=32)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,