def iter_chunk_tokenize_and_featurize(task, examples: list, tokenizer, feat_spec: FeaturizationSpec, phase, verbose=False): """Generator of DataRows containing tokenized and featurized examples. Args: task (Task): Task object examples (list[Example]): list of task Examples. tokenizer: TODO (issue #1188) feat_spec (FeaturizationSpec): Tokenization-related metadata. phase (str): string identifying the data subset (e.g., train, val or test). verbose: If True, display progress bar. Yields: DataRow containing tokenized and featurized examples. """ for step, example in enumerate( maybe_tqdm(examples, desc="Tokenizing", verbose=verbose)): regular_log(logger, step) # TODO: Better solution (issue #1184) if task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA: yield from example.to_feature_list( tokenizer=tokenizer, max_seq_length=feat_spec.max_seq_length, doc_stride=task.doc_stride, max_query_length=task.max_query_length, set_type=phase, ) else: yield example.tokenize(tokenizer).featurize(tokenizer, feat_spec)
def run_train_context(self, verbose=True): train_dataloader_dict = self.get_train_dataloader_dict() loss_weights_dict = {} train_state = TrainState.from_task_name_list( self.jiant_task_container.task_run_config.train_task_list) global_train_config = self.jiant_task_container.global_train_config losses = [] for step in maybe_tqdm( range(global_train_config.max_steps), desc="Training", verbose=verbose, ): regular_log(logger, step, interval=10, tag='train') if step == global_train_config.weighted_sampling_start_step: train_dataloader_dict = self.get_train_dataloader_dict( do_weighted_sampling=True) if step == global_train_config.weighted_loss_start_step: loss_weights_dict = self.get_loss_weights_dict() loss_per_step = self.run_train_step( train_dataloader_dict=train_dataloader_dict, train_state=train_state, loss_weights_dict=loss_weights_dict, ) losses.append(loss_per_step) if step % 100 == 0: logger.info('[train] loss: %f', np.mean(losses)) self.tf_writer.flush() yield train_state
def generic_read_squad_examples(path: str, set_type: str, example_class: type = dict, read_title: bool = True): with open(path, "r", encoding="utf-8") as reader: input_data = json.load(reader)["data"] is_training = set_type == PHASE.TRAIN examples = [] for step, entry in enumerate( maybe_tqdm(input_data, desc="Reading SQuAD Entries")): regular_log(logger, step) if read_title: title = entry["title"] else: title = "-" for paragraph in entry["paragraphs"]: context_text = paragraph["context"] for qa in paragraph["qas"]: qas_id = qa["id"] question_text = qa["question"] start_position_character = None answer_text = None answers = [] if "is_impossible" in qa: is_impossible = qa["is_impossible"] else: is_impossible = False if not is_impossible: if is_training: answer = qa["answers"][0] answer_text = answer["text"] start_position_character = answer["answer_start"] else: answers = qa["answers"] example = example_class( qas_id=qas_id, question_text=question_text, context_text=context_text, answer_text=answer_text, start_position_character=start_position_character, title=title, is_impossible=is_impossible, answers=answers, ) examples.append(example) return examples
def load_combined(path): result = pd.read_csv( path, encoding="utf-8", dtype=dict(is_answer_absent=float), na_values=dict(question=[], story_text=[], validated_answers=[]), keep_default_na=False, ) if "story_text" in result.keys(): for step, row_ in enumerate(display.tqdm( result.itertuples(), total=len(result), desc="Adjusting story texts" )): regular_log(logger, step) story_text_ = row_.story_text.replace("\r\n", "\n") result.at[row_.Index, "story_text"] = story_text_ return result
def smart_truncate(dataset: torch_utils.ListDataset, max_seq_length: int, verbose: bool = False): """Truncate data to the length of the longest example in the dataset. Args: dataset (torch_utils.ListDataset): ListDataset to truncate if possible. max_seq_length (int): The maximum total input sequence length. verbose (bool): If True, display progress bar tracking truncation progress. Returns: Tuple[torch_utils.ListDataset, int]: truncated dataset, and length of the longest sequence. """ if "input_mask" not in dataset.data[0]["data_row"].get_fields(): raise RuntimeError("Smart truncate not supported") valid_length_ls = [] range_idx = np.arange(max_seq_length) for datum in dataset.data: # TODO: document why reshape and max happen here (for cola this isn't necessary). # (issue #1185) indexer = datum["data_row"].input_mask.reshape(-1, max_seq_length).max(-2) valid_length_ls.append(range_idx[indexer.astype(bool)].max() + 1) max_valid_length = max(valid_length_ls) if max_valid_length == max_seq_length: return dataset, max_seq_length new_datum_ls = [] for step, datum in enumerate( maybe_tqdm(dataset.data, desc="Smart truncate data", verbose=verbose)): regular_log(logger, step) new_datum_ls.append( smart_truncate_datum( datum=datum, max_seq_length=max_seq_length, max_valid_length=max_valid_length, )) new_dataset = torch_utils.ListDataset(new_datum_ls) return new_dataset, max_valid_length
def preprocess_all_glue_data(input_base_path, output_base_path, task_name_ls=None): if task_name_ls is None: task_name_ls = GLUE_CONVERSION.keys() os.makedirs(output_base_path, exist_ok=True) os.makedirs(os.path.join(output_base_path, "data"), exist_ok=True) os.makedirs(os.path.join(output_base_path, "configs"), exist_ok=True) for step, task_name in enumerate(tqdm.tqdm(task_name_ls)): regular_log(logger, step) task_data_path = os.path.join(output_base_path, "data", task_name) paths_dict = convert_glue_data( input_base_path=input_base_path, task_data_path=task_data_path, task_name=task_name, ) config = {"task": task_name, "paths": paths_dict, "name": task_name} py_io.write_json(data=config, path=os.path.join(output_base_path, "configs", f"{task_name}.json"))
def get_preds_for_single_tagging_task( task, test_dataloader, runner: jiant_runner.JiantRunner, verbose: str = True ): """Generate predictions for a single tagging task""" jiant_model, device = runner.model, runner.device jiant_model.eval() test_examples = task.get_test_examples() preds_list = [] example_i = 0 for step, (batch, batch_metadata) in enumerate( maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose) ): regular_log(logger, step, interval=10) batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=False, ) batch_logits = model_output.logits.detach().cpu().numpy() label_mask_arr = batch.label_mask.cpu().bool().numpy() preds_arr = np.argmax(batch_logits, axis=-1) for i in range(len(batch)): # noinspection PyUnresolvedReferences labels = [task.ID_TO_LABEL[class_i] for class_i in preds_arr[i][label_mask_arr[i]]] if len(labels) == len(test_examples[example_i].tokens): this_preds = list(zip(test_examples[example_i].tokens, labels)) elif len(labels) < len(test_examples[example_i].tokens): this_preds = list(zip(test_examples[example_i].tokens, labels)) this_preds += [ (task.LABELS[-1], token) for token in test_examples[example_i].tokens[len(labels) :] ] else: raise RuntimeError preds_list.append(this_preds) example_i += 1 return preds_list
def smart_truncate_cache( cache: shared_caching.ChunkedFilesDataCache, max_seq_length: int, max_valid_length: int, verbose: bool = False, ): for chunk_i in maybe_trange(cache.num_chunks, desc="Smart truncate chunks", verbose=verbose): chunk = torch.load(cache.get_chunk_path(chunk_i)) new_chunk = [] for step, datum in enumerate( maybe_tqdm(chunk, desc="Smart truncate chunk-datum", verbose=verbose)): regular_log(logger, step) new_chunk.append( smart_truncate_datum( datum=datum, max_seq_length=max_seq_length, max_valid_length=max_valid_length, )) torch.save(new_chunk, cache.get_chunk_path(chunk_i))
def download_newsqa_data_and_write_config( task_name: str, task_data_path: str, task_config_path: str ): def get_consensus_answer(row_): answer_char_start, answer_char_end = None, None if row_.validated_answers: validated_answers_ = json.loads(row.validated_answers) answer_, max_count = max(validated_answers_.items(), key=itemgetter(1)) total_count = sum(validated_answers_.values()) if max_count >= total_count / 2.0: if answer_ != "none" and answer_ != "bad_question": answer_char_start, answer_char_end = map(int, answer_.split(":")) else: # No valid answer. pass else: # Check row_.answer_char_ranges for most common answer. # No validation was done so there must be an answer with consensus. answers = Counter() for user_answer in row_.answer_char_ranges.split("|"): for ans in user_answer.split(","): answers[ans] += 1 top_answer = answers.most_common(1) if top_answer: top_answer, _ = top_answer[0] if ":" in top_answer: answer_char_start, answer_char_end = map(int, top_answer.split(":")) return answer_char_start, answer_char_end def load_combined(path): result = pd.read_csv( path, encoding="utf-8", dtype=dict(is_answer_absent=float), na_values=dict(question=[], story_text=[], validated_answers=[]), keep_default_na=False, ) if "story_text" in result.keys(): for step, row_ in enumerate(display.tqdm( result.itertuples(), total=len(result), desc="Adjusting story texts" )): regular_log(logger, step) story_text_ = row_.story_text.replace("\r\n", "\n") result.at[row_.Index, "story_text"] = story_text_ return result def _map_answers(answers): result = [] for a in answers.split("|"): user_answers = [] result.append(dict(sourcerAnswers=user_answers)) for r in a.split(","): if r == "None": user_answers.append(dict(noAnswer=True)) else: start_, end_ = map(int, r.split(":")) user_answers.append(dict(s=start_, e=end_)) return result def strip_empty_strings(strings): while strings and strings[-1] == "": del strings[-1] return strings # Require: cnn_stories.tgz cnn_stories_path = os.path.join(task_data_path, "cnn_stories.tgz") assert os.path.exists(cnn_stories_path), ( "Download CNN Stories from https://cs.nyu.edu/~kcho/DMQA/ and save to " + cnn_stories_path ) # Require: newsqa-data-v1/newsqa-data-v1.csv dataset_path = os.path.join(task_data_path, "newsqa-data-v1", "newsqa-data-v1.csv") if os.path.exists(dataset_path): pass elif os.path.exists(os.path.join(task_data_path, "newsqa-data-v1.zip")): download_utils.unzip_file( zip_path=os.path.join(task_data_path, "newsqa-data-v1.zip"), extract_location=task_data_path, delete=False, ) else: raise AssertionError( "Download https://www.microsoft.com/en-us/research/project/newsqa-dataset/#!download" " and save to " + os.path.join(task_data_path, "newsqa-data-v1.zip") ) # Download auxiliary data os.makedirs(task_data_path, exist_ok=True) file_name_list = [ "train_story_ids.csv", "dev_story_ids.csv", "test_story_ids.csv", "stories_requiring_extra_newline.csv", "stories_requiring_two_extra_newlines.csv", "stories_to_decode_specially.csv", ] for file_name in file_name_list: download_utils.download_file( f"https://raw.githubusercontent.com/Maluuba/newsqa/master/maluuba/newsqa/{file_name}", os.path.join(task_data_path, file_name), ) dataset = load_combined(dataset_path) remaining_story_ids = set(dataset["story_id"]) with open( os.path.join(task_data_path, "stories_requiring_extra_newline.csv"), "r", encoding="utf-8" ) as f: stories_requiring_extra_newline = set(f.read().split("\n")) with open( os.path.join(task_data_path, "stories_requiring_two_extra_newlines.csv"), "r", encoding="utf-8", ) as f: stories_requiring_two_extra_newlines = set(f.read().split("\n")) with open( os.path.join(task_data_path, "stories_to_decode_specially.csv"), "r", encoding="utf-8" ) as f: stories_to_decode_specially = set(f.read().split("\n")) # Start combining data files story_id_to_text = {} with tarfile.open(cnn_stories_path, mode="r:gz", encoding="utf-8") as t: highlight_indicator = "@highlight" copyright_line_pattern = re.compile( "^(Copyright|Entire contents of this article copyright, )" ) with display.tqdm(total=len(remaining_story_ids), desc="Getting story texts") as pbar: for step, member in enumerate(t.getmembers()): regular_log(logger, step) story_id = member.name if story_id in remaining_story_ids: remaining_story_ids.remove(story_id) story_file = t.extractfile(member) # Correct discrepancies in stories. # Problems are caused by using several programming languages and libraries. # When ingesting the stories, we started with Python 2. # After dealing with unicode issues, we tried switching to Python 3. # That caused inconsistency problems so we switched back to Python 2. # Furthermore, when crowdsourcing, JavaScript and HTML templating perturbed # the stories. # So here we map the text to be compatible with the indices. lines = map(lambda s_: s_.strip().decode("utf-8"), story_file.readlines()) story_file.close() lines = list(lines) highlights_start = lines.index(highlight_indicator) story_lines = lines[:highlights_start] story_lines = strip_empty_strings(story_lines) while len(story_lines) > 1 and copyright_line_pattern.search(story_lines[-1]): story_lines = strip_empty_strings(story_lines[:-2]) if story_id in stories_requiring_two_extra_newlines: story_text = "\n\n\n".join(story_lines) elif story_id in stories_requiring_extra_newline: story_text = "\n\n".join(story_lines) else: story_text = "\n".join(story_lines) story_text = story_text.replace("\xe2\x80\xa2", "\xe2\u20ac\xa2") story_text = story_text.replace("\xe2\x82\xac", "\xe2\u201a\xac") story_text = story_text.replace("\r", "\n") if story_id in stories_to_decode_specially: story_text = story_text.replace("\xe9", "\xc3\xa9") story_id_to_text[story_id] = story_text pbar.update() if len(remaining_story_ids) == 0: break for step, row in enumerate(display.tqdm(dataset.itertuples(), total=len(dataset), desc="Setting story texts")): regular_log(logger, step) # Set story_text since we cannot include it in the dataset. story_text = story_id_to_text[row.story_id] dataset.at[row.Index, "story_text"] = story_text # Handle endings that are too large. answer_char_ranges = row.answer_char_ranges.split("|") updated_answer_char_ranges = [] ranges_updated = False for user_answer_char_ranges in answer_char_ranges: updated_user_answer_char_ranges = [] for char_range in user_answer_char_ranges.split(","): if char_range != "None": start, end = map(int, char_range.split(":")) if end > len(story_text): ranges_updated = True end = len(story_text) if start < end: updated_user_answer_char_ranges.append("%d:%d" % (start, end)) else: # It's unclear why but sometimes the end is after the start. # We'll filter these out. ranges_updated = True else: updated_user_answer_char_ranges.append(char_range) if updated_user_answer_char_ranges: updated_user_answer_char_ranges = ",".join(updated_user_answer_char_ranges) updated_answer_char_ranges.append(updated_user_answer_char_ranges) if ranges_updated: updated_answer_char_ranges = "|".join(updated_answer_char_ranges) dataset.at[row.Index, "answer_char_ranges"] = updated_answer_char_ranges if row.validated_answers and not pd.isnull(row.validated_answers): updated_validated_answers = {} validated_answers = json.loads(row.validated_answers) for char_range, count in validated_answers.items(): if ":" in char_range: start, end = map(int, char_range.split(":")) if end > len(story_text): ranges_updated = True end = len(story_text) if start < end: char_range = "{}:{}".format(start, end) updated_validated_answers[char_range] = count else: # It's unclear why but sometimes the end is after the start. # We'll filter these out. ranges_updated = True else: updated_validated_answers[char_range] = count if ranges_updated: updated_validated_answers = json.dumps( updated_validated_answers, ensure_ascii=False, separators=(",", ":") ) dataset.at[row.Index, "validated_answers"] = updated_validated_answers # Process Splits data = [] cache = dict() train_story_ids = set( pd.read_csv(os.path.join(task_data_path, "train_story_ids.csv"))["story_id"].values ) dev_story_ids = set( pd.read_csv(os.path.join(task_data_path, "dev_story_ids.csv"))["story_id"].values ) test_story_ids = set( pd.read_csv(os.path.join(task_data_path, "test_story_ids.csv"))["story_id"].values ) def _get_data_type(story_id_): if story_id_ in train_story_ids: return "train" elif story_id_ in dev_story_ids: return "dev" elif story_id_ in test_story_ids: return "test" else: return ValueError("{} not found in any story ID set.".format(story_id)) for step, row in enumerate(display.tqdm(dataset.itertuples(), total=len(dataset), desc="Building json")): regular_log(logger, step) questions = cache.get(row.story_id) if questions is None: questions = [] datum = dict( storyId=row.story_id, type=_get_data_type(row.story_id), text=row.story_text, questions=questions, ) cache[row.story_id] = questions data.append(datum) q = dict( q=row.question, answers=_map_answers(row.answer_char_ranges), isAnswerAbsent=row.is_answer_absent, ) if row.is_question_bad != "?": q["isQuestionBad"] = float(row.is_question_bad) if row.validated_answers and not pd.isnull(row.validated_answers): validated_answers = json.loads(row.validated_answers) q["validatedAnswers"] = [] for answer, count in validated_answers.items(): answer_item = dict(count=count) if answer == "none": answer_item["noAnswer"] = True elif answer == "bad_question": answer_item["badQuestion"] = True else: s, e = map(int, answer.split(":")) answer_item["s"] = s answer_item["e"] = e q["validatedAnswers"].append(answer_item) consensus_start, consensus_end = get_consensus_answer(row) if consensus_start is None and consensus_end is None: if q.get("isQuestionBad", 0) >= 0.5: q["consensus"] = dict(badQuestion=True) else: q["consensus"] = dict(noAnswer=True) else: q["consensus"] = dict(s=consensus_start, e=consensus_end) questions.append(q) phase_dict = { "train": [], "val": [], "test": [], } phase_map = {"train": "train", "dev": "val", "test": "test"} for entry in data: phase = phase_map[entry["type"]] output_entry = {"text": entry["text"], "storyId": entry["storyId"], "qas": []} for qn in entry["questions"]: if "badQuestion" in qn["consensus"] or "noAnswer" in qn["consensus"]: continue output_entry["qas"].append({"question": qn["q"], "answer": qn["consensus"]}) phase_dict[phase].append(output_entry) for phase, phase_data in phase_dict.items(): py_io.write_jsonl(phase_data, os.path.join(task_data_path, f"{phase}.jsonl"), skip_if_exists=True) py_io.write_json( data={ "task": task_name, "paths": { "train": os.path.join(task_data_path, "train.jsonl"), "val": os.path.join(task_data_path, "val.jsonl"), "test": os.path.join(task_data_path, "val.jsonl"), }, "name": task_name, }, path=task_config_path, skip_if_exists=True, ) for file_name in file_name_list: os.remove(os.path.join(task_data_path, file_name))
def compute_predictions_logits_v2( partial_examples: List[PartialExample], all_results, n_best_size, max_answer_length, do_lower_case, version_2_with_negative, null_score_diff_threshold, tokenizer, skip_get_final_text=False, verbose=True, ): """Write final predictions to the json file and log-odds of null if needed.""" unique_id_to_result = {} for result in all_results: unique_id_to_result[result.unique_id] = result _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"], ) all_predictions = collections.OrderedDict() all_nbest_json = collections.OrderedDict() scores_diff_json = collections.OrderedDict() for step, example in enumerate(maybe_tqdm(partial_examples, verbose=verbose)): regular_log(logger, step, interval=10) features = example.partial_features prelim_predictions = [] # keep track of the minimum score of null start+end of position 0 score_null = 1000000 # large and positive min_null_feature_index = 0 # the paragraph slice with min null score null_start_logit = 0 # the start logit at the slice with min null score null_end_logit = 0 # the end logit at the slice with min null score for (feature_index, feature) in enumerate(features): result = unique_id_to_result[feature.unique_id] start_indexes = _get_best_indexes(result.start_logits, n_best_size) end_indexes = _get_best_indexes(result.end_logits, n_best_size) # if we could have irrelevant answers, get the min score of irrelevant if version_2_with_negative: feature_null_score = result.start_logits[0] + result.end_logits[0] if feature_null_score < score_null: score_null = feature_null_score min_null_feature_index = feature_index null_start_logit = result.start_logits[0] null_end_logit = result.end_logits[0] for start_index in start_indexes: for end_index in end_indexes: # We could hypothetically create invalid predictions, e.g., predict # that the start of the span is in the question. We throw out all # invalid predictions. if start_index >= len(feature.tokens): continue if end_index >= len(feature.tokens): continue if start_index not in feature.token_to_orig_map: continue if end_index not in feature.token_to_orig_map: continue if not feature.token_is_max_context.get(start_index, False): continue if end_index < start_index: continue length = end_index - start_index + 1 if length > max_answer_length: continue prelim_predictions.append( _PrelimPrediction( feature_index=feature_index, start_index=start_index, end_index=end_index, start_logit=result.start_logits[start_index], end_logit=result.end_logits[end_index], ) ) if version_2_with_negative: prelim_predictions.append( _PrelimPrediction( feature_index=min_null_feature_index, start_index=0, end_index=0, start_logit=null_start_logit, end_logit=null_end_logit, ) ) prelim_predictions = sorted( prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True ) _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name "NbestPrediction", ["text", "start_logit", "end_logit"] ) seen_predictions = {} nbest = [] for pred in prelim_predictions: if len(nbest) >= n_best_size: break feature = features[pred.feature_index] if pred.start_index > 0: # this is a non-null prediction tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)] orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)] tok_text = tokenizer.convert_tokens_to_string(tok_tokens) # tok_text = " ".join(tok_tokens) # # # De-tokenize WordPieces that have been split off. # tok_text = tok_text.replace(" ##", "") # tok_text = tok_text.replace("##", "") # Clean whitespace tok_text = tok_text.strip() tok_text = " ".join(tok_text.split()) orig_text = " ".join(orig_tokens) if not skip_get_final_text: final_text = get_final_text(tok_text, orig_text, do_lower_case) else: final_text = tok_text if final_text in seen_predictions: continue seen_predictions[final_text] = True else: final_text = "" seen_predictions[final_text] = True nbest.append( _NbestPrediction( text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit ) ) # if we didn't include the empty option in the n-best, include it if version_2_with_negative: if "" not in seen_predictions: nbest.append( _NbestPrediction( text="", start_logit=null_start_logit, end_logit=null_end_logit ) ) # In very rare edge cases we could only have single null prediction. # So we just create a nonce prediction in this case to avoid failure. if len(nbest) == 1: nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) # In very rare edge cases we could have no valid predictions. So we # just create a nonce prediction in this case to avoid failure. if not nbest: nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) assert len(nbest) >= 1 total_scores = [] best_non_null_entry = None for entry in nbest: total_scores.append(entry.start_logit + entry.end_logit) if not best_non_null_entry: if entry.text: best_non_null_entry = entry probs = _compute_softmax(total_scores) nbest_json = [] for (i, entry) in enumerate(nbest): output = collections.OrderedDict() output["text"] = entry.text output["probability"] = probs[i] output["start_logit"] = entry.start_logit output["end_logit"] = entry.end_logit nbest_json.append(output) assert len(nbest_json) >= 1 if not version_2_with_negative: all_predictions[example.qas_id] = nbest_json[0]["text"] else: # predict "" iff the null score - the score of best non-null > threshold score_diff = ( score_null - best_non_null_entry.start_logit - best_non_null_entry.end_logit ) scores_diff_json[example.qas_id] = score_diff if score_diff > null_score_diff_threshold: all_predictions[example.qas_id] = "" else: all_predictions[example.qas_id] = best_non_null_entry.text all_nbest_json[example.qas_id] = nbest_json return all_predictions
def run_test( test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True, return_preds=True, return_logits=True, return_encoder_output: bool = False, ): if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() encoder_outputs = [] for step, (batch, batch_metadata) in enumerate( maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose)): regular_log(logger, step, interval=10, tag='test') batch = batch.to(device) with torch.no_grad(): model_outputs = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=False, get_encoder_output=return_encoder_output, ) if return_encoder_output: model_output, encoder_output = model_outputs encoder_outputs.append(encoder_output) else: model_output = model_outputs batch_logits = model_output.logits.detach().cpu().numpy() eval_accumulator.update( batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata, ) output = { "accumulator": eval_accumulator, } if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) if isinstance(eval_accumulator, evaluate.ConcatenateLogitsAccumulator) and return_logits: output["logits"] = eval_accumulator.get_accumulated() if return_encoder_output: output["encoder_outputs_pooled"] = np.concatenate( [encoder_output.pooled for encoder_output in encoder_outputs]) output["encoder_outputs_unpooled"] = np.concatenate( [encoder_output.unpooled for encoder_output in encoder_outputs]) return output
def run_val( val_dataloader, val_labels, jiant_model: JiantModel, task, device, local_rank, tf_writer: SummaryWriter, global_step: Optional[int] = None, phase=None, return_preds=False, return_logits=True, return_encoder_output: bool = False, verbose=True, split='valid', ): # Reminder: # val_dataloader contains mostly PyTorch-relevant info # val_labels might contain more details information needed for full evaluation has_labels = True # TODO: データセットにラベルが存在するかどうかを自動判定する. if not local_rank == -1: return jiant_model.eval() evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task) eval_accumulator = evaluation_scheme.get_accumulator() total_eval_loss = 0 nb_eval_steps, nb_eval_examples = 0, 0 encoder_outputs = [] for step, (batch, batch_metadata) in enumerate( maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, {str(phase)})", verbose=verbose)): regular_log(logger, step, interval=10, tag=split) batch = batch.to(device) with torch.no_grad(): model_outputs = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=has_labels, get_encoder_output=return_encoder_output, ) if return_encoder_output: model_output, encoder_output = model_outputs encoder_outputs.append(encoder_output) else: model_output = model_outputs batch_logits = model_output.logits.detach().cpu().numpy() if has_labels: batch_loss = model_output.loss.mean().item() else: batch_loss = 0 total_eval_loss += batch_loss eval_accumulator.update( batch_logits=batch_logits, batch_loss=batch_loss, batch=batch, batch_metadata=batch_metadata, ) nb_eval_examples += len(batch) nb_eval_steps += 1 eval_loss = total_eval_loss / nb_eval_steps output = { "accumulator": eval_accumulator, } if has_labels: tokenizer = (jiant_model.tokenizer if not torch_utils.is_data_parallel(jiant_model) else jiant_model.module.tokenizer) metrics = evaluation_scheme.compute_metrics_from_accumulator( task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer, ) output.update({ "loss": eval_loss, "metrics": metrics, }) if global_step is not None: for metric_name, metric_value in metrics.minor.items(): tf_writer.add_scalar(f'{split}/{metric_name}', metric_value, global_step=global_step) if return_preds: output["preds"] = evaluation_scheme.get_preds_from_accumulator( task=task, accumulator=eval_accumulator, ) if isinstance(eval_accumulator, evaluate.ConcatenateLogitsAccumulator) and return_logits: output["logits"] = eval_accumulator.get_accumulated() if return_encoder_output: output["encoder_outputs_pooled"] = np.concatenate( [encoder_output.pooled for encoder_output in encoder_outputs]) output["encoder_outputs_unpooled"] = np.concatenate( [encoder_output.unpooled for encoder_output in encoder_outputs]) if global_step is not None: tf_writer.add_scalar(f'{split}/loss', eval_loss, global_step=global_step) tf_writer.flush() return output