def _do_biencoder_fwd_pass(
    model: nn.Module,
    input: BiEncoderBatch,
    tensorizer: Tensorizer,
    cfg,
    encoder_type: str,
    rep_positions=0,
    loss_scale: float = None,
) -> Tuple[torch.Tensor, int]:

    input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device))

    q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
    ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)

    if model.training:
        model_out = model(
            input.question_ids,
            input.question_segments,
            q_attn_mask,
            input.context_ids,
            input.ctx_segments,
            ctx_attn_mask,
            encoder_type=encoder_type,
            representation_token_pos=rep_positions,
        )
    else:
        with torch.no_grad():
            model_out = model(
                input.question_ids,
                input.question_segments,
                q_attn_mask,
                input.context_ids,
                input.ctx_segments,
                ctx_attn_mask,
                encoder_type=encoder_type,
                representation_token_pos=rep_positions,
            )

    local_q_vector, local_ctx_vectors = model_out

    loss_function = BiEncoderNllLoss()

    loss, is_correct = _calc_loss(
        cfg,
        loss_function,
        local_q_vector,
        local_ctx_vectors,
        input.is_positive,
        input.hard_negatives,
        loss_scale=loss_scale,
    )

    is_correct = is_correct.sum().item()

    if cfg.n_gpu > 1:
        loss = loss.mean()
    if cfg.train.gradient_accumulation_steps > 1:
        loss = loss / cfg.gradient_accumulation_steps
    return loss, is_correct
Esempio n. 2
0
def _do_biencoder_fwd_pass(model: nn.Module, input: BiEncoderBatch,
                           tensorizer: Tensorizer,
                           args) -> (torch.Tensor, int):
    input = BiEncoderBatch(**move_to_device(input._asdict(), args.device))

    q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
    ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)

    if model.training:
        model_out = model(input.question_ids, input.question_segments,
                          q_attn_mask, input.context_ids, input.ctx_segments,
                          ctx_attn_mask)
    else:
        with torch.no_grad():
            model_out = model(input.question_ids, input.question_segments,
                              q_attn_mask, input.context_ids,
                              input.ctx_segments, ctx_attn_mask)

    local_q_vector, local_ctx_vectors = model_out

    loss_function = BiEncoderNllLoss()

    loss, is_correct = _calc_loss(args, loss_function, local_q_vector,
                                  local_ctx_vectors, input.is_positive,
                                  input.hard_negatives)

    is_correct = is_correct.sum().item()

    if args.n_gpu > 1:
        loss = loss.mean()
    if args.gradient_accumulation_steps > 1:
        loss = loss / args.gradient_accumulation_steps

    return loss, is_correct
def gen_ctx_vectors(ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer,
                    insert_title: bool = True) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = args.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):

        batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in
                               ctx_rows[batch_start:batch_start + bsz]]

        ctx_ids_batch = torch.stack(batch_token_tensors, dim=0)
        ctx_seg_batch = torch.zeros_like(ctx_ids_batch)
        ctx_attn_mask = tensorizer.get_attn_mask(ctx_ids_batch)
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]]

        assert len(ctx_ids) == out.size(0)

        total += len(ctx_ids)

        results.extend([
            (ctx_ids[i], out[i].view(-1).numpy())
            for i in range(out.size(0))
        ])

        if total % 10 == 0:
            logger.info('Encoded passages %d', total)

    return results
