Beispiel #1
0
 def worker_fn(gpu_id: int, world_size: int, batch_size: int):
     dist.init_process_group(
         backend="nccl",
         init_method="tcp://0.0.0.0:1234",
         world_size=world_size,
         rank=gpu_id,
     )
     embeddings = torch.full(size=(batch_size, 3),
                             fill_value=float(gpu_id),
                             requires_grad=True).cuda(gpu_id)
     gathered = SimclrInfoNCECriterion.gather_embeddings(embeddings)
     if world_size == 1:
         assert gathered.equal(
             torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                          device=f"cuda:{gpu_id}"))
     if world_size == 2:
         assert gathered.equal(
             torch.tensor(
                 [
                     [0.0, 0.0, 0.0],
                     [0.0, 0.0, 0.0],
                     [1.0, 1.0, 1.0],
                     [1.0, 1.0, 1.0],
                 ],
                 device=f"cuda:{gpu_id}",
             ))
     assert gathered.requires_grad
Beispiel #2
0
 def test_simclr_info_nce_masks(self):
     BATCH_SIZE = 4
     WORLD_SIZE = 2
     buffer_params = BUFFER_PARAMS_STRUCT(
         BATCH_SIZE * WORLD_SIZE, WORLD_SIZE, EMBEDDING_DIM
     )
     criterion = SimclrInfoNCECriterion(buffer_params=buffer_params, temperature=0.1)
     self.assertTrue(
         criterion.pos_mask.equal(
             torch.tensor(
                 [
                     [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                     [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
                     [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                     [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                 ]
             )
         )
     )
     self.assertTrue(
         criterion.neg_mask.equal(
             torch.tensor(
                 [
                     [0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                     [1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0],
                     [0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                     [1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0],
                 ]
             )
         )
     )
Beispiel #3
0
 def worker_fn(gpu_id: int, world_size: int, batch_size: int,
               sync_file: str):
     init_distributed_on_file(world_size=world_size,
                              gpu_id=gpu_id,
                              sync_file=sync_file)
     embeddings = torch.full(size=(batch_size, 3),
                             fill_value=float(gpu_id),
                             requires_grad=True).cuda(gpu_id)
     gathered = SimclrInfoNCECriterion.gather_embeddings(embeddings)
     if world_size == 1:
         assert gathered.equal(
             torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                          device=f"cuda:{gpu_id}"))
     if world_size == 2:
         assert gathered.equal(
             torch.tensor(
                 [
                     [0.0, 0.0, 0.0],
                     [0.0, 0.0, 0.0],
                     [1.0, 1.0, 1.0],
                     [1.0, 1.0, 1.0],
                 ],
                 device=f"cuda:{gpu_id}",
             ))
     assert gathered.requires_grad
Beispiel #4
0
    def test_simclr_backward(self):
        EMBEDDING_DIM = 3
        BATCH_SIZE = 4
        WORLD_SIZE = 1
        buffer_params = BUFFER_PARAMS_STRUCT(
            BATCH_SIZE * WORLD_SIZE, WORLD_SIZE, EMBEDDING_DIM
        )
        criterion = SimclrInfoNCECriterion(buffer_params=buffer_params, temperature=0.1)
        embeddings = torch.tensor(
            [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]],
            requires_grad=True,
        )

        self.assertTrue(embeddings.grad is None)
        criterion(embeddings).backward()
        self.assertTrue(embeddings.grad is not None)
        print(embeddings.grad)
        with torch.no_grad():
            next_embeddings = embeddings - embeddings.grad  # gradient descent
            self.assertTrue(criterion(next_embeddings) < criterion(embeddings))
Beispiel #5
0
 def test_simclr_info_nce_loss(self):
     loss_layer = SimclrInfoNCECriterion(buffer_params=BUFFER_PARAMS,
                                         temperature=0.1)
     _ = loss_layer(self._get_embedding())