Exemplo n.º 1
0
 def __init__(self):
     super().__init__()
     resnet = torchvision.models.resnet18()
     self.backbone = nn.Sequential(*list(resnet.children())[:-1])
     self.projection_head = SwaVProjectionHead(512, 512, 128)
     self.prototypes = SwaVPrototypes(128, n_prototypes=512)
     self.criterion = SwaVLoss()
Exemplo n.º 2
0
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = SwaVProjectionHead(512, 512, 128)
        self.prototypes = SwaVPrototypes(128, n_prototypes=512)

        # enable sinkhorn_gather_distributed to gather features from all gpus
        # while running the sinkhorn algorithm in the loss calculation
        self.criterion = SwaVLoss(sinkhorn_gather_distributed=True)
Exemplo n.º 3
0
    def test_forward_pass_bsz_1(self):

        n = 32
        n_high_res = 2
        high_res = [torch.eye(1, n) for i in range(n_high_res)]

        for n_low_res in range(6):
            for sinkhorn_iterations in range(3):
                criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
                low_res = [torch.eye(1, n) for i in range(n_low_res)]

                with self.subTest(
                        msg=
                        f'n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}'
                ):
                    loss = criterion(high_res, low_res)
Exemplo n.º 4
0
    def test_forward_pass_cuda(self):
        n = 32
        n_high_res = 2
        high_res = [torch.eye(n, n).cuda() for i in range(n_high_res)]

        for n_low_res in range(6):
            for sinkhorn_iterations in range(3):
                criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations)
                low_res = [torch.eye(n, n).cuda() for i in range(n_low_res)]

                with self.subTest(
                        msg=
                        f'n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}'
                ):
                    loss = criterion(high_res, low_res)
                    # loss should be almost zero for unit matrix
                    self.assertGreater(0.5, loss.cpu().numpy())
Exemplo n.º 5
0
dataset = LightlyDataset.from_torch_dataset(pascal_voc)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = SwaVCollateFunction()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=128,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = SwaVLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for batch, _, _ in dataloader:
        multi_crop_features = [model(x.to(device)) for x in batch]
        high_resolution = multi_crop_features[:2]
        low_resolution = multi_crop_features[2:]
        loss = criterion(high_resolution, low_resolution)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)