Exemple #1
0
 def report(self):
     outfile_name = "{}/dyck_ending_acc.{}".format(
         get_results_dir_of_args(self.args), self.split_name)
     summary_str = f"{self.total_correct / self.total}"
     tqdm.write("Ending Accuracy ({}): {}".format(self.split_name,
                                                  summary_str))
     output_str = "{}\n".format(summary_str)
     with open(outfile_name, "w") as f:
         f.write(output_str)
Exemple #2
0
    def report(self):
        if self.references_batch:
            self.evaluate_batches()
            self.clear_batches()
        bleu = sacrebleu.corpus_bleu(self.predictions, [self.references])
        tqdm.write("Oracle bleu score: {}".format(bleu.score))
        outfile_name = os.path.join(get_results_dir_of_args(self.args),
                                    "oracle_bleu.{}".format(self.split_name))
        with open(outfile_name, "w") as f:
            try:
                f.write("{}\n{}\n".format(bleu.score, str(bleu)))
            except ZeroDivisionError:
                f.write("{}\nNA\n".format(bleu.score))

        outfile_name = os.path.join(
            get_results_dir_of_args(self.args),
            "oracle_bleu_lengths.{}".format(self.split_name))
        with open(outfile_name, "w") as f:
            f.write("\n".join(self.cutoff_lengths))
Exemple #3
0
def main():
    argp = ArgumentParser()
    argp.add_argument("base_dir")
    argp.add_argument("--white_list")
    argp.add_argument("--black_list")
    argp.add_argument("--metric", choices=VALID_METRICS)
    cli_args = argp.parse_args()

    black_list = cli_args.black_list.split(
        ",") if cli_args.black_list else None
    white_list = cli_args.white_list.split(
        ",") if cli_args.white_list else None
    if cli_args.metric is None or cli_args.metric == "close_bracket":
        get_metric = get_mean
        metric_file = "close_bracket_acc.{}"
    elif cli_args.metric == "ppl":
        get_metric = get_ppl
        metric_file = "{}.perplexity"
    elif cli_args.metric == "probe_acc":
        get_metric = get_ppl
        metric_file = "dyck_ending_acc.{}"

    dev_medians = []
    test_medians = []
    print("Iterating through file system")
    for config in os.listdir(cli_args.base_dir):
        if config.startswith("."):
            continue
        if black_list and any([item in config for item in black_list]):
            continue
        if white_list and not any([item in config for item in white_list]):
            continue

        print(f"\t{config}")
        config_path = os.path.join(cli_args.base_dir, config)
        try:
            config_args = utils.load_config_from_path(config_path)
            results_dir = utils.get_results_dir_of_args(config_args)
            dev_results_path = os.path.join(results_dir,
                                            metric_file.format("dev"))
            dev_medians.append(get_metric(dev_results_path))
            test_results_path = os.path.join(results_dir,
                                             metric_file.format("test"))
            test_medians.append(get_metric(test_results_path))
        except FileNotFoundError:
            print(f"Not found: {config_path}")
            continue

    dev_median = np.median(dev_medians)
    test_median = np.median(test_medians)
    print(f"In domain median: {dev_median} (out of {len(dev_medians)})")
    print(f"Out of domain median: {test_median} (out of {len(test_medians)})")
def report_results_dict(args, results, split_name):
    """ Aggregate statistics and write to disk.

  Arguments:
    results: string-key results dictionary from get_dyck_eval_dict
    split_name: string split name in train,dev,test.
  """
    # Report raw statistics
    output_dir = utils.get_results_dir_of_args(args)
    output_path = os.path.join(output_dir, 'dyck-k-eval.json')
    tqdm.write('Writing results to {}'.format(output_path))
    with open(output_path, 'w') as fout:
        json.dump(results, fout)

    # Report summary
    result_column = []
    indices = []
    for i in range(10000):
        key_correct = 'diff{}-1'.format(i)
        key_incorrect = 'diff{}-0'.format(i)
        correct_count = (
            results['correct_closing_bracket_constraint'][key_correct]
            if key_correct in results['correct_closing_bracket_constraint']
            else 0)
        incorrect_count = (
            results['correct_closing_bracket_constraint'][key_incorrect]
            if key_incorrect in results['correct_closing_bracket_constraint']
            else 0)
        if correct_count + incorrect_count >= 1:
            result_column.append(correct_count /
                                 (correct_count + incorrect_count))
            indices.append(i)
    output_dir = utils.get_results_dir_of_args(args)
    output_path = os.path.join(output_dir,
                               'summary-{}.json'.format(split_name))
    tqdm.write('Writing results to {}'.format(output_path))
    with open(output_path, 'w') as fout:
        json.dump(list(zip(result_column, indices)), fout)
