Exemple #1
0
    def test_shift_indices_tuple(self):
        for dtype in TEST_DTYPES:
            batch_size = 32
            pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1)
            triplet_miner = TripletMarginMiner(margin=1)
            self.loss = CrossBatchMemory(
                loss=ContrastiveLoss(),
                embedding_size=self.embedding_size,
                memory_size=self.memory_size,
            )
            for i in range(30):
                embeddings = (
                    torch.randn(batch_size, self.embedding_size)
                    .to(TEST_DEVICE)
                    .type(dtype)
                )
                labels = torch.arange(batch_size).to(TEST_DEVICE)
                loss = self.loss(embeddings, labels)
                all_labels = torch.cat([labels, self.loss.label_memory], dim=0)

                indices_tuple = lmu.get_all_pairs_indices(
                    labels, self.loss.label_memory
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size))
                a1, p, a2, n = shifted
                self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool()))

                indices_tuple = pair_miner(
                    embeddings,
                    labels,
                    self.loss.embedding_memory,
                    self.loss.label_memory,
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[3], shifted[3] - batch_size))
                a1, p, a2, n = shifted
                self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[a2] - all_labels[n]).bool()))

                indices_tuple = triplet_miner(
                    embeddings,
                    labels,
                    self.loss.embedding_memory,
                    self.loss.label_memory,
                )
                shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
                self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
                self.assertTrue(torch.equal(indices_tuple[1], shifted[1] - batch_size))
                self.assertTrue(torch.equal(indices_tuple[2], shifted[2] - batch_size))
                a, p, n = shifted
                self.assertTrue(not torch.any((all_labels[a] - all_labels[p]).bool()))
                self.assertTrue(torch.all((all_labels[p] - all_labels[n]).bool()))
Exemple #2
0
    def test_sanity_check(self):
        # cross batch memory with batch_size == memory_size should be equivalent to just using the inner loss function
        for dtype in TEST_DTYPES:
            for test_enqueue_idx in [False, True]:
                for memory_size in range(20, 40, 5):
                    inner_loss = NTXentLoss(temperature=0.1)
                    inner_miner = TripletMarginMiner(margin=0.1)
                    loss = CrossBatchMemory(
                        loss=inner_loss,
                        embedding_size=self.embedding_size,
                        memory_size=memory_size,
                    )
                    loss_with_miner = CrossBatchMemory(
                        loss=inner_loss,
                        embedding_size=self.embedding_size,
                        memory_size=memory_size,
                        miner=inner_miner,
                    )
                    for i in range(10):
                        if test_enqueue_idx:
                            enqueue_idx = torch.arange(memory_size, memory_size * 2)
                            not_enqueue_idx = torch.arange(memory_size)
                            batch_size = memory_size * 2
                        else:
                            enqueue_idx = None
                            batch_size = memory_size
                        embeddings = (
                            torch.randn(batch_size, self.embedding_size)
                            .to(TEST_DEVICE)
                            .type(dtype)
                        )
                        labels = torch.randint(0, 4, (batch_size,)).to(TEST_DEVICE)

                        if test_enqueue_idx:
                            pairs = lmu.get_all_pairs_indices(
                                labels[not_enqueue_idx], labels[enqueue_idx]
                            )
                            pairs = c_f.shift_indices_tuple(pairs, memory_size)
                            inner_loss_val = inner_loss(embeddings, labels, pairs)
                        else:
                            inner_loss_val = inner_loss(embeddings, labels)
                        loss_val = loss(embeddings, labels, enqueue_idx=enqueue_idx)
                        self.assertTrue(torch.isclose(inner_loss_val, loss_val))

                        if test_enqueue_idx:
                            triplets = inner_miner(
                                embeddings[not_enqueue_idx],
                                labels[not_enqueue_idx],
                                embeddings[enqueue_idx],
                                labels[enqueue_idx],
                            )
                            triplets = c_f.shift_indices_tuple(triplets, memory_size)
                            inner_loss_val = inner_loss(embeddings, labels, triplets)
                        else:
                            triplets = inner_miner(embeddings, labels)
                            inner_loss_val = inner_loss(embeddings, labels, triplets)
                        loss_val = loss_with_miner(
                            embeddings, labels, enqueue_idx=enqueue_idx
                        )
                        self.assertTrue(torch.isclose(inner_loss_val, loss_val))
