def validate_nll(self) -> float: logger.info('NLL validation ...') args = self.args self.biencoder.eval() data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, shuffle=False) total_loss = 0.0 start_time = time.time() total_correct_predictions = 0 num_hard_negatives = args.hard_negatives num_other_negatives = args.other_negatives log_result_step = args.log_batch_step batches = 0 for i, samples_batch in enumerate(data_iterator.iterate_data()): biencoder_input = BiEncoder.create_biencoder_input(samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=False) loss, correct_cnt = _do_biencoder_fwd_pass(self.biencoder, biencoder_input, self.tensorizer, args) total_loss += loss.item() total_correct_predictions += correct_cnt batches += 1 if (i + 1) % log_result_step == 0: logger.info('Eval step: %d , used_time=%f sec., loss=%f ', i, time.time() - start_time, loss.item()) total_loss = total_loss / batches total_samples = batches * args.dev_batch_size * self.distributed_factor correct_ratio = float(total_correct_predictions / total_samples) logger.info('NLL Validation: loss = %f. correct prediction ratio %d/%d ~ %f', total_loss, total_correct_predictions, total_samples, correct_ratio ) return total_loss
def create_ofa_input( mode: str, wiki_data: TokenizedWikipediaPassages, tensorizer: Tensorizer, samples: List[Tuple[BiEncoderSampleTokenized, ReaderSample]], biencoder_config: BiEncoderDataConfig, reader_config: ReaderDataConfig, ) -> Union[BiEncoderBatch, List[ReaderBatch], Tuple[BiEncoderBatch, List[ReaderBatch]], ]: assert mode in ["retriever", "reader", "both"], f"Invalid mode: {mode}" retriever_samples, reader_samples = zip(*samples) # Retriever (bi-encoder) if mode in ["retriever", "both"]: biencoder_batch = BiEncoder.create_biencoder_input( samples=retriever_samples, tensorizer=tensorizer, insert_title=biencoder_config.insert_title, num_hard_negatives=biencoder_config.num_hard_negatives, num_other_negatives=biencoder_config.num_other_negatives, shuffle=biencoder_config.shuffle, shuffle_positives=biencoder_config.shuffle_positives, hard_neg_fallback=biencoder_config.hard_neg_fallback, query_token=biencoder_config.query_token, ) # Reader if mode in ["reader", "both"]: num_samples = len(samples) num_sub_batches = reader_config.num_sub_batches assert num_sub_batches > 0 sub_batch_size = math.ceil(num_samples / num_sub_batches) reader_batches: List[ReaderBatch] = [] for batch_i in range(num_sub_batches): start = batch_i * sub_batch_size end = min(start + sub_batch_size, num_samples) if start >= end: break reader_batch = create_reader_input( wiki_data=wiki_data, tensorizer=tensorizer, samples=reader_samples[start:end], passages_per_question=reader_config.passages_per_question, max_length=reader_config.max_length, max_n_answers=reader_config.max_n_answers, is_train=reader_config.is_train, shuffle=reader_config.shuffle, ) reader_batches.append(reader_batch) if mode == "retriever": return biencoder_batch elif mode == "reader": return reader_batches else: return biencoder_batch, reader_batches
def _train_epoch( self, scheduler, epoch: int, eval_step: int, train_data_iterator: ShardedDataIterator, ): args = self.args rolling_train_loss = 0.0 epoch_loss = 0 epoch_correct_predictions = 0 log_result_step = args.log_batch_step rolling_loss_step = args.train_rolling_loss_step num_hard_negatives = args.hard_negatives num_other_negatives = args.other_negatives seed = args.seed self.biencoder.train() epoch_batches = train_data_iterator.max_iterations data_iteration = 0 for i, samples_batch in enumerate( train_data_iterator.iterate_data(epoch=epoch) ): # to be able to resume shuffled ctx- pools data_iteration = train_data_iterator.get_iteration() random.seed(seed + epoch + data_iteration) biencoder_batch = BiEncoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=True, shuffle_positives=args.shuffle_positive_ctx, ) loss, correct_cnt = _do_biencoder_fwd_pass( self.biencoder, biencoder_batch, self.tensorizer, args ) epoch_correct_predictions += correct_cnt epoch_loss += loss.item() rolling_train_loss += loss.item() if args.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm ) else: loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( self.biencoder.parameters(), args.max_grad_norm ) if (i + 1) % args.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.biencoder.zero_grad() if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, loss=%f, lr=%f", epoch, data_iteration, epoch_batches, loss.item(), lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 if data_iteration % eval_step == 0: logger.info( "Validation: Epoch: %d Step: %d/%d", epoch, data_iteration, epoch_batches, ) self.validate_and_save( epoch, train_data_iterator.get_iteration(), scheduler ) self.biencoder.train() self.validate_and_save(epoch, data_iteration, scheduler) epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) logger.info("epoch total correct predictions=%d", epoch_correct_predictions)
def validate_average_rank(self) -> float: """ Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset. It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params) and stores them in RAM as well as question vectors. Then the similarity scores are calculted for the entire num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion. Each question's gold passage rank in that sorted list of scores is averaged across all the questions. :return: averaged rank number """ logger.info("Average rank validation ...") args = self.args self.biencoder.eval() distributed_factor = self.distributed_factor data_iterator = self.get_data_iterator( args.dev_file, args.dev_batch_size, shuffle=False ) sub_batch_size = args.val_av_rank_bsz sim_score_f = BiEncoderNllLoss.get_similarity_function() q_represenations = [] ctx_represenations = [] positive_idx_per_question = [] num_hard_negatives = args.val_av_rank_hard_neg num_other_negatives = args.val_av_rank_other_neg log_result_step = args.log_batch_step for i, samples_batch in enumerate(data_iterator.iterate_data()): # samples += 1 if len(q_represenations) > args.val_av_rank_max_qs / distributed_factor: break biencoder_input = BiEncoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=False, ) total_ctxs = len(ctx_represenations) ctxs_ids = biencoder_input.context_ids ctxs_segments = biencoder_input.ctx_segments bsz = ctxs_ids.size(0) # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch for j, batch_start in enumerate(range(0, bsz, sub_batch_size)): q_ids, q_segments = ( (biencoder_input.question_ids, biencoder_input.question_segments) if j == 0 else (None, None) ) if j == 0 and args.n_gpu > 1 and q_ids.size(0) == 1: # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0, # otherwise the other input tensors will be split but only the first split will be called continue ctx_ids_batch = ctxs_ids[batch_start: batch_start + sub_batch_size] ctx_seg_batch = ctxs_segments[ batch_start: batch_start + sub_batch_size ] q_attn_mask = self.tensorizer.get_attn_mask(q_ids) ctx_attn_mask = self.tensorizer.get_attn_mask(ctx_ids_batch) with torch.no_grad(): q_dense, ctx_dense = self.biencoder( q_ids, q_segments, q_attn_mask, ctx_ids_batch, ctx_seg_batch, ctx_attn_mask, ) if q_dense is not None: q_represenations.extend(q_dense.cpu().split(1, dim=0)) ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0)) batch_positive_idxs = biencoder_input.is_positive positive_idx_per_question.extend( [total_ctxs + v for v in batch_positive_idxs] ) if (i + 1) % log_result_step == 0: logger.info( "Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d", i, len(ctx_represenations), len(q_represenations), ) ctx_represenations = torch.cat(ctx_represenations, dim=0) q_represenations = torch.cat(q_represenations, dim=0) logger.info( "Av.rank validation: total q_vectors size=%s", q_represenations.size() ) logger.info( "Av.rank validation: total ctx_vectors size=%s", ctx_represenations.size() ) q_num = q_represenations.size(0) assert q_num == len(positive_idx_per_question) scores = sim_score_f(q_represenations, ctx_represenations) values, indices = torch.sort(scores, dim=1, descending=True) rank = 0 for i, idx in enumerate(positive_idx_per_question): # aggregate the rank of the known gold passage in the sorted results for each question gold_idx = (indices[i] == idx).nonzero() rank += gold_idx.item() if distributed_factor > 1: # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank # NOTE: the set of passages is still unique for every node eval_stats = all_gather_list([rank, q_num], max_size=100) for i, item in enumerate(eval_stats): remote_rank, remote_q_num = item if i != args.local_rank: rank += remote_rank q_num += remote_q_num av_rank = float(rank / q_num) logger.info( "Av.rank validation: average rank %s, total questions=%d", av_rank, q_num ) return av_rank
def _train_epoch( self, scheduler, epoch: int, eval_step: int, train_data_iterator: MultiSetDataIterator, ): cfg = self.cfg rolling_train_loss = 0.0 epoch_loss = 0 epoch_correct_predictions, epoch_correct_predictions_matching = 0, 0 log_result_step = cfg.train.log_batch_step rolling_loss_step = cfg.train.train_rolling_loss_step num_hard_negatives = cfg.train.hard_negatives num_other_negatives = cfg.train.other_negatives seed = cfg.seed self.biencoder.train() epoch_batches = train_data_iterator.max_iterations data_iteration = 0 dataset = 0 for i, samples_batch in enumerate( train_data_iterator.iterate_ds_data(epoch=epoch)): if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch ds_cfg = self.ds_cfg.train_datasets[dataset] special_token = ds_cfg.special_token encoder_type = ds_cfg.encoder_type shuffle_positives = ds_cfg.shuffle_positives # to be able to resume shuffled ctx- pools data_iteration = train_data_iterator.get_iteration() random.seed(seed + epoch + data_iteration) biencoder_batch = BiEncoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=True, shuffle_positives=shuffle_positives, query_token=special_token, ) # get the token to be used for representation selection from dpr.data.biencoder_data import DEFAULT_SELECTOR selector = ds_cfg.selector if ds_cfg else DEFAULT_SELECTOR rep_positions_q = selector.get_positions( biencoder_batch.question_ids, self.tensorizer, self.biencoder) rep_positions_c = selector.get_positions( biencoder_batch.context_ids, self.tensorizer, self.biencoder) loss_scale = (cfg.loss_scale_factors[dataset] if cfg.loss_scale_factors else None) outp = _do_biencoder_fwd_pass( self.biencoder, biencoder_batch, self.tensorizer, self.loss_function, cfg, encoder_type=encoder_type, rep_positions_q=rep_positions_q, rep_positions_c=rep_positions_c, loss_scale=loss_scale, clustering=self.clustering, ) if self.clustering: loss, correct_cnt, (question_vector, context_vector) = outp question_vector = question_vector.clone().detach().cpu().numpy( ) context_vector = context_vector.clone().detach().cpu().numpy() model_outs = ForwardPassOutputsTrain( loss=None, biencoder_is_correct=None, biencoder_input=biencoder_batch, biencoder_preds=BiEncoderPredictionBatch( question_vector=question_vector, context_vector=context_vector, ), reader_input=None, reader_preds=None, ) iterator: ShardedDataIteratorClustering = train_data_iterator.iterables[ dataset] iterator.record_predictions(epoch=epoch, model_outs=model_outs) elif cfg.others.is_matching: loss, correct_cnt, correct_cnt_matching = outp epoch_correct_predictions_matching += correct_cnt_matching else: loss, correct_cnt = outp epoch_correct_predictions += correct_cnt epoch_loss += loss.item() rolling_train_loss += loss.item() if cfg.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if cfg.train.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), cfg.train.max_grad_norm) else: loss.backward() if cfg.train.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.biencoder.parameters(), cfg.train.max_grad_norm) if (i + 1) % cfg.train.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.biencoder.zero_grad() if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, loss=%f, lr=%f", epoch, data_iteration, epoch_batches, loss.item(), lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 if data_iteration % eval_step == 0: logger.info( "rank=%d, Validation: Epoch: %d Step: %d/%d", cfg.local_rank, epoch, data_iteration, epoch_batches, ) self.validate_and_save(epoch, train_data_iterator.get_iteration(), scheduler) self.biencoder.train() logger.info("Epoch finished on %d", cfg.local_rank) # If we just evaluate at the last iteration, we don't need to evaluate again if data_iteration % eval_step != 0: self.validate_and_save(epoch, data_iteration, scheduler) epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) logger.info("epoch total correct predictions=%d", epoch_correct_predictions) if cfg.others.is_matching: logger.info("epoch total correct matching predictions=%d", epoch_correct_predictions_matching)
def validate_average_rank(self) -> float: """ Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset. It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params) and stores them in RAM as well as question vectors. Then the similarity scores are calculted for the entire num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion. Each question's gold passage rank in that sorted list of scores is averaged across all the questions. :return: averaged rank number """ logger.info("Average rank validation ...") cfg = self.cfg self.biencoder.eval() distributed_factor = self.distributed_factor if not self.dev_iterator: self.dev_iterator = self.get_data_iterator( cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank) data_iterator = self.dev_iterator sub_batch_size = cfg.train.val_av_rank_bsz sim_score_f = self.loss_function.get_similarity_function() q_represenations = [] ctx_represenations = [] positive_idx_per_question = [] num_hard_negatives = cfg.train.val_av_rank_hard_neg num_other_negatives = cfg.train.val_av_rank_other_neg dataset = 0 for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): # samples += 1 if (len(q_represenations) > cfg.train.val_av_rank_max_qs / distributed_factor): break if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch biencoder_input = BiEncoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=False, ) total_ctxs = len(ctx_represenations) ctxs_ids = biencoder_input.context_ids.to(cfg.device) ctxs_segments = biencoder_input.ctx_segments.to(cfg.device) bsz = ctxs_ids.size(0) # get the token to be used for representation selection ds_cfg = self.ds_cfg.dev_datasets[dataset] encoder_type = ds_cfg.encoder_type rep_positions_q = ds_cfg.selector.get_positions( biencoder_input.question_ids, self.tensorizer, self.biencoder) rep_positions_c = ds_cfg.selector.get_positions( biencoder_input.context_ids, self.tensorizer, self.biencoder) # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch for j, batch_start in enumerate(range(0, bsz, sub_batch_size)): q_ids, q_segments = ((biencoder_input.question_ids.to( cfg.device), biencoder_input.question_segments.to( cfg.device)) if j == 0 else (None, None)) if j == 0 and cfg.n_gpu > 1 and q_ids.size(0) == 1: # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0, # otherwise the other input tensors will be split but only the first split will be called continue ctx_ids_batch = ctxs_ids[batch_start:batch_start + sub_batch_size] ctx_seg_batch = ctxs_segments[batch_start:batch_start + sub_batch_size] q_attn_mask = self.tensorizer.get_attn_mask(q_ids) q_attn_mask = q_attn_mask if q_ids is not None else q_attn_mask ctx_attn_mask = self.tensorizer.get_attn_mask( ctx_ids_batch).to(cfg.device) with torch.no_grad(): q_dense, ctx_dense = self.biencoder( q_ids, q_segments, q_attn_mask, ctx_ids_batch, ctx_seg_batch, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos_q=rep_positions_q, representation_token_pos_c=rep_positions_c, ) if q_dense is not None: q_represenations.extend(q_dense.cpu().split(1, dim=0)) ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0)) batch_positive_idxs = biencoder_input.is_positive positive_idx_per_question.extend( [total_ctxs + v for v in batch_positive_idxs]) logger.info( "Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d", i, len(ctx_represenations), len(q_represenations), ) ctx_represenations = torch.cat(ctx_represenations, dim=0) q_represenations = torch.cat(q_represenations, dim=0) if cfg.others.is_matching: # Need to compute by batch of contexts logger.info("Average rank validation for interaction layers...") num_batches = math.ceil(len(ctx_represenations) / sub_batch_size) interaction_scores = [] for i in range(num_batches): start = i * sub_batch_size end = min(start + sub_batch_size, len(ctx_represenations)) with torch.no_grad(): interaction_score = self.biencoder( q_pooled_out=q_represenations.to(cfg.device), ctx_pooled_out=ctx_represenations[start:end].to( cfg.device), is_matching=True).cpu() interaction_scores.append(interaction_score) logger.info("Av.rank validation (interaction): step %d/%d", i, num_batches) interaction_scores = torch.cat( interaction_scores, dim=1) # concatenate along context dim logger.info( "Av.rank validation (interaction): total interaction matrix size=%s", interaction_scores.size()) logger.info("Av.rank validation: total q_vectors size=%s", q_represenations.size()) logger.info("Av.rank validation: total ctx_vectors size=%s", ctx_represenations.size()) # Calculate cosine similarity scores q_num = q_represenations.size(0) assert q_num == len(positive_idx_per_question) scores = sim_score_f(q_represenations, ctx_represenations) values, indices = torch.sort(scores, dim=1, descending=True) rank = 0 for i, idx in enumerate(positive_idx_per_question): # aggregate the rank of the known gold passage in the sorted results for each question gold_idx = (indices[i] == idx).nonzero() rank += gold_idx.item() if distributed_factor > 1: # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank # NOTE: the set of passages is still unique for every node eval_stats = all_gather_list([rank, q_num], max_size=100) for i, item in enumerate(eval_stats): remote_rank, remote_q_num = item if i != cfg.local_rank: rank += remote_rank q_num += remote_q_num av_rank = float(rank / q_num) logger.info("Av.rank validation: average rank %s, total questions=%d", av_rank, q_num) # Calculate interaction scores if cfg.others.is_matching: interaction_q_num = q_represenations.size(0) assert interaction_q_num == len(positive_idx_per_question) interaction_rank = 0 values, indices = torch.sort(interaction_scores, dim=1, descending=True) for i, idx in enumerate(positive_idx_per_question): # aggregate the rank of the known gold passage in the sorted results for each question gold_idx = (indices[i] == idx).nonzero() interaction_rank += gold_idx.item() if distributed_factor > 1: # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank # NOTE: the set of passages is still unique for every node eval_stats = all_gather_list( [interaction_rank, interaction_q_num], max_size=100) for i, item in enumerate(eval_stats): remote_rank, remote_q_num = item if i != cfg.local_rank: interaction_rank += remote_rank interaction_q_num += remote_q_num interaction_av_rank = float(interaction_rank / interaction_q_num) logger.info( "Av.rank validation (interaction): average rank %s, total questions=%d", interaction_av_rank, interaction_q_num) return interaction_av_rank if cfg.others.is_matching else av_rank
def validate_nll(self) -> float: logger.info("NLL validation ...") cfg = self.cfg self.biencoder.eval() if not self.dev_iterator: self.dev_iterator = self.get_data_iterator( cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank) data_iterator = self.dev_iterator total_loss = 0.0 start_time = time.time() total_correct_predictions, total_correct_predictions_matching = 0, 0 num_hard_negatives = cfg.train.hard_negatives num_other_negatives = cfg.train.other_negatives log_result_step = cfg.train.log_batch_step batches = 0 dataset = 0 for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch logger.info("Eval step: %d ,rnk=%s", i, cfg.local_rank) biencoder_input = BiEncoder.create_biencoder_input( samples_batch, self.tensorizer, insert_title=True, num_hard_negatives=num_hard_negatives, num_other_negatives=num_other_negatives, shuffle=False, ) # get the token to be used for representation selection ds_cfg = self.ds_cfg.dev_datasets[dataset] rep_positions_q = ds_cfg.selector.get_positions( biencoder_input.question_ids, self.tensorizer, self.biencoder) rep_positions_c = ds_cfg.selector.get_positions( biencoder_input.context_ids, self.tensorizer, self.biencoder) encoder_type = ds_cfg.encoder_type outp = _do_biencoder_fwd_pass( self.biencoder, biencoder_input, self.tensorizer, self.loss_function, cfg, encoder_type=encoder_type, rep_positions_q=rep_positions_q, rep_positions_c=rep_positions_c, ) if cfg.others.is_matching: loss, correct_cnt, correct_cnt_matching = outp total_correct_predictions_matching += correct_cnt_matching else: loss, correct_cnt = outp total_loss += loss.item() total_correct_predictions += correct_cnt batches += 1 if (i + 1) % log_result_step == 0: logger.info( "Eval step: %d , used_time=%f sec., loss=%f ", i, time.time() - start_time, loss.item(), ) total_loss = total_loss / batches total_samples = batches * cfg.train.dev_batch_size * self.distributed_factor correct_ratio = float(total_correct_predictions / total_samples) to_log = ( f"NLL Validation: loss = {total_loss:.4f} correct prediction ratio " f"{total_correct_predictions}/{total_samples} ~ {correct_ratio:.4f}" ) if cfg.others.is_matching: correct_ratio_matching = float(total_correct_predictions_matching / total_samples) to_log += ( f", matching correct prediction ratio {total_correct_predictions_matching}/{total_samples}" f" ~ {correct_ratio_matching:.4f}") logger.info(to_log) return total_loss