def generate_question_vectors(
    question_encoder: torch.nn.Module,
    tensorizer: Tensorizer,
    questions: List[str],
    bsz: int,
    query_token: str = None,
    selector: RepTokenSelector = None,
) -> T:
    n = len(questions)
    query_vectors = []

    with torch.no_grad():
        for j, batch_start in enumerate(range(0, n, bsz)):
            batch_questions = questions[batch_start:batch_start + bsz]

            if query_token:
                if query_token == "[START_ENT]":
                    batch_token_tensors = [
                        _select_span_with_token(q,
                                                tensorizer,
                                                token_str=query_token)
                        for q in batch_questions
                    ]
                else:
                    batch_token_tensors = [
                        tensorizer.text_to_tensor(" ".join([query_token, q]))
                        for q in batch_questions
                    ]
            else:
                batch_token_tensors = [
                    tensorizer.text_to_tensor(q) for q in batch_questions
                ]

            q_ids_batch = torch.stack(batch_token_tensors, dim=0).cuda()
            q_seg_batch = torch.zeros_like(q_ids_batch).cuda()
            q_attn_mask = tensorizer.get_attn_mask(q_ids_batch)

            if selector:
                rep_positions = selector.get_positions(q_ids_batch, tensorizer)

                _, out, _ = BiEncoder.get_representation(
                    question_encoder,
                    q_ids_batch,
                    q_seg_batch,
                    q_attn_mask,
                    representation_token_pos=rep_positions,
                )
            else:
                _, out, _ = question_encoder(q_ids_batch, q_seg_batch,
                                             q_attn_mask)

            query_vectors.extend(out.cpu().split(1, dim=0))

            if len(query_vectors) % 100 == 0:
                logger.info("Encoded queries %d", len(query_vectors))

    query_tensor = torch.cat(query_vectors, dim=0)
    logger.info("Total encoded queries tensor %s", query_tensor.size())
    assert query_tensor.size(0) == len(questions)
    return query_tensor
Esempio n. 5
0
def gen_ctx_vectors(
    cfg: DictConfig,
    ctx_rows: List[Tuple[object, BiEncoderPassage]],
    model: nn.Module,
    tensorizer: Tensorizer,
    insert_title: bool = True,
) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = cfg.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):
        batch = ctx_rows[batch_start : batch_start + bsz]
        batch_token_tensors = [
            tensorizer.text_to_tensor(
                ctx[1].text, title=ctx[1].title if insert_title else None
            )
            for ctx in batch
        ]

        ctx_ids_batch = move_to_device(
            torch.stack(batch_token_tensors, dim=0), cfg.device
        )
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device)
        ctx_attn_mask = move_to_device(
            tensorizer.get_attn_mask(ctx_ids_batch), cfg.device
        )
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in batch]
        extra_info = []
        if len(batch[0]) > 3:
            extra_info = [r[3:] for r in batch]

        assert len(ctx_ids) == out.size(0)
        total += len(ctx_ids)

        # TODO: refactor to avoid 'if'
        if extra_info:
            results.extend(
                [
                    (ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i])
                    for i in range(out.size(0))
                ]
            )
        else:
            results.extend(
                [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))]
            )

        if total % 10 == 0:
            logger.info("Encoded passages %d", total)
    return results
Esempio n. 6
0
def gen_ctx_vectors(
        ctx_rows: List[Tuple[object, str, str]],
        model: nn.Module,
        tensorizer: Tensorizer,
        insert_title: bool = True) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = args.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):

        all_txt = []
        for ctx in ctx_rows[batch_start:batch_start + bsz]:
            if ctx[2]:
                txt = ['title:', ctx[2], 'context:', ctx[1]]
            else:
                txt = ['context:', ctx[1]]
            txt = ' '.join(txt)
            all_txt.append(txt)
        batch_token_tensors = [
            tensorizer.text_to_tensor(txt, max_length=250) for txt in all_txt
        ]
        #batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in #original
        #                       ctx_rows[batch_start:batch_start + bsz]]                                             #original

        ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0),
                                       args.device)
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch),
                                       args.device)
        ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch),
                                       args.device)
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]]

        assert len(ctx_ids) == out.size(0)

        total += len(ctx_ids)

        #results.extend([
        #    (ctx_ids[i], out[i].view(-1).numpy())
        #    for i in range(out.size(0))
        #])

        results.extend([(ctx_ids[i], out[i].numpy())
                        for i in range(out.size(0))])

        if total % 10 == 0:
            logger.info('Encoded passages %d', total)

    return results
