def iter_chunk_tokenize_and_featurize(task, examples: list, tokenizer, feat_spec: FeaturizationSpec, phase, verbose=False): """Generator of DataRows containing tokenized and featurized examples. Args: task (Task): Task object examples (list[Example]): list of task Examples. tokenizer: TODO (Issue #44) feat_spec (FeaturizationSpec): Tokenization-related metadata. phase (str): string identifying the data subset (e.g., train, val or test). verbose: If True, display progress bar. Yields: DataRow containing tokenized and featurized examples. """ for example in maybe_tqdm(examples, desc="Tokenizing", verbose=verbose): # TODO: Better solution (Issue #48) if task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA: yield from example.to_feature_list( tokenizer=tokenizer, max_seq_length=feat_spec.max_seq_length, doc_stride=task.doc_stride, max_query_length=task.max_query_length, set_type=phase, ) else: yield example.tokenize(tokenizer).featurize(tokenizer, feat_spec)
def run_train_context(self, verbose=True): train_dataloader_dict = self.get_train_dataloader_dict() loss_weights_dict = {} train_state = TrainState.from_task_name_list( self.jiant_task_container.task_run_config.train_task_list) global_train_config = self.jiant_task_container.global_train_config losses = [] for step in maybe_tqdm( range(global_train_config.max_steps), desc="Training", verbose=verbose, ): regular_log(logger, step, interval=10, tag='train') if step == global_train_config.weighted_sampling_start_step: train_dataloader_dict = self.get_train_dataloader_dict( do_weighted_sampling=True) if step == global_train_config.weighted_loss_start_step: loss_weights_dict = self.get_loss_weights_dict() loss_per_step = self.run_train_step( train_dataloader_dict=train_dataloader_dict, train_state=train_state, loss_weights_dict=loss_weights_dict, ) losses.append(loss_per_step) if step % 100 == 0: logger.info('[train] loss: %f', np.mean(losses)) self.tf_writer.flush() yield train_state
def run_val( val_dataloader, val_labels, jiant_model: JiantModel, task, device, local_rank, return_preds=False, verbose=True, ): # Reminder: # val_dataloader contains mostly PyTorch-relevant info # val_labels might contain more details information needed for full evaluation if not local_rank == -1: return jiant_model.eval() total_eval_loss = 0 nb_eval_steps, nb_eval_examples = 0, 0 evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() for step, (batch, batch_metadata) in enumerate( maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, Val)", verbose=verbose) ): batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=True, ) batch_logits = model_output.logits.detach().cpu().numpy() batch_loss = model_output.loss.mean().item() total_eval_loss += batch_loss eval_accumulator.update( batch_logits=batch_logits, batch_loss=batch_loss, batch=batch, batch_metadata=batch_metadata, ) nb_eval_examples += len(batch) nb_eval_steps += 1 eval_loss = total_eval_loss / nb_eval_steps tokenizer = ( jiant_model.tokenizer if not torch_utils.is_data_parallel(jiant_model) else jiant_model.module.tokenizer ) output = { "accumulator": eval_accumulator, "loss": eval_loss, "metrics": evaluation_scheme.compute_metrics_from_accumulator( task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer, ), } if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) return output
def run_test(test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True): if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() for step, (batch, batch_metadata) in enumerate( maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose) ): batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=False, ) batch_logits = model_output.logits.detach().cpu().numpy() eval_accumulator.update( batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata, ) return { "preds": evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ), "accumulator": eval_accumulator, }
def iter_chunk_tokenize_and_featurize(examples: list, tokenizer, feat_spec: FeaturizationSpec, phase, verbose=False): """Generator of DataRows containing tokenized and featurized examples. Args: examples (list[Example]): list of task Examples. tokenizer: TODO (Issue #44) feat_spec (FeaturizationSpec): Tokenization-related metadata. phase (str): string identifying the data subset (e.g., train, val or test). verbose: If True, display progress bar. Yields: DataRow containing tokenized and featurized examples. """ for example in maybe_tqdm(examples, desc="Tokenizing", verbose=verbose): # TODO: Better solution (Issue #48) if isinstance(example, squad_style.Example): # TODO: Expose parameters (Issue #49) yield from example.to_feature_list( tokenizer=tokenizer, max_seq_length=feat_spec.max_seq_length, doc_stride=128, max_query_length=64, set_type=phase, ) else: yield example.tokenize(tokenizer).featurize(tokenizer, feat_spec)
def read_examples(self, path, set_type): input_data = read_json(path, encoding="utf-8")["data"] is_training = set_type == PHASE.TRAIN examples = [] data = take_one(input_data) for paragraph in maybe_tqdm(data["paragraphs"]): for qa in paragraph["qas"]: qas_id = qa["id"] # Because answers can also come from questions, we're going to abuse notation # slightly and put the entire background+situation+question into the "context" # and leave nothing for the "question" question_text = " " if self.include_background: context_segments = [ paragraph["background"], paragraph["situation"], qa["question"], ] else: context_segments = [paragraph["situation"], qa["question"]] full_context = " ".join(segment.strip() for segment in context_segments) if is_training: answer = qa["answers"][0] start_position_character = full_context.find( answer["text"]) answer_text = answer["text"] answers = [] else: start_position_character = None answer_text = None answers = [{ "text": answer["text"], "answer_start": full_context.find(answer["text"]) } for answer in qa["answers"]] example = Example( qas_id=qas_id, question_text=question_text, context_text=full_context, answer_text=answer_text, start_position_character=start_position_character, title="", is_impossible=False, answers=answers, background_text=paragraph["background"], situation_text=paragraph["situation"], ) examples.append(example) return examples
def tokenize_and_featurize(task, examples: list, tokenizer, feat_spec: FeaturizationSpec, phase, verbose=False): """Create list of DataRows containing tokenized and featurized examples. Args: task (Task): Task object examples (list[Example]): list of task Examples. tokenizer: TODO (issue #1188) feat_spec (FeaturizationSpec): Tokenization-related metadata. phase (str): string identifying the data subset (e.g., train, val or test). verbose: If True, display progress bar. Returns: List DataRows containing tokenized and featurized examples. """ # TODO: Better solution (issue #1184) if task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA: data_rows = [] for step, example in enumerate( maybe_tqdm(examples, desc="Tokenizing", verbose=verbose)): regular_log(logger, step) data_rows += example.to_feature_list( tokenizer=tokenizer, max_seq_length=feat_spec.max_seq_length, doc_stride=task.doc_stride, max_query_length=task.max_query_length, set_type=phase, ) else: data_rows = [] for step, example in enumerate( maybe_tqdm(examples, desc="Tokenizing", verbose=verbose)): regular_log(logger, step) data_rows.append( example.tokenize(tokenizer).featurize(tokenizer, feat_spec)) return data_rows
def run_train_context(self, verbose=True): train_dataloader_dict = self.get_train_dataloader_dict() train_state = TrainState.from_task_name_list( self.jiant_task_container.task_run_config.train_task_list) for _ in maybe_tqdm( range(self.jiant_task_container.global_train_config.max_steps), desc="Training", verbose=verbose, ): self.run_train_step(train_dataloader_dict=train_dataloader_dict, train_state=train_state) yield train_state
def generic_read_squad_examples(path: str, set_type: str, example_class: type = dict, read_title: bool = True): with open(path, "r", encoding="utf-8") as reader: input_data = json.load(reader)["data"] is_training = set_type == PHASE.TRAIN examples = [] for step, entry in enumerate( maybe_tqdm(input_data, desc="Reading SQuAD Entries")): regular_log(logger, step) if read_title: title = entry["title"] else: title = "-" for paragraph in entry["paragraphs"]: context_text = paragraph["context"] for qa in paragraph["qas"]: qas_id = qa["id"] question_text = qa["question"] start_position_character = None answer_text = None answers = [] if "is_impossible" in qa: is_impossible = qa["is_impossible"] else: is_impossible = False if not is_impossible: if is_training: answer = qa["answers"][0] answer_text = answer["text"] start_position_character = answer["answer_start"] else: answers = qa["answers"] example = example_class( qas_id=qas_id, question_text=question_text, context_text=context_text, answer_text=answer_text, start_position_character=start_position_character, title=title, is_impossible=is_impossible, answers=answers, ) examples.append(example) return examples
def tokenize_and_featurize(examples: list, tokenizer, feat_spec: FeaturizationSpec, phase, verbose=False): """Create list of DataRows containing tokenized and featurized examples. Args: examples (list[Example]): list of task Examples. tokenizer: TODO (Issue #44) feat_spec (FeaturizationSpec): Tokenization-related metadata. phase (str): string identifying the data subset (e.g., train, val or test). verbose: If True, display progress bar. Returns: List DataRows containing tokenized and featurized examples. """ # TODO: Better solution (Issue #48) if isinstance(examples[0], squad_style.Example): data_rows = [] for example in maybe_tqdm(examples, desc="Tokenizing", verbose=verbose): # TODO: Expose parameters (Issue #49) data_rows += example.to_feature_list( tokenizer=tokenizer, feat_spec=feat_spec, max_seq_length=feat_spec.max_seq_length, doc_stride=128, max_query_length=64, set_type=phase, ) else: data_rows = [ example.tokenize(tokenizer).featurize(tokenizer, feat_spec) for example in maybe_tqdm(examples, desc="Tokenizing", verbose=verbose) ] return data_rows
def resume_train_context(self, train_state, verbose=True): train_dataloader_dict = self.get_train_dataloader_dict() start_position = train_state.global_steps for _ in maybe_tqdm( range(start_position, self.jiant_task_container.global_train_config.max_steps), desc="Training", initial=start_position, total=self.jiant_task_container.global_train_config.max_steps, verbose=verbose, ): self.run_train_step( train_dataloader_dict=train_dataloader_dict, train_state=train_state ) yield train_state
def get_preds_for_single_tagging_task(task, test_dataloader, runner: jiant_runner.JiantRunner, verbose: str = True): """Generate predictions for a single tagging task""" jiant_model, device = runner.model, runner.device jiant_model.eval() test_examples = task.get_test_examples() preds_list = [] example_i = 0 for step, (batch, batch_metadata) in enumerate( maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose)): batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=False, ) batch_logits = model_output.logits.detach().cpu().numpy() label_mask_arr = batch.label_mask.cpu().bool().numpy() preds_arr = np.argmax(batch_logits, axis=-1) for i in range(len(batch)): # noinspection PyUnresolvedReferences labels = [ task.ID_TO_LABEL[class_i] for class_i in preds_arr[i][label_mask_arr[i]] ] if len(labels) == len(test_examples[example_i].tokens): this_preds = list(zip(test_examples[example_i].tokens, labels)) elif len(labels) < len(test_examples[example_i].tokens): this_preds = list(zip(test_examples[example_i].tokens, labels)) this_preds += [ (task.LABELS[-1], token) for token in test_examples[example_i].tokens[len(labels):] ] else: raise RuntimeError preds_list.append(this_preds) example_i += 1 return preds_list
def smart_truncate(dataset: torch_utils.ListDataset, max_seq_length: int, verbose: bool = False): """Truncate data to the length of the longest example in the dataset. Args: dataset (torch_utils.ListDataset): ListDataset to truncate if possible. max_seq_length (int): The maximum total input sequence length. verbose (bool): If True, display progress bar tracking truncation progress. Returns: Tuple[torch_utils.ListDataset, int]: truncated dataset, and length of the longest sequence. """ if "input_mask" not in dataset.data[0]["data_row"].get_fields(): raise RuntimeError("Smart truncate not supported") valid_length_ls = [] range_idx = np.arange(max_seq_length) for datum in dataset.data: # TODO: document why reshape and max happen here (for cola this isn't necessary). # (issue #1185) indexer = datum["data_row"].input_mask.reshape(-1, max_seq_length).max(-2) valid_length_ls.append(range_idx[indexer.astype(bool)].max() + 1) max_valid_length = max(valid_length_ls) if max_valid_length == max_seq_length: return dataset, max_seq_length new_datum_ls = [] for step, datum in enumerate( maybe_tqdm(dataset.data, desc="Smart truncate data", verbose=verbose)): regular_log(logger, step) new_datum_ls.append( smart_truncate_datum( datum=datum, max_seq_length=max_seq_length, max_valid_length=max_valid_length, )) new_dataset = torch_utils.ListDataset(new_datum_ls) return new_dataset, max_valid_length
def read_squad_examples(cls, path, set_type): with open(path, "r", encoding="utf-8") as reader: input_data = json.load(reader)["data"] is_training = set_type == PHASE.TRAIN examples = [] for entry in maybe_tqdm(input_data, desc="Reading SQuAD Entries"): title = entry["title"] for paragraph in entry["paragraphs"]: context_text = paragraph["context"] for qa in paragraph["qas"]: qas_id = qa["id"] question_text = qa["question"] start_position_character = None answer_text = None answers = [] if "is_impossible" in qa: is_impossible = qa["is_impossible"] else: is_impossible = False if not is_impossible: if is_training: answer = qa["answers"][0] answer_text = answer["text"] start_position_character = answer["answer_start"] else: answers = qa["answers"] example = cls.Example( qas_id=qas_id, question_text=question_text, context_text=context_text, answer_text=answer_text, start_position_character=start_position_character, title=title, is_impossible=is_impossible, answers=answers, ) examples.append(example) return examples
def smart_truncate_cache( cache: shared_caching.ChunkedFilesDataCache, max_seq_length: int, max_valid_length: int, verbose: bool = False, ): for chunk_i in maybe_trange(cache.num_chunks, desc="Smart truncate chunks", verbose=verbose): chunk = torch.load(cache.get_chunk_path(chunk_i)) new_chunk = [] for datum in maybe_tqdm(chunk, desc="Smart truncate chunk-datum", verbose=verbose): new_chunk.append( smart_truncate_datum( datum=datum, max_seq_length=max_seq_length, max_valid_length=max_valid_length, )) torch.save(new_chunk, cache.get_chunk_path(chunk_i))
def compute_predictions_logits_v2( partial_examples: List[PartialExample], all_results, n_best_size, max_answer_length, do_lower_case, version_2_with_negative, null_score_diff_threshold, tokenizer, skip_get_final_text=False, verbose=True, ): """Write final predictions to the json file and log-odds of null if needed.""" unique_id_to_result = {} for result in all_results: unique_id_to_result[result.unique_id] = result _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"], ) all_predictions = collections.OrderedDict() all_nbest_json = collections.OrderedDict() scores_diff_json = collections.OrderedDict() for example in maybe_tqdm(partial_examples, verbose=verbose): features = example.partial_features prelim_predictions = [] # keep track of the minimum score of null start+end of position 0 score_null = 1000000 # large and positive min_null_feature_index = 0 # the paragraph slice with min null score null_start_logit = 0 # the start logit at the slice with min null score null_end_logit = 0 # the end logit at the slice with min null score for (feature_index, feature) in enumerate(features): result = unique_id_to_result[feature.unique_id] start_indexes = _get_best_indexes(result.start_logits, n_best_size) end_indexes = _get_best_indexes(result.end_logits, n_best_size) # if we could have irrelevant answers, get the min score of irrelevant if version_2_with_negative: feature_null_score = result.start_logits[0] + result.end_logits[0] if feature_null_score < score_null: score_null = feature_null_score min_null_feature_index = feature_index null_start_logit = result.start_logits[0] null_end_logit = result.end_logits[0] for start_index in start_indexes: for end_index in end_indexes: # We could hypothetically create invalid predictions, e.g., predict # that the start of the span is in the question. We throw out all # invalid predictions. if start_index >= len(feature.tokens): continue if end_index >= len(feature.tokens): continue if start_index not in feature.token_to_orig_map: continue if end_index not in feature.token_to_orig_map: continue if not feature.token_is_max_context.get(start_index, False): continue if end_index < start_index: continue length = end_index - start_index + 1 if length > max_answer_length: continue prelim_predictions.append( _PrelimPrediction( feature_index=feature_index, start_index=start_index, end_index=end_index, start_logit=result.start_logits[start_index], end_logit=result.end_logits[end_index], ) ) if version_2_with_negative: prelim_predictions.append( _PrelimPrediction( feature_index=min_null_feature_index, start_index=0, end_index=0, start_logit=null_start_logit, end_logit=null_end_logit, ) ) prelim_predictions = sorted( prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True ) _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name "NbestPrediction", ["text", "start_logit", "end_logit"] ) seen_predictions = {} nbest = [] for pred in prelim_predictions: if len(nbest) >= n_best_size: break feature = features[pred.feature_index] if pred.start_index > 0: # this is a non-null prediction tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)] orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)] tok_text = tokenizer.convert_tokens_to_string(tok_tokens) # tok_text = " ".join(tok_tokens) # # # De-tokenize WordPieces that have been split off. # tok_text = tok_text.replace(" ##", "") # tok_text = tok_text.replace("##", "") # Clean whitespace tok_text = tok_text.strip() tok_text = " ".join(tok_text.split()) orig_text = " ".join(orig_tokens) if not skip_get_final_text: final_text = get_final_text(tok_text, orig_text, do_lower_case) else: final_text = tok_text if final_text in seen_predictions: continue seen_predictions[final_text] = True else: final_text = "" seen_predictions[final_text] = True nbest.append( _NbestPrediction( text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit ) ) # if we didn't include the empty option in the n-best, include it if version_2_with_negative: if "" not in seen_predictions: nbest.append( _NbestPrediction( text="", start_logit=null_start_logit, end_logit=null_end_logit ) ) # In very rare edge cases we could only have single null prediction. # So we just create a nonce prediction in this case to avoid failure. if len(nbest) == 1: nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) # In very rare edge cases we could have no valid predictions. So we # just create a nonce prediction in this case to avoid failure. if not nbest: nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) assert len(nbest) >= 1 total_scores = [] best_non_null_entry = None for entry in nbest: total_scores.append(entry.start_logit + entry.end_logit) if not best_non_null_entry: if entry.text: best_non_null_entry = entry probs = _compute_softmax(total_scores) nbest_json = [] for (i, entry) in enumerate(nbest): output = collections.OrderedDict() output["text"] = entry.text output["probability"] = probs[i] output["start_logit"] = entry.start_logit output["end_logit"] = entry.end_logit nbest_json.append(output) assert len(nbest_json) >= 1 if not version_2_with_negative: all_predictions[example.qas_id] = nbest_json[0]["text"] else: # predict "" iff the null score - the score of best non-null > threshold score_diff = ( score_null - best_non_null_entry.start_logit - best_non_null_entry.end_logit ) scores_diff_json[example.qas_id] = score_diff if score_diff > null_score_diff_threshold: all_predictions[example.qas_id] = "" else: all_predictions[example.qas_id] = best_non_null_entry.text all_nbest_json[example.qas_id] = nbest_json return all_predictions
def run_val( val_dataloader, val_labels, jiant_model: JiantModel, task, device, local_rank, tf_writer: SummaryWriter, global_step: Optional[int] = None, phase=None, return_preds=False, return_logits=True, return_encoder_output: bool = False, verbose=True, split='valid', ): # Reminder: # val_dataloader contains mostly PyTorch-relevant info # val_labels might contain more details information needed for full evaluation has_labels = True # TODO: データセットにラベルが存在するかどうかを自動判定する. if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() total_eval_loss = 0 nb_eval_steps, nb_eval_examples = 0, 0 encoder_outputs = [] for step, (batch, batch_metadata) in enumerate( maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, {str(phase)})", verbose=verbose)): regular_log(logger, step, interval=10, tag=split) batch = batch.to(device) with torch.no_grad(): model_outputs = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=has_labels, get_encoder_output=return_encoder_output, ) if return_encoder_output: model_output, encoder_output = model_outputs encoder_outputs.append(encoder_output) else: model_output = model_outputs batch_logits = model_output.logits.detach().cpu().numpy() if has_labels: batch_loss = model_output.loss.mean().item() else: batch_loss = 0 total_eval_loss += batch_loss eval_accumulator.update( batch_logits=batch_logits, batch_loss=batch_loss, batch=batch, batch_metadata=batch_metadata, ) nb_eval_examples += len(batch) nb_eval_steps += 1 eval_loss = total_eval_loss / nb_eval_steps output = { "accumulator": eval_accumulator, } if has_labels: tokenizer = (jiant_model.tokenizer if not torch_utils.is_data_parallel(jiant_model) else jiant_model.module.tokenizer) metrics = evaluation_scheme.compute_metrics_from_accumulator( task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer, ) output.update({ "loss": eval_loss, "metrics": metrics, }) if global_step is not None: for metric_name, metric_value in metrics.minor.items(): tf_writer.add_scalar(f'{split}/{metric_name}', metric_value, global_step=global_step) if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) if isinstance(eval_accumulator, evaluate.ConcatenateLogitsAccumulator) and return_logits: output["logits"] = eval_accumulator.get_accumulated() if return_encoder_output: output["encoder_outputs_pooled"] = np.concatenate( [encoder_output.pooled for encoder_output in encoder_outputs]) output["encoder_outputs_unpooled"] = np.concatenate( [encoder_output.unpooled for encoder_output in encoder_outputs]) if global_step is not None: tf_writer.add_scalar(f'{split}/loss', eval_loss, global_step=global_step) tf_writer.flush() return output
def run_test( test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True, return_preds=True, return_logits=True, return_encoder_output: bool = False, ): if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() encoder_outputs = [] for step, (batch, batch_metadata) in enumerate( maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose)): regular_log(logger, step, interval=10, tag='test') batch = batch.to(device) with torch.no_grad(): model_outputs = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=False, get_encoder_output=return_encoder_output, ) if return_encoder_output: model_output, encoder_output = model_outputs encoder_outputs.append(encoder_output) else: model_output = model_outputs batch_logits = model_output.logits.detach().cpu().numpy() eval_accumulator.update( batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata, ) output = { "accumulator": eval_accumulator, } if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) if isinstance(eval_accumulator, evaluate.ConcatenateLogitsAccumulator) and return_logits: output["logits"] = eval_accumulator.get_accumulated() if return_encoder_output: output["encoder_outputs_pooled"] = np.concatenate( [encoder_output.pooled for encoder_output in encoder_outputs]) output["encoder_outputs_unpooled"] = np.concatenate( [encoder_output.unpooled for encoder_output in encoder_outputs]) return output