def _do_biencoder_fwd_pass( model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, cfg, encoder_type: str, rep_positions=0, loss_scale: float = None, ) -> Tuple[torch.Tensor, int]: input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) else: with torch.no_grad(): model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) local_q_vector, local_ctx_vectors = model_out loss_function = BiEncoderNllLoss() loss, is_correct = _calc_loss( cfg, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) is_correct = is_correct.sum().item() if cfg.n_gpu > 1: loss = loss.mean() if cfg.train.gradient_accumulation_steps > 1: loss = loss / cfg.gradient_accumulation_steps return loss, is_correct
def _do_biencoder_fwd_pass(model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, args) -> (torch.Tensor, int): input = BiEncoderBatch(**move_to_device(input._asdict(), args.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask) else: with torch.no_grad(): model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask) local_q_vector, local_ctx_vectors = model_out loss_function = BiEncoderNllLoss() loss, is_correct = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives) is_correct = is_correct.sum().item() if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps return loss, is_correct
def validate_average_rank(self) -> float: """ Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset. It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params) and stores them in RAM as well as question vectors. Then the similarity scores are calculted for the entire num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion. Each question's gold passage rank in that sorted list of scores is averaged across all the questions. :return: averaged rank number """ logger.info("Average rank validation ...") cfg = self.cfg self.biencoder.eval() distributed_factor = self.distributed_factor if not self.dev_iterator: self.dev_iterator = self.get_data_iterator( cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank ) data_iterator = self.dev_iterator sub_batch_size = cfg.train.val_av_rank_bsz sim_score_f = BiEncoderNllLoss.get_similarity_function() q_represenations = [] ctx_represenations = [] positive_idx_per_question = [] num_hard_negatives = cfg.train.val_av_rank_hard_neg num_other_negatives = cfg.train.val_av_rank_other_neg log_result_step = cfg.train.log_batch_step dataset = 0 for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): # samples += 1 if ( len(q_represenations) > cfg.train.val_av_rank_max_qs / distributed_factor ): break if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch biencoder_input = BiEncoder.create_biencoder_input2( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=False, ) total_ctxs = len(ctx_represenations) ctxs_ids = biencoder_input.context_ids ctxs_segments = biencoder_input.ctx_segments bsz = ctxs_ids.size(0) # get the token to be used for representation selection ds_cfg = self.ds_cfg.dev_datasets[dataset] encoder_type = ds_cfg.encoder_type rep_positions = ds_cfg.selector.get_positions( biencoder_input.question_ids, self.tensorizer ) # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch for j, batch_start in enumerate(range(0, bsz, sub_batch_size)): q_ids, q_segments = ( (biencoder_input.question_ids, biencoder_input.question_segments) if j == 0 else (None, None) ) if j == 0 and cfg.n_gpu > 1 and q_ids.size(0) == 1: # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0, # otherwise the other input tensors will be split but only the first split will be called continue ctx_ids_batch = ctxs_ids[batch_start : batch_start + sub_batch_size] ctx_seg_batch = ctxs_segments[ batch_start : batch_start + sub_batch_size ] q_attn_mask = self.tensorizer.get_attn_mask(q_ids) ctx_attn_mask = self.tensorizer.get_attn_mask(ctx_ids_batch) with torch.no_grad(): q_dense, ctx_dense = self.biencoder( q_ids, q_segments, q_attn_mask, ctx_ids_batch, ctx_seg_batch, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) if q_dense is not None: q_represenations.extend(q_dense.cpu().split(1, dim=0)) ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0)) batch_positive_idxs = biencoder_input.is_positive positive_idx_per_question.extend( [total_ctxs + v for v in batch_positive_idxs] ) if (i + 1) % log_result_step == 0: logger.info( "Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d", i, len(ctx_represenations), len(q_represenations), ) ctx_represenations = torch.cat(ctx_represenations, dim=0) q_represenations = torch.cat(q_represenations, dim=0) logger.info( "Av.rank validation: total q_vectors size=%s", q_represenations.size() ) logger.info( "Av.rank validation: total ctx_vectors size=%s", ctx_represenations.size() ) q_num = q_represenations.size(0) assert q_num == len(positive_idx_per_question) scores = sim_score_f(q_represenations, ctx_represenations) values, indices = torch.sort(scores, dim=1, descending=True) rank = 0 for i, idx in enumerate(positive_idx_per_question): # aggregate the rank of the known gold passage in the sorted results for each question gold_idx = (indices[i] == idx).nonzero() rank += gold_idx.item() if distributed_factor > 1: # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank # NOTE: the set of passages is still unique for every node eval_stats = all_gather_list([rank, q_num], max_size=100) for i, item in enumerate(eval_stats): remote_rank, remote_q_num = item if i != cfg.local_rank: rank += remote_rank q_num += remote_q_num av_rank = float(rank / q_num) logger.info( "Av.rank validation: average rank %s, total questions=%d", av_rank, q_num ) return av_rank