Exemple #3
0
def single_process_function(
    rank,
    world_size,
    lr,
    model,
    inputs,
    labels,
    loss_fn,
    miner_fn,
    original_model,
    original_loss_fn,
    original_miner_fn,
    correct_loss,
    correct_indices_tuple,
    is_tuple_loss,
    ref_outputs,
    ref_labels,
):
    setup(rank, world_size)
    device = torch.device("cuda:{}".format(rank))

    ddp_mp_model = DDP(model.to(device), device_ids=[rank], output_device=rank)

    if is_tuple_loss:
        loss_fn = distributed.DistributedLossWrapper(loss=loss_fn)
    else:
        loss_fn = distributed.DistributedLossWrapper(loss=loss_fn.to(device),
                                                     device_ids=[rank],
                                                     output_device=rank)
        loss_optimizer = optim.SGD(loss_fn.parameters(), lr=lr)
        loss_optimizer.zero_grad()

    miner_fn = distributed.DistributedMinerWrapper(miner=miner_fn)

    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=lr)
    optimizer.zero_grad()
    outputs = ddp_mp_model(inputs[rank].to(device))

    if ref_outputs is not None:
        ref_outputs[rank] = ref_outputs[rank].to(device)
        indices_tuple = miner_fn(outputs, labels[rank], ref_outputs[rank],
                                 ref_labels[rank])
        indices_tuple = c_f.shift_indices_tuple(indices_tuple,
                                                len(outputs) * world_size)
        loss = loss_fn(
            [outputs, ref_outputs[rank]],
            [labels[rank], ref_labels[rank]],
            indices_tuple,
        )
    else:
        indices_tuple = miner_fn(outputs, labels[rank])
        loss = loss_fn(outputs, labels[rank], indices_tuple)

    if is_tuple_loss:
        pos_loss_size = loss_fn.loss.reducer.reducers["pos_loss"].losses_size
        neg_loss_size = loss_fn.loss.reducer.reducers["neg_loss"].losses_size
        correct_pos_loss_size = original_loss_fn.reducer.reducers[
            "pos_loss"].losses_size
        correct_neg_loss_size = original_loss_fn.reducer.reducers[
            "neg_loss"].losses_size
        assert pos_loss_size == correct_pos_loss_size
        assert neg_loss_size == correct_neg_loss_size
    else:
        loss_size = loss_fn.loss.module.reducer.losses_size
        correct_loss_size = original_loss_fn.reducer.losses_size
        assert loss_size == correct_loss_size

    assert torch.isclose(loss, torch.from_numpy(correct_loss).to(device))
    assert miner_fn.miner.num_pos_pairs == original_miner_fn.num_pos_pairs
    assert miner_fn.miner.num_neg_pairs == original_miner_fn.num_neg_pairs
    for i in range(len(correct_indices_tuple)):
        assert torch.all(indices_tuple[i] == (
            torch.from_numpy(correct_indices_tuple[i]).to(device)))

    dist.barrier()
    loss.backward()

    original_model = original_model.to(device)
    assert not parameters_are_equal(original_model, ddp_mp_model.module)
    dist.barrier()
    optimizer.step()
    dist.barrier()
    assert parameters_are_equal(original_model, ddp_mp_model.module)

    if not is_tuple_loss:
        original_loss_fn = original_loss_fn.to(device)
        assert not parameters_are_equal(original_loss_fn, loss_fn.loss.module)
        dist.barrier()
        loss_optimizer.step()
        dist.barrier()
        assert parameters_are_equal(original_loss_fn, loss_fn.loss.module)

    dist.barrier()
    cleanup()
