Exemplo n.º 1
0
    def __iter__(self):
        """ this iterator will run indefinitely """
        task = self.sampling_pools[0]
        while True:
            if self.step % self.accum_steps == 0:
                task = random.choice(self.sampling_pools)
                if self.distributed:
                    # make sure all process is training same task
                    task = any_broadcast(task, 0)
            self.step += 1
            iter_ = self.name2iter[task]
            try:
                batch = next(iter_)
            except StopIteration:
                iter_ = iter(self.name2loader[task])
                batch = next(iter_)
                self.name2iter[task] = iter_

            yield task, batch
Exemplo n.º 2
0
Arquivo: itm.py Projeto: zmykevin/UC2
def get_hard_negs(model, loader, hard_negative_num=20):
    LOGGER.info("start running hard negative extraction")
    st = time()
    if hvd.rank() == 0:
        pbar = tqdm(total=len(loader))
    else:
        pbar = NoOp()
    model.eval()

    txt2hardimgs = {}
    img_to_score_txts = defaultdict(list)
    for batch in loader:
        scores = model(batch, compute_loss=False).squeeze(-1)
        txt = batch['gt_txt_id']
        imgs = batch['neg_img_ids']
        # record hard images
        hard_indices = scores.topk(hard_negative_num, sorted=False)[1].tolist()
        txt2hardimgs[txt] = [imgs[i] for i in hard_indices]
        # record img2txts
        for i, img in enumerate(imgs):
            img_to_score_txts[img].append((scores[i].item(), txt))
        pbar.update(1)
    pbar.close()

    LOGGER.info("start computing hard texts from images...")
    n_less_neg = 0
    tot_text = 0
    img2hardtxts = {}
    # need to gather hard texts from all GPUs
    all_img_ids = [
        i for dset in loader.dataset.datasets for i in dset.all_img_ids
    ]
    all_img_ids = any_broadcast(all_img_ids, 0)
    for img in all_img_ids:
        score_txts = img_to_score_txts[img]
        scores, txts = map(
            list,
            unzip(pair for pairs in all_gather_list(score_txts)
                  for pair in pairs))
        if hvd.rank() != 0:
            # only rank 0 needs to compute
            continue
        tot_text += len(txts)
        if len(txts) < hard_negative_num:
            # not enough negatives
            hard_indices = range(len(txts))
            n_less_neg += 1
        else:
            hard_indices = torch.tensor(scores).topk(hard_negative_num,
                                                     sorted=False)[1].tolist()
        img2hardtxts[img] = [txts[i] for i in hard_indices]

    n_less_neg = sum(all_gather_list(n_less_neg))
    if n_less_neg:
        LOGGER.info(f"Warning: {n_less_neg} images did not "
                    f"sample enough negatives")
    LOGGER.info(f"hard negative extraction finished "
                f"in {int(time() - st)} seconds "
                f"({tot_text//len(img_to_score_txts)} texts per images)")

    model.train()
    return txt2hardimgs, img2hardtxts