def _gather_interaction_matrices( cfg, local_interaction_matrix, ): """Helper function for `_calc_loss_matching` to gather computed interaction matrices.""" distributed_world_size = cfg.distributed_world_size or 1 if distributed_world_size > 1: interaction_matrix_to_send = ( torch.empty_like(local_interaction_matrix).cpu().copy_( local_interaction_matrix).detach_()) global_interaction_matrices = all_gather_list( [interaction_matrix_to_send], max_size=cfg.global_loss_buf_sz, ) global_interaction_matrix = [] for i, global_interaction_matrix_i in enumerate( global_interaction_matrices): global_interaction_matrix_i = global_interaction_matrix_i[0] if i != cfg.local_rank: global_interaction_matrix.append( global_interaction_matrix_i.to( local_interaction_matrix.device)) else: global_interaction_matrix.append(local_interaction_matrix) global_interaction_matrix = torch.cat(global_interaction_matrix, dim=1) else: global_interaction_matrix = local_interaction_matrix return global_interaction_matrix
def _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_positive_idxs, local_hard_negatives_idxs: list = None, ) -> Tuple[T, bool]: """ Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations across all the nodes. """ distributed_world_size = args.distributed_world_size or 1 if distributed_world_size > 1: q_vector_to_send = torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_() ctx_vector_to_send = torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_() global_question_ctx_vectors = all_gather_list( [q_vector_to_send, ctx_vector_to_send, local_positive_idxs, local_hard_negatives_idxs], max_size=args.global_loss_buf_sz) global_q_vector = [] global_ctxs_vector = [] # ctxs_per_question = local_ctx_vectors.size(0) positive_idx_per_question = [] hard_negatives_per_question = [] total_ctxs = 0 for i, item in enumerate(global_question_ctx_vectors): q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item if i != args.local_rank: global_q_vector.append(q_vector.to(local_q_vector.device)) global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device)) positive_idx_per_question.extend([v + total_ctxs for v in positive_idx]) hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in hard_negatives_idxs]) else: global_q_vector.append(local_q_vector) global_ctxs_vector.append(local_ctx_vectors) positive_idx_per_question.extend([v + total_ctxs for v in local_positive_idxs]) hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in local_hard_negatives_idxs]) total_ctxs += ctx_vectors.size(0) global_q_vector = torch.cat(global_q_vector, dim=0) global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0) else: global_q_vector = local_q_vector global_ctxs_vector = local_ctx_vectors positive_idx_per_question = local_positive_idxs hard_negatives_per_question = local_hard_negatives_idxs loss, is_correct = loss_function.calc(global_q_vector, global_ctxs_vector, positive_idx_per_question, hard_negatives_per_question) return loss, is_correct
def validate_biencoder_average_rank(self) -> float: """ Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset. It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params) and stores them in RAM as well as question vectors. Then the similarity scores are calculted for the entire num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion. Each question's gold passage rank in that sorted list of scores is averaged across all the questions. :return: averaged rank number """ logger.info("Retriever average rank validation ...") cfg = self.cfg self.model.eval() distributed_factor = self.distributed_factor if not self.dev_iterator: self.dev_iterator = self.get_data_iterator( cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank ) data_iterator = self.dev_iterator sub_batch_size = cfg.biencoder.val_av_rank_bsz sim_score_f = self.biencoder_loss_function.get_similarity_function() q_represenations = [] ctx_represenations = [] positive_idx_per_question = [] dataset = 0 for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): # samples += 1 if ( len(q_represenations) > cfg.biencoder.val_av_rank_max_qs / distributed_factor ): break if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch ds_cfg = self.ds_cfg.dev_datasets[dataset] # Biencoder data config biencoder_data_config = BiEncoderDataConfig( insert_title=True, num_hard_negatives=cfg.biencoder.hard_negatives, num_other_negatives=cfg.biencoder.other_negatives, shuffle=False, shuffle_positives=ds_cfg.shuffle_positives, hard_neg_fallback=True, query_token=ds_cfg.special_token, ) # Prepare data biencoder_batch: BiEncoderBatch = create_ofa_input( mode="retriever", wiki_data=self.ds_cfg.wiki_data, tensorizer=self.tensorizer, samples=samples_batch, biencoder_config=biencoder_data_config, reader_config=None, ) total_ctxs = len(ctx_represenations) ctxs_ids = biencoder_batch.context_ids.to(cfg.device) ctxs_segments = biencoder_batch.ctx_segments.to(cfg.device) bsz = ctxs_ids.size(0) # Get the token to be used for representation selection rep_positions_q = ds_cfg.selector.get_positions( biencoder_batch.question_ids, self.tensorizer, self.model ) rep_positions_c = ds_cfg.selector.get_positions( biencoder_batch.context_ids, self.tensorizer, self.model ) # Biencoder training config biencoder_training_config = BiEncoderTrainingConfig( encoder_type=ds_cfg.encoder_type, rep_positions_q=rep_positions_q, rep_positions_c=rep_positions_c, ) # Split contexts batch into sub batches since it is supposed to be too large to be processed in one batch for j, batch_start in enumerate(range(0, bsz, sub_batch_size)): q_ids, q_segments = ( (biencoder_batch.question_ids.to(cfg.device), biencoder_batch.question_segments.to(cfg.device)) if j == 0 else (None, None) ) if j == 0 and cfg.n_gpu > 1 and q_ids.size(0) == 1: # If we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0, # otherwise the other input tensors will be split but only the first split will be called continue ctx_ids_batch = ctxs_ids[batch_start : batch_start + sub_batch_size] ctx_seg_batch = ctxs_segments[ batch_start : batch_start + sub_batch_size ] # Prepare input data biencoder_batch_j = BiEncoderBatch( question_ids=q_ids, question_segments=q_segments, context_IDs=None, # not used context_ids=ctx_ids_batch, ctx_segments=ctx_seg_batch, is_positive=None, # not used hard_negatives=None, # not used encoder_type=None, # not used ) with torch.no_grad(): biencoder_preds: BiEncoderPredictionBatch = self.forward_fn( trainer=self, mode="retriever", backward=False, step=False, biencoder_input=biencoder_batch_j, biencoder_config=biencoder_training_config, reader_inputs=None, reader_config=None, inference_only=True, ) q_dense = biencoder_preds.question_vector ctx_dense = biencoder_preds.context_vector if q_dense is not None: q_represenations.extend(q_dense.cpu().split(1, dim=0)) ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0)) batch_positive_idxs = biencoder_batch.is_positive positive_idx_per_question.extend( [total_ctxs + v for v in batch_positive_idxs] ) logger.info( "Retriever Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d", i, len(ctx_represenations), len(q_represenations), ) ctx_represenations = torch.cat(ctx_represenations, dim=0) q_represenations = torch.cat(q_represenations, dim=0) logger.info( "Retriever Av.rank validation: total q_vectors size=%s", q_represenations.size() ) logger.info( "Retriever Av.rank validation: total ctx_vectors size=%s", ctx_represenations.size() ) # Calculate cosine similarity scores q_num = q_represenations.size(0) assert q_num == len(positive_idx_per_question) scores = sim_score_f(q_represenations, ctx_represenations) values, indices = torch.sort(scores, dim=1, descending=True) rank = 0 for i, idx in enumerate(positive_idx_per_question): # aggregate the rank of the known gold passage in the sorted results for each question gold_idx = (indices[i] == idx).nonzero() rank += gold_idx.item() if distributed_factor > 1: # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank # NOTE: the set of passages is still unique for every node eval_stats = all_gather_list([rank, q_num], max_size=100) for i, item in enumerate(eval_stats): remote_rank, remote_q_num = item if i != cfg.local_rank: rank += remote_rank q_num += remote_q_num av_rank = float(rank / q_num) logger.info( "Retriever Av.rank validation: average rank %s, total questions=%d", av_rank, q_num ) return av_rank
def validate_average_rank(self) -> float: """ Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset. It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params) and stores them in RAM as well as question vectors. Then the similarity scores are calculted for the entire num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion. Each question's gold passage rank in that sorted list of scores is averaged across all the questions. :return: averaged rank number """ logger.info("Average rank validation ...") cfg = self.cfg self.biencoder.eval() distributed_factor = self.distributed_factor if not self.dev_iterator: self.dev_iterator = self.get_data_iterator( cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank ) data_iterator = self.dev_iterator sub_batch_size = cfg.train.val_av_rank_bsz sim_score_f = BiEncoderNllLoss.get_similarity_function() q_represenations = [] ctx_represenations = [] positive_idx_per_question = [] num_hard_negatives = cfg.train.val_av_rank_hard_neg num_other_negatives = cfg.train.val_av_rank_other_neg log_result_step = cfg.train.log_batch_step dataset = 0 for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): # samples += 1 if ( len(q_represenations) > cfg.train.val_av_rank_max_qs / distributed_factor ): break if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch biencoder_input = BiEncoder.create_biencoder_input2( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=False, ) total_ctxs = len(ctx_represenations) ctxs_ids = biencoder_input.context_ids ctxs_segments = biencoder_input.ctx_segments bsz = ctxs_ids.size(0) # get the token to be used for representation selection ds_cfg = self.ds_cfg.dev_datasets[dataset] encoder_type = ds_cfg.encoder_type rep_positions = ds_cfg.selector.get_positions( biencoder_input.question_ids, self.tensorizer ) # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch for j, batch_start in enumerate(range(0, bsz, sub_batch_size)): q_ids, q_segments = ( (biencoder_input.question_ids, biencoder_input.question_segments) if j == 0 else (None, None) ) if j == 0 and cfg.n_gpu > 1 and q_ids.size(0) == 1: # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0, # otherwise the other input tensors will be split but only the first split will be called continue ctx_ids_batch = ctxs_ids[batch_start : batch_start + sub_batch_size] ctx_seg_batch = ctxs_segments[ batch_start : batch_start + sub_batch_size ] q_attn_mask = self.tensorizer.get_attn_mask(q_ids) ctx_attn_mask = self.tensorizer.get_attn_mask(ctx_ids_batch) with torch.no_grad(): q_dense, ctx_dense = self.biencoder( q_ids, q_segments, q_attn_mask, ctx_ids_batch, ctx_seg_batch, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) if q_dense is not None: q_represenations.extend(q_dense.cpu().split(1, dim=0)) ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0)) batch_positive_idxs = biencoder_input.is_positive positive_idx_per_question.extend( [total_ctxs + v for v in batch_positive_idxs] ) if (i + 1) % log_result_step == 0: logger.info( "Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d", i, len(ctx_represenations), len(q_represenations), ) ctx_represenations = torch.cat(ctx_represenations, dim=0) q_represenations = torch.cat(q_represenations, dim=0) logger.info( "Av.rank validation: total q_vectors size=%s", q_represenations.size() ) logger.info( "Av.rank validation: total ctx_vectors size=%s", ctx_represenations.size() ) q_num = q_represenations.size(0) assert q_num == len(positive_idx_per_question) scores = sim_score_f(q_represenations, ctx_represenations) values, indices = torch.sort(scores, dim=1, descending=True) rank = 0 for i, idx in enumerate(positive_idx_per_question): # aggregate the rank of the known gold passage in the sorted results for each question gold_idx = (indices[i] == idx).nonzero() rank += gold_idx.item() if distributed_factor > 1: # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank # NOTE: the set of passages is still unique for every node eval_stats = all_gather_list([rank, q_num], max_size=100) for i, item in enumerate(eval_stats): remote_rank, remote_q_num = item if i != cfg.local_rank: rank += remote_rank q_num += remote_q_num av_rank = float(rank / q_num) logger.info( "Av.rank validation: average rank %s, total questions=%d", av_rank, q_num ) return av_rank
def _calc_loss_graded( cfg, loss_function, local_q_vector, local_ctx_vectors, local_positive_idxs, local_hard_negatives_idxs: list = None, local_negatives_idxs: list = None, local_related_idxs: list = None, local_highly_related_idxs: list = None, local_relations: list = None, loss_scale: float = None, ) -> Tuple[T, bool]: """ Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations across all the nodes. """ distributed_world_size = cfg.distributed_world_size or 1 if distributed_world_size > 1: # no grade change in this section. graded modifications are in else section q_vector_to_send = ( torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_() ) ctx_vector_to_send = ( torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_() ) global_question_ctx_vectors = all_gather_list( [ q_vector_to_send, ctx_vector_to_send, local_positive_idxs, local_hard_negatives_idxs, ], max_size=cfg.global_loss_buf_sz, ) global_q_vector = [] global_ctxs_vector = [] # ctxs_per_question = local_ctx_vectors.size(0) positive_idx_per_question = [] hard_negatives_per_question = [] total_ctxs = 0 for i, item in enumerate(global_question_ctx_vectors): q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item if i != cfg.local_rank: global_q_vector.append(q_vector.to(local_q_vector.device)) global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device)) positive_idx_per_question.extend([v + total_ctxs for v in positive_idx]) hard_negatives_per_question.extend( [[v + total_ctxs for v in l] for l in hard_negatives_idxs] ) else: global_q_vector.append(local_q_vector) global_ctxs_vector.append(local_ctx_vectors) positive_idx_per_question.extend( [v + total_ctxs for v in local_positive_idxs] ) hard_negatives_per_question.extend( [[v + total_ctxs for v in l] for l in local_hard_negatives_idxs] ) total_ctxs += ctx_vectors.size(0) global_q_vector = torch.cat(global_q_vector, dim=0) global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0) else: global_q_vector = local_q_vector global_ctxs_vector = local_ctx_vectors positive_idx_per_question = local_positive_idxs hard_negatives_per_question = local_hard_negatives_idxs negatives_per_question = local_negatives_idxs related_per_question = local_related_idxs highly_related_per_question = local_highly_related_idxs relations_per_question = local_relations loss, is_correct = loss_function.calc( q_vectors=global_q_vector, ctx_vectors=global_ctxs_vector, positive_idx_per_question=positive_idx_per_question, hard_negative_idx_per_question=hard_negatives_per_question, negative_idx_per_question=negatives_per_question, related_idx_per_question=related_per_question, highly_related_idx_per_question=highly_related_per_question, relations_per_question=relations_per_question, loss_scale=loss_scale, ) return loss, is_correct
def gather( cfg, local_q_vector, local_ctx_vectors, local_positive_idxs, local_hard_negatives_idxs: list = None, ): """Helper function for `_calc*` functions to gather all needed data.""" distributed_world_size = cfg.distributed_world_size or 1 if distributed_world_size > 1: q_vector_to_send = (torch.empty_like(local_q_vector).cpu().copy_( local_q_vector).detach_()) ctx_vector_to_send = (torch.empty_like(local_ctx_vectors).cpu().copy_( local_ctx_vectors).detach_()) global_question_ctx_vectors = all_gather_list( [ q_vector_to_send, ctx_vector_to_send, local_positive_idxs, local_hard_negatives_idxs, ], max_size=cfg.global_loss_buf_sz, ) global_q_vector = [] global_ctxs_vector = [] # ctxs_per_question = local_ctx_vectors.size(0) positive_idx_per_question = [] hard_negatives_per_question = [] total_ctxs = 0 for i, item in enumerate(global_question_ctx_vectors): q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item if i != cfg.local_rank: global_q_vector.append(q_vector.to(local_q_vector.device)) global_ctxs_vector.append(ctx_vectors.to( local_q_vector.device)) positive_idx_per_question.extend( [v + total_ctxs for v in positive_idx]) hard_negatives_per_question.extend( [[v + total_ctxs for v in l] for l in hard_negatives_idxs]) else: global_q_vector.append(local_q_vector) global_ctxs_vector.append(local_ctx_vectors) positive_idx_per_question.extend( [v + total_ctxs for v in local_positive_idxs]) hard_negatives_per_question.extend( [[v + total_ctxs for v in l] for l in local_hard_negatives_idxs]) total_ctxs += ctx_vectors.size(0) global_q_vector = torch.cat(global_q_vector, dim=0) global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0) else: global_q_vector = local_q_vector global_ctxs_vector = local_ctx_vectors positive_idx_per_question = local_positive_idxs hard_negatives_per_question = local_hard_negatives_idxs return global_q_vector, global_ctxs_vector, positive_idx_per_question, hard_negatives_per_question