Esempio n. 7
0
    def get_positions(self, input_ids: T, tenzorizer: Tensorizer,
                      model: torch.nn.Module):
        attention_masks = tenzorizer.get_attn_mask(input_ids)
        rep_positions = []

        for attention_mask in attention_masks:
            if model.training:
                input_length = (attention_mask != 0).sum()
                rep_position = random.randint(0, input_length - 1)
                rep_positions.append(rep_position)
            else:
                # Fall back to default
                rep_positions.append(self.static_position)
        rep_positions = torch.tensor(rep_positions,
                                     dtype=torch.int8).unsqueeze(-1).repeat(
                                         1, 2)
        return rep_positions
Esempio n. 8
0
def _do_biencoder_fwd_pass(
    model: nn.Module,
    input: BiEncoderBatch,
    tensorizer: Tensorizer,
    loss_function,
    cfg,
    encoder_type: str,
    rep_positions_q=0,
    rep_positions_c=0,
    loss_scale: float = None,
    clustering: bool = False,
) -> Tuple[torch.Tensor, int]:

    input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device))

    q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
    ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)

    if model.training:
        model_out = model(
            input.question_ids,
            input.question_segments,
            q_attn_mask,
            input.context_ids,
            input.ctx_segments,
            ctx_attn_mask,
            encoder_type=encoder_type,
            representation_token_pos_q=rep_positions_q,
            representation_token_pos_c=rep_positions_c,
        )
    else:
        with torch.no_grad():
            model_out = model(
                input.question_ids,
                input.question_segments,
                q_attn_mask,
                input.context_ids,
                input.ctx_segments,
                ctx_attn_mask,
                encoder_type=encoder_type,
                representation_token_pos_q=rep_positions_q,
                representation_token_pos_c=rep_positions_c,
            )

    local_q_vector, local_ctx_vectors = model_out

    if cfg.others.is_matching:  # MatchBiEncoder model
        loss, ml_is_correct, matching_is_correct = _calc_loss_matching(
            cfg,
            model,
            loss_function,
            local_q_vector,
            local_ctx_vectors,
            input.is_positive,
            input.hard_negatives,
            loss_scale=loss_scale,
        )
        ml_is_correct = ml_is_correct.sum().item()
        matching_is_correct = matching_is_correct.sum().item()
    else:
        loss, is_correct = calc_loss(
            cfg,
            loss_function,
            local_q_vector,
            local_ctx_vectors,
            input.is_positive,
            input.hard_negatives,
            loss_scale=loss_scale,
        )
        is_correct = is_correct.sum().item()

    if cfg.n_gpu > 1:
        loss = loss.mean()
    if cfg.train.gradient_accumulation_steps > 1:
        loss = loss / cfg.gradient_accumulation_steps

    if clustering:
        assert not cfg.others.is_matching
        return loss, is_correct, model_out
    elif cfg.others.is_matching:
        return loss, ml_is_correct, matching_is_correct
    else:
        return loss, is_correct
