def test_from_torch_dataset_with_transform(self): dataset_ = torchvision.datasets.FakeData(size=1, image_size=(3, 32, 32)) dataset = LightlyDataset.from_torch_dataset( dataset_, transform=torchvision.transforms.ToTensor()) self.assertIsNotNone(dataset.transform) self.assertIsNotNone(dataset.dataset.transform)
student_out, epoch=self.current_epoch) return loss def configure_optimizers(self): optim = torch.optim.Adam(self.parameters(), lr=0.001) return optim model = DINO() # we ignore object detection annotations by setting target_transform to return 0 pascal_voc = torchvision.datasets.VOCDetection("datasets/pascal_voc", download=True, target_transform=lambda t: 0) dataset = LightlyDataset.from_torch_dataset(pascal_voc) # or create a dataset from a folder containing images or videos: # dataset = LightlyDataset("path/to/folder") collate_fn = DINOCollateFunction() dataloader = torch.utils.data.DataLoader( dataset, batch_size=64, collate_fn=collate_fn, shuffle=True, drop_last=True, num_workers=8, ) gpus = torch.cuda.device_count()
def forward(self, x): x = self.backbone(x).flatten(start_dim=1) z = self.projection_head(x) return z resnet = torchvision.models.resnet18() backbone = nn.Sequential(*list(resnet.children())[:-1]) model = BarlowTwins(backbone) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True) dataset = LightlyDataset.from_torch_dataset(cifar10) # or create a dataset from a folder containing images or videos: # dataset = LightlyDataset("path/to/folder") collate_fn = ImageCollateFunction(input_size=32) dataloader = torch.utils.data.DataLoader( dataset, batch_size=256, collate_fn=collate_fn, shuffle=True, drop_last=True, num_workers=8, ) criterion = BarlowTwinsLoss()
def test_from_torch_dataset(self): _dataset = torchvision.datasets.FakeData(size=1, image_size=(3, 32, 32)) dataset = LightlyDataset.from_torch_dataset(_dataset) self.assertEqual(len(_dataset), len(dataset)) self.assertEqual(len(dataset.get_filenames()), len(dataset))