def test_distance_weighted_miner(self):
        embedding_angles = torch.arange(0, 180)
        embeddings = torch.tensor(
            [c_f.angle_to_coord(a) for a in embedding_angles],
            requires_grad=True,
            dtype=torch.float)  #2D embeddings
        labels = torch.randint(low=0, high=2, size=(180, ))
        a, _, n = lmu.get_all_triplets_indices(labels)
        all_an_dist = torch.nn.functional.pairwise_distance(
            embeddings[a], embeddings[n], 2)
        min_an_dist = torch.min(all_an_dist)

        for non_zero_cutoff_int in range(5, 15):
            non_zero_cutoff = (float(non_zero_cutoff_int) / 10.) - 0.01
            miner = DistanceWeightedMiner(0, non_zero_cutoff)
            a, p, n = miner(embeddings, labels)
            anchors, positives, negatives = embeddings[a], embeddings[
                p], embeddings[n]
            an_dist = torch.nn.functional.pairwise_distance(
                anchors, negatives, 2)
            self.assertTrue(torch.max(an_dist) <= non_zero_cutoff)
            an_dist_var = torch.var(an_dist)
            an_dist_mean = torch.mean(an_dist)
            target_var = ((non_zero_cutoff - min_an_dist)**
                          2) / 12  # variance formula for uniform distribution
            target_mean = (non_zero_cutoff - min_an_dist) / 2
            self.assertTrue(
                torch.abs(an_dist_var - target_var) / target_var < 0.1)
            self.assertTrue(
                torch.abs(an_dist_mean - target_mean) / target_mean < 0.1)
    def test_distance_weighted_miner(self, with_ref_labels=False):
        for dtype in TEST_DTYPES:
            embedding_angles = torch.arange(0, 256)
            embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings
            ref_embeddings = embeddings.clone() if with_ref_labels else None
            labels = torch.randint(low=0, high=2, size=(256,))
            ref_labels = torch.randint(low=0, high=2, size=(256,)) if with_ref_labels else None

            a,_,n = lmu.get_all_triplets_indices(labels, ref_labels)
            if with_ref_labels:
                all_an_dist = torch.nn.functional.pairwise_distance(embeddings[a], ref_embeddings[n], 2)
            else:
                all_an_dist = torch.nn.functional.pairwise_distance(embeddings[a], embeddings[n], 2)
            min_an_dist = torch.min(all_an_dist)
            
            for non_zero_cutoff_int in range(5, 15):
                non_zero_cutoff = (float(non_zero_cutoff_int) / 10.) - 0.01
                miner = DistanceWeightedMiner(0, non_zero_cutoff)
                a, p, n = miner(embeddings, labels, ref_embeddings, ref_labels)
                if with_ref_labels:
                    anchors, positives, negatives = embeddings[a], ref_embeddings[p], ref_embeddings[n]
                else:
                    anchors, positives, negatives = embeddings[a], embeddings[p], embeddings[n]
                an_dist = torch.nn.functional.pairwise_distance(anchors, negatives, 2)
                self.assertTrue(torch.max(an_dist)<=non_zero_cutoff)
                an_dist_var = torch.var(an_dist)
                an_dist_mean = torch.mean(an_dist)
                target_var = ((non_zero_cutoff - min_an_dist)**2) / 12 # variance formula for uniform distribution
                target_mean = (non_zero_cutoff - min_an_dist) / 2
                self.assertTrue(torch.abs(an_dist_var-target_var)/target_var < 0.1)
                self.assertTrue(torch.abs(an_dist_mean-target_mean)/target_mean < 0.1)
    def test_get_all_pairs_triplets_indices(self):
        original_x = torch.arange(10)

        for i in range(1, 11):
            x = original_x.repeat(i)
            correct_num_pos = len(x)*(i-1)
            correct_num_neg = len(x)*(len(x)-i)
            a1, p, a2, n = lmu.get_all_pairs_indices(x)
            self.assertTrue(len(a1) == len(p) == correct_num_pos)
            self.assertTrue(len(a2) == len(n) == correct_num_neg)

            correct_num_triplets = len(x)*(i-1)*(len(x)-i)
            a, p, n = lmu.get_all_triplets_indices(x)
            self.assertTrue(len(a) == len(p) == len(n) == correct_num_triplets)