def main():
    argp = ArgumentParser()
    argp.add_argument("base_dir")
    cli_args = argp.parse_args()

    test_results = defaultdict(list)
    results = []
    print("Iterating through file system")
    # for config in os.listdir(cli_args.base_dir):
    for dir_name, sub_dirs, file_list in os.walk(cli_args.base_dir):
        for config in file_list:
            config = os.path.join(dir_name, config)
            # import ipdb; ipdb.set_trace()

            for pattern in (seed_pattern_lstm, seed_pattern_transformer):
                s = re.findall(pattern, config)
                if s:
                    break
            if not s:  # no match
                continue

            ls, eos = s[0]
            ls = int(ls)
            print(f"\t{config}")
            # config_path = os.path.join(cli_args.base_dir, config)
            # import ipdb; ipdb.set_trace()
            config_args = utils.load_config_from_path(config)  #_path)
            results_dir = utils.get_results_dir_of_args(config_args)
            dev_results_path = os.path.join(
                results_dir, "oracle_exact_match_acc.dev_sampled")
            dev_result = get_result(dev_results_path)
            if dev_result is None:
                print("\t> error with dev - check if file exists")
                continue
            test_results_path = os.path.join(
                results_dir, "oracle_exact_match_acc.test_sampled")
            test_result = get_result(test_results_path)
            if test_result is None:
                print("\t> error with test - check if file exists")
                continue
            results.append((config, dev_result, test_result))
            test_results[(eos, ls)].append(test_result)

    out_str = ""
    for k in sorted(test_results.keys()):
        out_str += " & {:.2f}".format(np.median(test_results[k]))
        if k[1] == 40:  # max length split
            out_str += "\\\\\n"
    print(out_str)
Exemple #6
0
 def report(self):
     """
     Runs the detokenizer and then sacrebleu to calculate the score.
     """
     predictions = self.detokenize(self.predictions)
     references = self.detokenize(self.references)
     bleu = sacrebleu.corpus_bleu(predictions, [references])
     tqdm.write("Bleu score: {}".format(bleu.score))
     outfile_name = os.path.join(get_results_dir_of_args(self.args),
                                 "bleu.{}".format(self.split_name))
     with open(outfile_name, "w") as f:
         try:
             f.write("{}\n{}\n".format(bleu.score, str(bleu)))
         except ZeroDivisionError:
             f.write("{}\nNA\n".format(bleu.score))
Exemple #7
0
    def report(self):
        outfile_name = "{}/close_bracket_acc.{}".format(
            get_results_dir_of_args(self.args), self.split_name)
        output = []
        all_results = []
        for key in sorted(self.total):
            # if self.total[key] >= 5:
            all_results.append(self.total_correct[key] / self.total[key])
            output.append(f"{key},{self.total_correct[key]/self.total[key]}")
        result_min = np.min(all_results)
        result_median = np.median(all_results)
        result_1pt = np.percentile(all_results, [.25])
        summary_str = f"{result_median} {result_1pt} {result_min}"

        tqdm.write("Weighted Closing Bracket Accuracy ({}): {}".format(
            self.split_name, summary_str))
        output_str = "{}\n{}\n".format("\n".join(output), summary_str)
        with open(outfile_name, "w") as f:
            f.write(output_str)