Esempio n. 9
0
def gen_ctx_vectors(
    cfg: DictConfig,
    ctx_rows: List[Tuple[object, BiEncoderPassage]],
    q_rows: List[object],
    model: nn.Module,
    tensorizer: Tensorizer,
    insert_title: bool = True,
) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = cfg.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):
        # Passage preprocess # TODO; max seq length check
        batch = ctx_rows[batch_start:batch_start + bsz]
        batch_token_tensors = [
            tensorizer.text_to_tensor(
                ctx[1].text, title=ctx[1].title if insert_title else None)
            for ctx in batch
        ]

        ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0),
                                       cfg.device)
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch),
                                       cfg.device)
        ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch),
                                       cfg.device)

        # Question preprocess
        q_batch = q_rows[batch_start:batch_start + bsz]
        q_batch_token_tensors = [
            tensorizer.text_to_tensor(qq) for qq in q_batch
        ]

        q_ids_batch = move_to_device(torch.stack(q_batch_token_tensors, dim=0),
                                     cfg.device)
        q_seg_batch = move_to_device(torch.zeros_like(q_ids_batch), cfg.device)
        q_attn_mask = move_to_device(tensorizer.get_attn_mask(q_ids_batch),
                                     cfg.device)

        # Selector
        from dpr.data.biencoder_data import DEFAULT_SELECTOR
        selector = DEFAULT_SELECTOR
        rep_positions = selector.get_positions(q_ids_batch, tensorizer)

        with torch.no_grad():
            q_dense, ctx_dense = model(
                q_ids_batch,
                q_seg_batch,
                q_attn_mask,
                ctx_ids_batch,
                ctx_seg_batch,
                ctx_attn_mask,
                representation_token_pos=rep_positions,
            )
        q_dense = q_dense.cpu()
        ctx_dense = ctx_dense.cpu()
        ctx_ids = [r[0] for r in batch]

        assert len(ctx_ids) == q_dense.size(0) == ctx_dense.size(0)
        total += len(ctx_ids)

        results.extend([(ctx_ids[i], q_dense[i].numpy(), ctx_dense[i].numpy(),
                         q_dense[i].numpy().dot(ctx_dense[i].numpy()))
                        for i in range(q_dense.size(0))])

        if total % 10 == 0:
            logger.info("Encoded questions / passages %d", total)
            # break
    return results
Esempio n. 10
0
def generate_question_vectors(
    question_encoder: torch.nn.Module,
    tensorizer: Tensorizer,
    questions: List[str],
    bsz: int,
    query_token: str = None,
    selector: RepTokenSelector = None,
) -> T:
    n = len(questions)
    query_vectors = []

    with torch.no_grad():
        for j, batch_start in enumerate(range(0, n, bsz)):
            batch_questions = questions[batch_start:batch_start + bsz]

            if query_token:
                # TODO: tmp workaround for EL, remove or revise
                if query_token == "[START_ENT]":
                    batch_tensors = [
                        _select_span_with_token(q,
                                                tensorizer,
                                                token_str=query_token)
                        for q in batch_questions
                    ]
                else:
                    batch_tensors = [
                        tensorizer.text_to_tensor(" ".join([query_token, q]))
                        for q in batch_questions
                    ]
            elif isinstance(batch_questions[0], T):
                batch_tensors = [q for q in batch_questions]
            else:
                batch_tensors = [
                    tensorizer.text_to_tensor(q) for q in batch_questions
                ]

            # TODO: this only works for Wav2vec pipeline but will crash the regular text pipeline
            max_vector_len = max(q_t.size(1) for q_t in batch_tensors)
            min_vector_len = min(q_t.size(1) for q_t in batch_tensors)

            if max_vector_len != min_vector_len:
                # TODO: _pad_to_len move to utils
                from dpr.models.reader import _pad_to_len
                batch_tensors = [
                    _pad_to_len(q.squeeze(0), 0, max_vector_len)
                    for q in batch_tensors
                ]

            q_ids_batch = torch.stack(batch_tensors, dim=0).cuda()
            q_seg_batch = torch.zeros_like(q_ids_batch).cuda()
            q_attn_mask = tensorizer.get_attn_mask(q_ids_batch)

            if selector:
                rep_positions = selector.get_positions(q_ids_batch, tensorizer)

                _, out, _ = BiEncoder.get_representation(
                    question_encoder,
                    q_ids_batch,
                    q_seg_batch,
                    q_attn_mask,
                    representation_token_pos=rep_positions,
                )
            else:
                _, out, _ = question_encoder(q_ids_batch, q_seg_batch,
                                             q_attn_mask)

            query_vectors.extend(out.cpu().split(1, dim=0))

            if len(query_vectors) % 100 == 0:
                logger.info("Encoded queries %d", len(query_vectors))

    query_tensor = torch.cat(query_vectors, dim=0)
    logger.info("Total encoded queries tensor %s", query_tensor.size())
    assert query_tensor.size(0) == len(questions)
    return query_tensor