Exemple #4
0
    def loss_and_miner_tester(self,
                              loss_class,
                              miner_class,
                              is_tuple_loss,
                              test_ref_emb=False):
        for dtype in TEST_DTYPES:
            for world_size in range(1, 5):
                batch_size = 20
                lr = 1
                inputs = [
                    torch.randn(batch_size, 10).type(dtype)
                    for _ in range(world_size)
                ]
                labels = [
                    torch.randint(low=0, high=2, size=(batch_size, ))
                    for _ in range(world_size)
                ]
                original_model = ToyMpModel().type(dtype)
                model = ToyMpModel().type(dtype)
                model.load_state_dict(original_model.state_dict())

                original_model = original_model.to(self.device)
                original_loss_fn = self.create_loss(loss_class, is_tuple_loss,
                                                    dtype)
                loss_fn = self.create_loss(loss_class, is_tuple_loss, dtype)
                if not is_tuple_loss:
                    loss_fn.load_state_dict(original_loss_fn.state_dict())
                    assert parameters_are_equal(original_loss_fn, loss_fn)
                    original_loss_fn = original_loss_fn.to(self.device)
                    loss_optimizer = optim.SGD(original_loss_fn.parameters(),
                                               lr=lr)
                    loss_optimizer.zero_grad()

                original_miner_fn = miner_class()
                miner_fn = miner_class()

                optimizer = optim.SGD(original_model.parameters(), lr=lr)
                optimizer.zero_grad()
                all_inputs = torch.cat(inputs, dim=0).to(self.device)
                all_labels = torch.cat(labels, dim=0).to(self.device)
                all_outputs = original_model(all_inputs)
                if test_ref_emb:
                    ref_outputs = [
                        torch.randn(batch_size, 5).type(dtype).detach()
                        for _ in range(world_size)
                    ]
                    ref_labels = [
                        torch.randint(low=0, high=2, size=(batch_size, ))
                        for _ in range(world_size)
                    ]
                    all_ref_outputs = torch.cat(ref_outputs,
                                                dim=0).to(self.device)
                    all_ref_labels = torch.cat(ref_labels,
                                               dim=0).to(self.device)
                    correct_indices_tuple = original_miner_fn(
                        all_outputs, all_labels, all_ref_outputs,
                        all_ref_labels)
                    correct_indices_tuple = c_f.shift_indices_tuple(
                        correct_indices_tuple, len(all_outputs))
                    all_outputs = torch.cat([all_outputs, all_ref_outputs],
                                            dim=0).to(self.device)
                    all_labels = torch.cat([all_labels, all_ref_labels],
                                           dim=0).to(self.device)
                else:
                    ref_outputs, ref_labels = None, None
                    correct_indices_tuple = original_miner_fn(
                        all_outputs, all_labels)
                correct_loss = original_loss_fn(all_outputs, all_labels,
                                                correct_indices_tuple)
                (correct_loss / world_size).backward(retain_graph=True)
                optimizer.step()
                if not is_tuple_loss:
                    for p in original_loss_fn.parameters():
                        # Each replica loss function sees gradients from the global batch
                        p.grad *= world_size
                    loss_optimizer.step()

                mp.spawn(
                    single_process_function,
                    args=(
                        world_size,
                        lr,
                        model,
                        inputs,
                        labels,
                        loss_fn,
                        miner_fn,
                        original_model,
                        original_loss_fn,
                        original_miner_fn,
                        correct_loss.detach().cpu().numpy(),
                        tuple([x.cpu().numpy()
                               for x in correct_indices_tuple]),
                        is_tuple_loss,
                        ref_outputs,
                        ref_labels,
                    ),
                    nprocs=world_size,
                    join=True,
                )