def test_shift_indices_tuple(self): for dtype in TEST_DTYPES: batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory( loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size, ) for i in range(30): embeddings = ( torch.randn(batch_size, self.embedding_size) .to(TEST_DEVICE) .type(dtype) ) labels = torch.arange(batch_size).to(TEST_DEVICE) loss = self.loss(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) indices_tuple = lmu.get_all_pairs_indices( labels, self.loss.label_memory ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool())) indices_tuple = pair_miner( embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool())) indices_tuple = triplet_miner( embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, ) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size)) self.assertTrue(torch.equal(indices_tuple[2], shifted[2] - batch_size)) a, p, n = shifted self.assertTrue(not torch.any((all_labels[a] - all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[p] - all_labels[n]).bool()))
def test_sanity_check(self): # cross batch memory with batch_size == memory_size should be equivalent to just using the inner loss function for dtype in TEST_DTYPES: for test_enqueue_idx in [False, True]: for memory_size in range(20, 40, 5): inner_loss = NTXentLoss(temperature=0.1) inner_miner = TripletMarginMiner(margin=0.1) loss = CrossBatchMemory( loss=inner_loss, embedding_size=self.embedding_size, memory_size=memory_size, ) loss_with_miner = CrossBatchMemory( loss=inner_loss, embedding_size=self.embedding_size, memory_size=memory_size, miner=inner_miner, ) for i in range(10): if test_enqueue_idx: enqueue_idx = torch.arange(memory_size, memory_size * 2) not_enqueue_idx = torch.arange(memory_size) batch_size = memory_size * 2 else: enqueue_idx = None batch_size = memory_size embeddings = ( torch.randn(batch_size, self.embedding_size) .to(TEST_DEVICE) .type(dtype) ) labels = torch.randint(0, 4, (batch_size,)).to(TEST_DEVICE) if test_enqueue_idx: pairs = lmu.get_all_pairs_indices( labels[not_enqueue_idx], labels[enqueue_idx] ) pairs = c_f.shift_indices_tuple(pairs, memory_size) inner_loss_val = inner_loss(embeddings, labels, pairs) else: inner_loss_val = inner_loss(embeddings, labels) loss_val = loss(embeddings, labels, enqueue_idx=enqueue_idx) self.assertTrue(torch.isclose(inner_loss_val, loss_val)) if test_enqueue_idx: triplets = inner_miner( embeddings[not_enqueue_idx], labels[not_enqueue_idx], embeddings[enqueue_idx], labels[enqueue_idx], ) triplets = c_f.shift_indices_tuple(triplets, memory_size) inner_loss_val = inner_loss(embeddings, labels, triplets) else: triplets = inner_miner(embeddings, labels) inner_loss_val = inner_loss(embeddings, labels, triplets) loss_val = loss_with_miner( embeddings, labels, enqueue_idx=enqueue_idx ) self.assertTrue(torch.isclose(inner_loss_val, loss_val))
def single_process_function( rank, world_size, lr, model, inputs, labels, loss_fn, miner_fn, original_model, original_loss_fn, original_miner_fn, correct_loss, correct_indices_tuple, is_tuple_loss, ref_outputs, ref_labels, ): setup(rank, world_size) device = torch.device("cuda:{}".format(rank)) ddp_mp_model = DDP(model.to(device), device_ids=[rank], output_device=rank) if is_tuple_loss: loss_fn = distributed.DistributedLossWrapper(loss=loss_fn) else: loss_fn = distributed.DistributedLossWrapper(loss=loss_fn.to(device), device_ids=[rank], output_device=rank) loss_optimizer = optim.SGD(loss_fn.parameters(), lr=lr) loss_optimizer.zero_grad() miner_fn = distributed.DistributedMinerWrapper(miner=miner_fn) optimizer = optim.SGD(ddp_mp_model.parameters(), lr=lr) optimizer.zero_grad() outputs = ddp_mp_model(inputs[rank].to(device)) if ref_outputs is not None: ref_outputs[rank] = ref_outputs[rank].to(device) indices_tuple = miner_fn(outputs, labels[rank], ref_outputs[rank], ref_labels[rank]) indices_tuple = c_f.shift_indices_tuple(indices_tuple, len(outputs) * world_size) loss = loss_fn( [outputs, ref_outputs[rank]], [labels[rank], ref_labels[rank]], indices_tuple, ) else: indices_tuple = miner_fn(outputs, labels[rank]) loss = loss_fn(outputs, labels[rank], indices_tuple) if is_tuple_loss: pos_loss_size = loss_fn.loss.reducer.reducers["pos_loss"].losses_size neg_loss_size = loss_fn.loss.reducer.reducers["neg_loss"].losses_size correct_pos_loss_size = original_loss_fn.reducer.reducers[ "pos_loss"].losses_size correct_neg_loss_size = original_loss_fn.reducer.reducers[ "neg_loss"].losses_size assert pos_loss_size == correct_pos_loss_size assert neg_loss_size == correct_neg_loss_size else: loss_size = loss_fn.loss.module.reducer.losses_size correct_loss_size = original_loss_fn.reducer.losses_size assert loss_size == correct_loss_size assert torch.isclose(loss, torch.from_numpy(correct_loss).to(device)) assert miner_fn.miner.num_pos_pairs == original_miner_fn.num_pos_pairs assert miner_fn.miner.num_neg_pairs == original_miner_fn.num_neg_pairs for i in range(len(correct_indices_tuple)): assert torch.all(indices_tuple[i] == ( torch.from_numpy(correct_indices_tuple[i]).to(device))) dist.barrier() loss.backward() original_model = original_model.to(device) assert not parameters_are_equal(original_model, ddp_mp_model.module) dist.barrier() optimizer.step() dist.barrier() assert parameters_are_equal(original_model, ddp_mp_model.module) if not is_tuple_loss: original_loss_fn = original_loss_fn.to(device) assert not parameters_are_equal(original_loss_fn, loss_fn.loss.module) dist.barrier() loss_optimizer.step() dist.barrier() assert parameters_are_equal(original_loss_fn, loss_fn.loss.module) dist.barrier() cleanup()
def loss_and_miner_tester(self, loss_class, miner_class, is_tuple_loss, test_ref_emb=False): for dtype in TEST_DTYPES: for world_size in range(1, 5): batch_size = 20 lr = 1 inputs = [ torch.randn(batch_size, 10).type(dtype) for _ in range(world_size) ] labels = [ torch.randint(low=0, high=2, size=(batch_size, )) for _ in range(world_size) ] original_model = ToyMpModel().type(dtype) model = ToyMpModel().type(dtype) model.load_state_dict(original_model.state_dict()) original_model = original_model.to(self.device) original_loss_fn = self.create_loss(loss_class, is_tuple_loss, dtype) loss_fn = self.create_loss(loss_class, is_tuple_loss, dtype) if not is_tuple_loss: loss_fn.load_state_dict(original_loss_fn.state_dict()) assert parameters_are_equal(original_loss_fn, loss_fn) original_loss_fn = original_loss_fn.to(self.device) loss_optimizer = optim.SGD(original_loss_fn.parameters(), lr=lr) loss_optimizer.zero_grad() original_miner_fn = miner_class() miner_fn = miner_class() optimizer = optim.SGD(original_model.parameters(), lr=lr) optimizer.zero_grad() all_inputs = torch.cat(inputs, dim=0).to(self.device) all_labels = torch.cat(labels, dim=0).to(self.device) all_outputs = original_model(all_inputs) if test_ref_emb: ref_outputs = [ torch.randn(batch_size, 5).type(dtype).detach() for _ in range(world_size) ] ref_labels = [ torch.randint(low=0, high=2, size=(batch_size, )) for _ in range(world_size) ] all_ref_outputs = torch.cat(ref_outputs, dim=0).to(self.device) all_ref_labels = torch.cat(ref_labels, dim=0).to(self.device) correct_indices_tuple = original_miner_fn( all_outputs, all_labels, all_ref_outputs, all_ref_labels) correct_indices_tuple = c_f.shift_indices_tuple( correct_indices_tuple, len(all_outputs)) all_outputs = torch.cat([all_outputs, all_ref_outputs], dim=0).to(self.device) all_labels = torch.cat([all_labels, all_ref_labels], dim=0).to(self.device) else: ref_outputs, ref_labels = None, None correct_indices_tuple = original_miner_fn( all_outputs, all_labels) correct_loss = original_loss_fn(all_outputs, all_labels, correct_indices_tuple) (correct_loss / world_size).backward(retain_graph=True) optimizer.step() if not is_tuple_loss: for p in original_loss_fn.parameters(): # Each replica loss function sees gradients from the global batch p.grad *= world_size loss_optimizer.step() mp.spawn( single_process_function, args=( world_size, lr, model, inputs, labels, loss_fn, miner_fn, original_model, original_loss_fn, original_miner_fn, correct_loss.detach().cpu().numpy(), tuple([x.cpu().numpy() for x in correct_indices_tuple]), is_tuple_loss, ref_outputs, ref_labels, ), nprocs=world_size, join=True, )