def report(self): outfile_name = "{}/dyck_ending_acc.{}".format( get_results_dir_of_args(self.args), self.split_name) summary_str = f"{self.total_correct / self.total}" tqdm.write("Ending Accuracy ({}): {}".format(self.split_name, summary_str)) output_str = "{}\n".format(summary_str) with open(outfile_name, "w") as f: f.write(output_str)
def report(self): if self.references_batch: self.evaluate_batches() self.clear_batches() bleu = sacrebleu.corpus_bleu(self.predictions, [self.references]) tqdm.write("Oracle bleu score: {}".format(bleu.score)) outfile_name = os.path.join(get_results_dir_of_args(self.args), "oracle_bleu.{}".format(self.split_name)) with open(outfile_name, "w") as f: try: f.write("{}\n{}\n".format(bleu.score, str(bleu))) except ZeroDivisionError: f.write("{}\nNA\n".format(bleu.score)) outfile_name = os.path.join( get_results_dir_of_args(self.args), "oracle_bleu_lengths.{}".format(self.split_name)) with open(outfile_name, "w") as f: f.write("\n".join(self.cutoff_lengths))
def main(): argp = ArgumentParser() argp.add_argument("base_dir") argp.add_argument("--white_list") argp.add_argument("--black_list") argp.add_argument("--metric", choices=VALID_METRICS) cli_args = argp.parse_args() black_list = cli_args.black_list.split( ",") if cli_args.black_list else None white_list = cli_args.white_list.split( ",") if cli_args.white_list else None if cli_args.metric is None or cli_args.metric == "close_bracket": get_metric = get_mean metric_file = "close_bracket_acc.{}" elif cli_args.metric == "ppl": get_metric = get_ppl metric_file = "{}.perplexity" elif cli_args.metric == "probe_acc": get_metric = get_ppl metric_file = "dyck_ending_acc.{}" dev_medians = [] test_medians = [] print("Iterating through file system") for config in os.listdir(cli_args.base_dir): if config.startswith("."): continue if black_list and any([item in config for item in black_list]): continue if white_list and not any([item in config for item in white_list]): continue print(f"\t{config}") config_path = os.path.join(cli_args.base_dir, config) try: config_args = utils.load_config_from_path(config_path) results_dir = utils.get_results_dir_of_args(config_args) dev_results_path = os.path.join(results_dir, metric_file.format("dev")) dev_medians.append(get_metric(dev_results_path)) test_results_path = os.path.join(results_dir, metric_file.format("test")) test_medians.append(get_metric(test_results_path)) except FileNotFoundError: print(f"Not found: {config_path}") continue dev_median = np.median(dev_medians) test_median = np.median(test_medians) print(f"In domain median: {dev_median} (out of {len(dev_medians)})") print(f"Out of domain median: {test_median} (out of {len(test_medians)})")
def report_results_dict(args, results, split_name): """ Aggregate statistics and write to disk. Arguments: results: string-key results dictionary from get_dyck_eval_dict split_name: string split name in train,dev,test. """ # Report raw statistics output_dir = utils.get_results_dir_of_args(args) output_path = os.path.join(output_dir, 'dyck-k-eval.json') tqdm.write('Writing results to {}'.format(output_path)) with open(output_path, 'w') as fout: json.dump(results, fout) # Report summary result_column = [] indices = [] for i in range(10000): key_correct = 'diff{}-1'.format(i) key_incorrect = 'diff{}-0'.format(i) correct_count = ( results['correct_closing_bracket_constraint'][key_correct] if key_correct in results['correct_closing_bracket_constraint'] else 0) incorrect_count = ( results['correct_closing_bracket_constraint'][key_incorrect] if key_incorrect in results['correct_closing_bracket_constraint'] else 0) if correct_count + incorrect_count >= 1: result_column.append(correct_count / (correct_count + incorrect_count)) indices.append(i) output_dir = utils.get_results_dir_of_args(args) output_path = os.path.join(output_dir, 'summary-{}.json'.format(split_name)) tqdm.write('Writing results to {}'.format(output_path)) with open(output_path, 'w') as fout: json.dump(list(zip(result_column, indices)), fout)
def main(): argp = ArgumentParser() argp.add_argument("base_dir") cli_args = argp.parse_args() test_results = defaultdict(list) results = [] print("Iterating through file system") # for config in os.listdir(cli_args.base_dir): for dir_name, sub_dirs, file_list in os.walk(cli_args.base_dir): for config in file_list: config = os.path.join(dir_name, config) # import ipdb; ipdb.set_trace() for pattern in (seed_pattern_lstm, seed_pattern_transformer): s = re.findall(pattern, config) if s: break if not s: # no match continue ls, eos = s[0] ls = int(ls) print(f"\t{config}") # config_path = os.path.join(cli_args.base_dir, config) # import ipdb; ipdb.set_trace() config_args = utils.load_config_from_path(config) #_path) results_dir = utils.get_results_dir_of_args(config_args) dev_results_path = os.path.join( results_dir, "oracle_exact_match_acc.dev_sampled") dev_result = get_result(dev_results_path) if dev_result is None: print("\t> error with dev - check if file exists") continue test_results_path = os.path.join( results_dir, "oracle_exact_match_acc.test_sampled") test_result = get_result(test_results_path) if test_result is None: print("\t> error with test - check if file exists") continue results.append((config, dev_result, test_result)) test_results[(eos, ls)].append(test_result) out_str = "" for k in sorted(test_results.keys()): out_str += " & {:.2f}".format(np.median(test_results[k])) if k[1] == 40: # max length split out_str += "\\\\\n" print(out_str)
def report(self): """ Runs the detokenizer and then sacrebleu to calculate the score. """ predictions = self.detokenize(self.predictions) references = self.detokenize(self.references) bleu = sacrebleu.corpus_bleu(predictions, [references]) tqdm.write("Bleu score: {}".format(bleu.score)) outfile_name = os.path.join(get_results_dir_of_args(self.args), "bleu.{}".format(self.split_name)) with open(outfile_name, "w") as f: try: f.write("{}\n{}\n".format(bleu.score, str(bleu))) except ZeroDivisionError: f.write("{}\nNA\n".format(bleu.score))
def report(self): outfile_name = "{}/close_bracket_acc.{}".format( get_results_dir_of_args(self.args), self.split_name) output = [] all_results = [] for key in sorted(self.total): # if self.total[key] >= 5: all_results.append(self.total_correct[key] / self.total[key]) output.append(f"{key},{self.total_correct[key]/self.total[key]}") result_min = np.min(all_results) result_median = np.median(all_results) result_1pt = np.percentile(all_results, [.25]) summary_str = f"{result_median} {result_1pt} {result_min}" tqdm.write("Weighted Closing Bracket Accuracy ({}): {}".format( self.split_name, summary_str)) output_str = "{}\n{}\n".format("\n".join(output), summary_str) with open(outfile_name, "w") as f: f.write(output_str)
def main(): argp = ArgumentParser() argp.add_argument('config') argp.add_argument('--train', action="store_true") argp.add_argument('--train-seq2seq', action="store_true") argp.add_argument('--train-truncation', action="store_true") argp_args = argp.parse_args() args = yaml.safe_load(open(argp_args.config)) args['train-seq2seq'] = argp_args.train_seq2seq args['train-truncation'] = argp_args.train_truncation args[ 'train'] = argp_args.train or argp_args.train_truncation or argp_args.train_seq2seq args['device'] = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") # set random seed torch.manual_seed(args['lm']['seed']) np.random.seed(args['lm']['seed']) # prepare results and model directories output_dir = utils.get_results_dir_of_args(args) tqdm.write("Writing results to {}".format(output_dir)) os.makedirs(output_dir, exist_ok=True) copyfile(argp_args.config, "{}/{}".format(output_dir, "config.yaml")) # seq2seq model_dir = utils.get_lm_path_of_args(args) if not model_dir.endswith(".params"): os.makedirs(model_dir, exist_ok=True) # Search for dataset dataset = dataset_lookup[args['data']['dataset_type']](args) # Run whatever experiment necessary if args['lm']['lm_type'] == 'rnnlm': run_lm(args, dataset) elif args['lm']['lm_type'] in seq2seq_lookup: # truncation model_dir = utils.get_trunc_model_path_of_args(args) if not model_dir.endswith(".params"): os.makedirs(model_dir, exist_ok=True) run_seq2seq(args, dataset)
argp = ArgumentParser() argp.add_argument('config') args = argp.parse_args() args = yaml.load(open(args.config)) # Determine whether CUDA is available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") args['device'] = device # Construct the language model and dataset objects dataset = Dataset(args) input_size = args['lm']['embedding_dim'] hidden_size = args['lm']['hidden_dim'] recurrent_model = rnn.PytorchRecurrentModel(args, input_size, hidden_size, args['lm']['num_layers']) lm_model = lm.TraditionalLanguageModel(args, recurrent_model) # Prepare to write results output_dir = utils.get_results_dir_of_args(args) tqdm.write('Writing results to {}'.format(output_dir)) os.makedirs(utils.get_results_dir_of_args(args), exist_ok=True) # Train and load most recent parameters train(args, lm_model, dataset.get_train_dataloader(), dataset.get_dev_dataloader()) lm_model.load_state_dict(torch.load(utils.get_lm_path_of_args(args))) # Evaluate language model reporter.run_evals(args, lm_model, dataset, 'dev') reporter.run_evals(args, lm_model, dataset, 'test')
from models import Seq2SeqLSTM from utils import load_config_from_path, get_results_dir_of_args matplotlib.use('agg') matplotlib.rcParams['font.size'] = 16 matplotlib.rcParams['legend.fontsize'] = 'large' matplotlib.rcParams['figure.titlesize'] = 'medium' SCAN_HS = namedtuple("SCAN_hidden_states", ["train", "dev", "test"]) COLOR_MAP = 'cool' FIGURE_PATH = "results/plots/scan/{}_plot.png" NO_EOS_CONFIG = "configs/scan/plots/scan_-EOS.yaml" EOS_CONFIG = "configs/scan/plots/scan_+EOS.yaml" NO_EOS_DIR = get_results_dir_of_args(load_config_from_path(NO_EOS_CONFIG)) # "results/ls-22_eos-F/scan-scanls22-bs16_dec_cellstandard_dec_num_layers1_dropout0.5_embed_dim200_enc_cellstandard_enc_num_layers2_hidden_dim200_seq2seq_lr0.001_seed500_tfr0.5_eosF-model_typeoracle_truncation_pathmodels/ls-22_eos-F" EOS_DIR = get_results_dir_of_args(load_config_from_path(EOS_CONFIG)) # "results/ls-22_eos-T/scan-scanls22-bs16_dec_cellstandard_dec_num_layers1_dropout0.5_embed_dim200_enc_cellstandard_enc_num_layers2_hidden_dim200_seq2seq_lr0.001_seed500_tfr0.5_eosT-model_typeoracle_truncation_pathmodels/ls-22_eos-T" def get_sequence_cutoffs(sequences): cum_lens = np.cumsum([len(seq) for seq in sequences]).tolist() return [slice(start, end) for start, end in zip([0] + cum_lens[:-1], cum_lens)] def load_observations(data_loader): all_observations = [] for _, _, _, observations in data_loader: all_observations.extend([obs.target_tokens[1:] for obs in observations]) return all_observations def slices_to_ints(slices_list): arr_len = sum([s.stop - s.start for s in slices_list])
def report(self): cache = self.lm.dump_cache(clear=False) outfile_name = "{}/cache.{}.npy".format( get_results_dir_of_args(self.args), self.split_name) np.save(outfile_name, cache)