def worker_fn(gpu_id: int, world_size: int, batch_size: int): dist.init_process_group( backend="nccl", init_method="tcp://0.0.0.0:1234", world_size=world_size, rank=gpu_id, ) embeddings = torch.full(size=(batch_size, 3), fill_value=float(gpu_id), requires_grad=True).cuda(gpu_id) gathered = SimclrInfoNCECriterion.gather_embeddings(embeddings) if world_size == 1: assert gathered.equal( torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], device=f"cuda:{gpu_id}")) if world_size == 2: assert gathered.equal( torch.tensor( [ [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], ], device=f"cuda:{gpu_id}", )) assert gathered.requires_grad
def test_simclr_info_nce_masks(self): BATCH_SIZE = 4 WORLD_SIZE = 2 buffer_params = BUFFER_PARAMS_STRUCT( BATCH_SIZE * WORLD_SIZE, WORLD_SIZE, EMBEDDING_DIM ) criterion = SimclrInfoNCECriterion(buffer_params=buffer_params, temperature=0.1) self.assertTrue( criterion.pos_mask.equal( torch.tensor( [ [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ] ) ) ) self.assertTrue( criterion.neg_mask.equal( torch.tensor( [ [0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], ] ) ) )
def worker_fn(gpu_id: int, world_size: int, batch_size: int, sync_file: str): init_distributed_on_file(world_size=world_size, gpu_id=gpu_id, sync_file=sync_file) embeddings = torch.full(size=(batch_size, 3), fill_value=float(gpu_id), requires_grad=True).cuda(gpu_id) gathered = SimclrInfoNCECriterion.gather_embeddings(embeddings) if world_size == 1: assert gathered.equal( torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], device=f"cuda:{gpu_id}")) if world_size == 2: assert gathered.equal( torch.tensor( [ [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], ], device=f"cuda:{gpu_id}", )) assert gathered.requires_grad
def test_simclr_backward(self): EMBEDDING_DIM = 3 BATCH_SIZE = 4 WORLD_SIZE = 1 buffer_params = BUFFER_PARAMS_STRUCT( BATCH_SIZE * WORLD_SIZE, WORLD_SIZE, EMBEDDING_DIM ) criterion = SimclrInfoNCECriterion(buffer_params=buffer_params, temperature=0.1) embeddings = torch.tensor( [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]], requires_grad=True, ) self.assertTrue(embeddings.grad is None) criterion(embeddings).backward() self.assertTrue(embeddings.grad is not None) print(embeddings.grad) with torch.no_grad(): next_embeddings = embeddings - embeddings.grad # gradient descent self.assertTrue(criterion(next_embeddings) < criterion(embeddings))
def test_simclr_info_nce_loss(self): loss_layer = SimclrInfoNCECriterion(buffer_params=BUFFER_PARAMS, temperature=0.1) _ = loss_layer(self._get_embedding())