예제 #1
0
 def test_from_torch_dataset_with_transform(self):
     dataset_ = torchvision.datasets.FakeData(size=1,
                                              image_size=(3, 32, 32))
     dataset = LightlyDataset.from_torch_dataset(
         dataset_, transform=torchvision.transforms.ToTensor())
     self.assertIsNotNone(dataset.transform)
     self.assertIsNotNone(dataset.dataset.transform)
예제 #2
0
파일: dino.py 프로젝트: lightly-ai/lightly
                              student_out,
                              epoch=self.current_epoch)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=0.001)
        return optim


model = DINO()

# we ignore object detection annotations by setting target_transform to return 0
pascal_voc = torchvision.datasets.VOCDetection("datasets/pascal_voc",
                                               download=True,
                                               target_transform=lambda t: 0)
dataset = LightlyDataset.from_torch_dataset(pascal_voc)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = DINOCollateFunction()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

gpus = torch.cuda.device_count()
예제 #3
0
    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BarlowTwins(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
dataset = LightlyDataset.from_torch_dataset(cifar10)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = ImageCollateFunction(input_size=32)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = BarlowTwinsLoss()
예제 #4
0
 def test_from_torch_dataset(self):
     _dataset = torchvision.datasets.FakeData(size=1,
                                              image_size=(3, 32, 32))
     dataset = LightlyDataset.from_torch_dataset(_dataset)
     self.assertEqual(len(_dataset), len(dataset))
     self.assertEqual(len(dataset.get_filenames()), len(dataset))