def train_eval(args, train_data, dev_data, positions):
    _bbox_collate_fn = partial(bbox_collate_fn, max_bb_num=args.bb_num)

    # Create dataset & dataloader
    trans = [
        PadResize(224),
        transforms.RandomRotation(degrees=args.aug_rot),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=args.aug_erase_p,
                                 scale=(args.aug_erase_min,
                                        args.aug_erase_max))
    ]
    trans = transforms.Compose(trans)
    dev_trans = [
        PadResize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    dev_trans = transforms.Compose(dev_trans)

    train_dataset, train_char_idx = \
        create_datasetBB(args.root, train_data, positions,
            post_crop_transform=trans,
            collate_fn_bbox=_bbox_collate_fn,
            bbox_scale=args.bb_scale)

    train_sampler = MetricBatchSampler(train_dataset,
                                       train_char_idx,
                                       n_max_per_char=args.n_max_per_char,
                                       n_batch_size=args.n_batch_size,
                                       n_random=args.n_random)
    train_dataloader = DataLoader(train_dataset,
                                  batch_sampler=train_sampler,
                                  batch_size=1,
                                  num_workers=5)

    eval_train_dataloaders = \
        prepare_evaluation_dataloadersBB(args, args.eval_split*3, train_data, positions,
            post_crop_transform=dev_trans,
            collate_fn_bbox=_bbox_collate_fn,
            bbox_scale=args.bb_scale
        )
    eval_dev_dataloaders = \
        prepare_evaluation_dataloadersBB(args, args.eval_split, dev_data, positions,
            post_crop_transform=dev_trans,
            collate_fn_bbox=_bbox_collate_fn,
            bbox_scale=args.bb_scale
        )

    # Construct model & optimizer
    device = "cpu" if args.gpu < 0 else "cuda:{}".format(args.gpu)

    trunk, model = create_models(args.emb_dim, args.dropout)
    trunk.to(device)
    model.to(device)

    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(list(trunk.parameters()) +
                                    list(model.parameters()),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.decay)
    elif args.optimizer == "Adam":
        optimizer = torch.optim.Adam(list(trunk.parameters()) +
                                     list(model.parameters()),
                                     lr=args.lr,
                                     weight_decay=args.decay)
    elif args.optimizer == "RAdam":
        optimizer = RAdam(list(trunk.parameters()) + list(model.parameters()),
                          lr=args.lr,
                          weight_decay=args.decay)

    def lr_func(step):
        if step < args.warmup:
            return (step + 1) / args.warmup
        else:
            steps_decay = step // args.decay_freq
            return 1 / args.decay_factor**steps_decay

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_func)
    if args.optimizer == "RAdam":
        scheduler = None

    best_dev_eer = 1.0
    for i_epoch in range(args.epoch):
        logger.info(f"EPOCH: {i_epoch}")

        bar = tqdm(total=len(train_dataloader), smoothing=0.0)
        for (img, mask), labels in train_dataloader:
            optimizer.zero_grad()

            img, mask = img.to(device), mask.to(device)
            embedding = model(trunk([img, mask]))

            a_idx, p_idx, n_idx = get_all_triplets_indices(labels)
            if a_idx.size(0) == 0:
                logger.info("Zero triplet. Skip.")
                continue
            anchors, positives, negatives = embedding[a_idx], embedding[
                p_idx], embedding[n_idx]
            a_p_dist = -sim_func(anchors, positives)
            a_n_dist = -sim_func(anchors, negatives)

            dist = a_p_dist - a_n_dist
            loss_modified = dist + args.margin
            relued = torch.nn.functional.relu(loss_modified)
            num_non_zero_triplets = (relued > 0).nonzero().size(0)
            if num_non_zero_triplets > 0:
                loss = torch.sum(relued) / num_non_zero_triplets
                loss.backward()
                optimizer.step()

            if scheduler is not None:
                scheduler.step()
            bar.update()
        bar.close()

        if i_epoch % args.eval_freq == 0:
            train_eer, train_eer_std = evaluate(args,
                                                trunk,
                                                model,
                                                eval_train_dataloaders,
                                                sim_func=sim_func_pair)
            dev_eer, dev_eer_std = evaluate(args,
                                            trunk,
                                            model,
                                            eval_dev_dataloaders,
                                            sim_func=sim_func_pair)
            logger.info("Train EER (mean, std):\t{}\t{}".format(
                train_eer, train_eer_std))
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                dev_eer, dev_eer_std))
            if dev_eer < best_dev_eer:
                logger.info("New best model!")
                best_dev_eer = dev_eer

                if args.save_model:
                    save_models = {
                        "trunk": trunk.state_dict(),
                        "embedder": model.state_dict(),
                        "args": [args.emb_dim, args.dropout]
                    }
                    torch.save(save_models, f"model/{args.suffix}.mdl")

    return best_dev_eer
    def test_per_anchor_reducer(self):
        for inner_reducer in [MeanReducer(), AvgNonZeroReducer()]:
            reducer = PerAnchorReducer(inner_reducer)
            batch_size = 100
            embedding_size = 64
            for dtype in TEST_DTYPES:
                embeddings = (
                    torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE)
                )
                labels = torch.randint(0, 10, (batch_size,))
                pos_pair_indices = lmu.get_all_pairs_indices(labels)[:2]
                neg_pair_indices = lmu.get_all_pairs_indices(labels)[2:]
                triplet_indices = lmu.get_all_triplets_indices(labels)

                for indices, reduction_type in [
                    (torch.arange(batch_size), "element"),
                    (pos_pair_indices, "pos_pair"),
                    (neg_pair_indices, "neg_pair"),
                    (triplet_indices, "triplet"),
                ]:
                    loss_size = (
                        len(indices) if reduction_type == "element" else len(indices[0])
                    )
                    losses = torch.randn(loss_size).type(dtype).to(TEST_DEVICE)
                    loss_dict = {
                        "loss": {
                            "losses": losses,
                            "indices": indices,
                            "reduction_type": reduction_type,
                        }
                    }
                    if reduction_type == "triplet":
                        self.assertRaises(
                            NotImplementedError,
                            lambda: reducer(loss_dict, embeddings, labels),
                        )
                        continue

                    output = reducer(loss_dict, embeddings, labels)
                    if reduction_type == "element":
                        loss_dict = {
                            "loss": {
                                "losses": losses,
                                "indices": c_f.torch_arange_from_size(embeddings),
                                "reduction_type": "element",
                            }
                        }
                    else:
                        anchors = indices[0]
                        correct_output = torch.zeros(
                            batch_size, device=TEST_DEVICE, dtype=dtype
                        )
                        for i in range(len(embeddings)):
                            matching_pairs_mask = anchors == i
                            num_matching_pairs = torch.sum(matching_pairs_mask)
                            if num_matching_pairs > 0:
                                correct_output[i] = (
                                    torch.sum(losses[matching_pairs_mask])
                                    / num_matching_pairs
                                )
                        loss_dict = {
                            "loss": {
                                "losses": correct_output,
                                "indices": c_f.torch_arange_from_size(embeddings),
                                "reduction_type": "element",
                            }
                        }
                    correct_output = inner_reducer(loss_dict, embeddings, labels)
                    self.assertTrue(torch.isclose(output, correct_output))