def _do_biencoder_fwd_pass( model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, cfg, encoder_type: str, rep_positions=0, loss_scale: float = None, ) -> Tuple[torch.Tensor, int]: input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) else: with torch.no_grad(): model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) local_q_vector, local_ctx_vectors = model_out loss_function = BiEncoderNllLoss() loss, is_correct = _calc_loss( cfg, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) is_correct = is_correct.sum().item() if cfg.n_gpu > 1: loss = loss.mean() if cfg.train.gradient_accumulation_steps > 1: loss = loss / cfg.gradient_accumulation_steps return loss, is_correct
def _do_biencoder_fwd_pass(model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, args) -> (torch.Tensor, int): input = BiEncoderBatch(**move_to_device(input._asdict(), args.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask) else: with torch.no_grad(): model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask) local_q_vector, local_ctx_vectors = model_out loss_function = BiEncoderNllLoss() loss, is_correct = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives) is_correct = is_correct.sum().item() if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps return loss, is_correct
def gen_ctx_vectors(ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = args.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in ctx_rows[batch_start:batch_start + bsz]] ctx_ids_batch = torch.stack(batch_token_tensors, dim=0) ctx_seg_batch = torch.zeros_like(ctx_ids_batch) ctx_attn_mask = tensorizer.get_attn_mask(ctx_ids_batch) with torch.no_grad(): _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) out = out.cpu() ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]] assert len(ctx_ids) == out.size(0) total += len(ctx_ids) results.extend([ (ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0)) ]) if total % 10 == 0: logger.info('Encoded passages %d', total) return results
def generate_question_vectors( question_encoder: torch.nn.Module, tensorizer: Tensorizer, questions: List[str], bsz: int, query_token: str = None, selector: RepTokenSelector = None, ) -> T: n = len(questions) query_vectors = [] with torch.no_grad(): for j, batch_start in enumerate(range(0, n, bsz)): batch_questions = questions[batch_start:batch_start + bsz] if query_token: if query_token == "[START_ENT]": batch_token_tensors = [ _select_span_with_token(q, tensorizer, token_str=query_token) for q in batch_questions ] else: batch_token_tensors = [ tensorizer.text_to_tensor(" ".join([query_token, q])) for q in batch_questions ] else: batch_token_tensors = [ tensorizer.text_to_tensor(q) for q in batch_questions ] q_ids_batch = torch.stack(batch_token_tensors, dim=0).cuda() q_seg_batch = torch.zeros_like(q_ids_batch).cuda() q_attn_mask = tensorizer.get_attn_mask(q_ids_batch) if selector: rep_positions = selector.get_positions(q_ids_batch, tensorizer) _, out, _ = BiEncoder.get_representation( question_encoder, q_ids_batch, q_seg_batch, q_attn_mask, representation_token_pos=rep_positions, ) else: _, out, _ = question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) query_vectors.extend(out.cpu().split(1, dim=0)) if len(query_vectors) % 100 == 0: logger.info("Encoded queries %d", len(query_vectors)) query_tensor = torch.cat(query_vectors, dim=0) logger.info("Total encoded queries tensor %s", query_tensor.size()) assert query_tensor.size(0) == len(questions) return query_tensor
def gen_ctx_vectors( cfg: DictConfig, ctx_rows: List[Tuple[object, BiEncoderPassage]], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True, ) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = cfg.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): batch = ctx_rows[batch_start : batch_start + bsz] batch_token_tensors = [ tensorizer.text_to_tensor( ctx[1].text, title=ctx[1].title if insert_title else None ) for ctx in batch ] ctx_ids_batch = move_to_device( torch.stack(batch_token_tensors, dim=0), cfg.device ) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device) ctx_attn_mask = move_to_device( tensorizer.get_attn_mask(ctx_ids_batch), cfg.device ) with torch.no_grad(): _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) out = out.cpu() ctx_ids = [r[0] for r in batch] extra_info = [] if len(batch[0]) > 3: extra_info = [r[3:] for r in batch] assert len(ctx_ids) == out.size(0) total += len(ctx_ids) # TODO: refactor to avoid 'if' if extra_info: results.extend( [ (ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i]) for i in range(out.size(0)) ] ) else: results.extend( [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))] ) if total % 10 == 0: logger.info("Encoded passages %d", total) return results
def gen_ctx_vectors( ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = args.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): all_txt = [] for ctx in ctx_rows[batch_start:batch_start + bsz]: if ctx[2]: txt = ['title:', ctx[2], 'context:', ctx[1]] else: txt = ['context:', ctx[1]] txt = ' '.join(txt) all_txt.append(txt) batch_token_tensors = [ tensorizer.text_to_tensor(txt, max_length=250) for txt in all_txt ] #batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in #original # ctx_rows[batch_start:batch_start + bsz]] #original ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0), args.device) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), args.device) ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), args.device) with torch.no_grad(): _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) out = out.cpu() ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]] assert len(ctx_ids) == out.size(0) total += len(ctx_ids) #results.extend([ # (ctx_ids[i], out[i].view(-1).numpy()) # for i in range(out.size(0)) #]) results.extend([(ctx_ids[i], out[i].numpy()) for i in range(out.size(0))]) if total % 10 == 0: logger.info('Encoded passages %d', total) return results
def get_positions(self, input_ids: T, tenzorizer: Tensorizer, model: torch.nn.Module): attention_masks = tenzorizer.get_attn_mask(input_ids) rep_positions = [] for attention_mask in attention_masks: if model.training: input_length = (attention_mask != 0).sum() rep_position = random.randint(0, input_length - 1) rep_positions.append(rep_position) else: # Fall back to default rep_positions.append(self.static_position) rep_positions = torch.tensor(rep_positions, dtype=torch.int8).unsqueeze(-1).repeat( 1, 2) return rep_positions
def _do_biencoder_fwd_pass( model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, loss_function, cfg, encoder_type: str, rep_positions_q=0, rep_positions_c=0, loss_scale: float = None, clustering: bool = False, ) -> Tuple[torch.Tensor, int]: input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos_q=rep_positions_q, representation_token_pos_c=rep_positions_c, ) else: with torch.no_grad(): model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos_q=rep_positions_q, representation_token_pos_c=rep_positions_c, ) local_q_vector, local_ctx_vectors = model_out if cfg.others.is_matching: # MatchBiEncoder model loss, ml_is_correct, matching_is_correct = _calc_loss_matching( cfg, model, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) ml_is_correct = ml_is_correct.sum().item() matching_is_correct = matching_is_correct.sum().item() else: loss, is_correct = calc_loss( cfg, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) is_correct = is_correct.sum().item() if cfg.n_gpu > 1: loss = loss.mean() if cfg.train.gradient_accumulation_steps > 1: loss = loss / cfg.gradient_accumulation_steps if clustering: assert not cfg.others.is_matching return loss, is_correct, model_out elif cfg.others.is_matching: return loss, ml_is_correct, matching_is_correct else: return loss, is_correct
def gen_ctx_vectors( cfg: DictConfig, ctx_rows: List[Tuple[object, BiEncoderPassage]], q_rows: List[object], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True, ) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = cfg.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): # Passage preprocess # TODO; max seq length check batch = ctx_rows[batch_start:batch_start + bsz] batch_token_tensors = [ tensorizer.text_to_tensor( ctx[1].text, title=ctx[1].title if insert_title else None) for ctx in batch ] ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0), cfg.device) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device) ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), cfg.device) # Question preprocess q_batch = q_rows[batch_start:batch_start + bsz] q_batch_token_tensors = [ tensorizer.text_to_tensor(qq) for qq in q_batch ] q_ids_batch = move_to_device(torch.stack(q_batch_token_tensors, dim=0), cfg.device) q_seg_batch = move_to_device(torch.zeros_like(q_ids_batch), cfg.device) q_attn_mask = move_to_device(tensorizer.get_attn_mask(q_ids_batch), cfg.device) # Selector from dpr.data.biencoder_data import DEFAULT_SELECTOR selector = DEFAULT_SELECTOR rep_positions = selector.get_positions(q_ids_batch, tensorizer) with torch.no_grad(): q_dense, ctx_dense = model( q_ids_batch, q_seg_batch, q_attn_mask, ctx_ids_batch, ctx_seg_batch, ctx_attn_mask, representation_token_pos=rep_positions, ) q_dense = q_dense.cpu() ctx_dense = ctx_dense.cpu() ctx_ids = [r[0] for r in batch] assert len(ctx_ids) == q_dense.size(0) == ctx_dense.size(0) total += len(ctx_ids) results.extend([(ctx_ids[i], q_dense[i].numpy(), ctx_dense[i].numpy(), q_dense[i].numpy().dot(ctx_dense[i].numpy())) for i in range(q_dense.size(0))]) if total % 10 == 0: logger.info("Encoded questions / passages %d", total) # break return results
def generate_question_vectors( question_encoder: torch.nn.Module, tensorizer: Tensorizer, questions: List[str], bsz: int, query_token: str = None, selector: RepTokenSelector = None, ) -> T: n = len(questions) query_vectors = [] with torch.no_grad(): for j, batch_start in enumerate(range(0, n, bsz)): batch_questions = questions[batch_start:batch_start + bsz] if query_token: # TODO: tmp workaround for EL, remove or revise if query_token == "[START_ENT]": batch_tensors = [ _select_span_with_token(q, tensorizer, token_str=query_token) for q in batch_questions ] else: batch_tensors = [ tensorizer.text_to_tensor(" ".join([query_token, q])) for q in batch_questions ] elif isinstance(batch_questions[0], T): batch_tensors = [q for q in batch_questions] else: batch_tensors = [ tensorizer.text_to_tensor(q) for q in batch_questions ] # TODO: this only works for Wav2vec pipeline but will crash the regular text pipeline max_vector_len = max(q_t.size(1) for q_t in batch_tensors) min_vector_len = min(q_t.size(1) for q_t in batch_tensors) if max_vector_len != min_vector_len: # TODO: _pad_to_len move to utils from dpr.models.reader import _pad_to_len batch_tensors = [ _pad_to_len(q.squeeze(0), 0, max_vector_len) for q in batch_tensors ] q_ids_batch = torch.stack(batch_tensors, dim=0).cuda() q_seg_batch = torch.zeros_like(q_ids_batch).cuda() q_attn_mask = tensorizer.get_attn_mask(q_ids_batch) if selector: rep_positions = selector.get_positions(q_ids_batch, tensorizer) _, out, _ = BiEncoder.get_representation( question_encoder, q_ids_batch, q_seg_batch, q_attn_mask, representation_token_pos=rep_positions, ) else: _, out, _ = question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) query_vectors.extend(out.cpu().split(1, dim=0)) if len(query_vectors) % 100 == 0: logger.info("Encoded queries %d", len(query_vectors)) query_tensor = torch.cat(query_vectors, dim=0) logger.info("Total encoded queries tensor %s", query_tensor.size()) assert query_tensor.size(0) == len(questions) return query_tensor