Exemple #8
0
def main():
    argp = ArgumentParser()
    argp.add_argument('config')
    argp.add_argument('--train', action="store_true")
    argp.add_argument('--train-seq2seq', action="store_true")
    argp.add_argument('--train-truncation', action="store_true")
    argp_args = argp.parse_args()
    args = yaml.safe_load(open(argp_args.config))
    args['train-seq2seq'] = argp_args.train_seq2seq
    args['train-truncation'] = argp_args.train_truncation
    args[
        'train'] = argp_args.train or argp_args.train_truncation or argp_args.train_seq2seq
    args['device'] = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")

    # set random seed
    torch.manual_seed(args['lm']['seed'])
    np.random.seed(args['lm']['seed'])

    # prepare results and model directories
    output_dir = utils.get_results_dir_of_args(args)
    tqdm.write("Writing results to {}".format(output_dir))
    os.makedirs(output_dir, exist_ok=True)
    copyfile(argp_args.config, "{}/{}".format(output_dir, "config.yaml"))

    # seq2seq
    model_dir = utils.get_lm_path_of_args(args)
    if not model_dir.endswith(".params"):
        os.makedirs(model_dir, exist_ok=True)

    # Search for dataset
    dataset = dataset_lookup[args['data']['dataset_type']](args)

    # Run whatever experiment necessary
    if args['lm']['lm_type'] == 'rnnlm':
        run_lm(args, dataset)
    elif args['lm']['lm_type'] in seq2seq_lookup:
        # truncation
        model_dir = utils.get_trunc_model_path_of_args(args)
        if not model_dir.endswith(".params"):
            os.makedirs(model_dir, exist_ok=True)
        run_seq2seq(args, dataset)
    argp = ArgumentParser()
    argp.add_argument('config')
    args = argp.parse_args()
    args = yaml.load(open(args.config))

    # Determine whether CUDA is available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args['device'] = device

    # Construct the language model and dataset objects
    dataset = Dataset(args)
    input_size = args['lm']['embedding_dim']
    hidden_size = args['lm']['hidden_dim']
    recurrent_model = rnn.PytorchRecurrentModel(args, input_size, hidden_size,
                                                args['lm']['num_layers'])
    lm_model = lm.TraditionalLanguageModel(args, recurrent_model)

    # Prepare to write results
    output_dir = utils.get_results_dir_of_args(args)
    tqdm.write('Writing results to {}'.format(output_dir))
    os.makedirs(utils.get_results_dir_of_args(args), exist_ok=True)

    # Train and load most recent parameters
    train(args, lm_model, dataset.get_train_dataloader(),
          dataset.get_dev_dataloader())
    lm_model.load_state_dict(torch.load(utils.get_lm_path_of_args(args)))

    # Evaluate language model
    reporter.run_evals(args, lm_model, dataset, 'dev')
    reporter.run_evals(args, lm_model, dataset, 'test')
from models import Seq2SeqLSTM
from utils import load_config_from_path, get_results_dir_of_args

matplotlib.use('agg')
matplotlib.rcParams['font.size'] = 16
matplotlib.rcParams['legend.fontsize'] = 'large'
matplotlib.rcParams['figure.titlesize'] = 'medium'

SCAN_HS = namedtuple("SCAN_hidden_states", ["train", "dev", "test"])
COLOR_MAP = 'cool'

FIGURE_PATH = "results/plots/scan/{}_plot.png"
NO_EOS_CONFIG = "configs/scan/plots/scan_-EOS.yaml"
EOS_CONFIG = "configs/scan/plots/scan_+EOS.yaml"

NO_EOS_DIR = get_results_dir_of_args(load_config_from_path(NO_EOS_CONFIG)) # "results/ls-22_eos-F/scan-scanls22-bs16_dec_cellstandard_dec_num_layers1_dropout0.5_embed_dim200_enc_cellstandard_enc_num_layers2_hidden_dim200_seq2seq_lr0.001_seed500_tfr0.5_eosF-model_typeoracle_truncation_pathmodels/ls-22_eos-F"
EOS_DIR = get_results_dir_of_args(load_config_from_path(EOS_CONFIG)) # "results/ls-22_eos-T/scan-scanls22-bs16_dec_cellstandard_dec_num_layers1_dropout0.5_embed_dim200_enc_cellstandard_enc_num_layers2_hidden_dim200_seq2seq_lr0.001_seed500_tfr0.5_eosT-model_typeoracle_truncation_pathmodels/ls-22_eos-T"


def get_sequence_cutoffs(sequences):
    cum_lens = np.cumsum([len(seq) for seq in sequences]).tolist()
    return [slice(start, end) for start, end in zip([0] + cum_lens[:-1], cum_lens)]

def load_observations(data_loader):
    all_observations = []
    for _, _, _, observations in data_loader:
        all_observations.extend([obs.target_tokens[1:] for obs in observations])
    return all_observations

def slices_to_ints(slices_list):
    arr_len = sum([s.stop - s.start for s in slices_list])
Exemple #11
0
 def report(self):
     cache = self.lm.dump_cache(clear=False)
     outfile_name = "{}/cache.{}.npy".format(
         get_results_dir_of_args(self.args), self.split_name)
     np.save(outfile_name, cache)