Exemplo n.º 1
0
    def _read(self, file_path: str):
        tarball_with_all_lfs: str = None
        for filename in os.listdir(self._offline_logical_forms_directory):
            if filename.endswith(".tar.gz"):
                tarball_with_all_lfs = os.path.join(self._offline_logical_forms_directory,
                                                    filename)
                break
        if tarball_with_all_lfs is not None:
            logger.info(f"Found a tarball in offline logical forms directory: {tarball_with_all_lfs}")
            logger.info("Assuming it contains logical forms for all questions and un-taring it.")
            # If you're running this with beaker, the input directory will be read-only and we
            # cannot untar the files in the directory itself. So we will do so in /tmp, but that
            # means the new offline logical forms directory will be /tmp.
            self._offline_logical_forms_directory = "/tmp/"
            tarfile.open(tarball_with_all_lfs,
                         mode='r:gz').extractall(path=self._offline_logical_forms_directory)
        data: List[Dict[str, str]] = []
        if file_path.endswith(".jsonl"):
            with open(file_path, "r") as data_file:
                for line in data_file:
                    if not line:
                        continue
                    line_data = json.loads(line)
                    line_data["logical_form"] = [line_data["logical_form"]]
                    data.append(line_data)
        elif file_path.endswith(".examples"):
            num_examples = 0
            num_examples_without_lf = 0
            if self._offline_logical_forms_directory is None:
                raise RuntimeError("Logical forms directory required when processing examples files!")
            with open(file_path, "r") as data_file:
                for line in data_file:
                    num_examples += 1
                    line_data = wtq_data_util.parse_example_line(line)
                    example_id = line_data["id"]
                    logical_forms_file = os.path.join(self._offline_logical_forms_directory,
                                                      f"{example_id}.gz")
                    if not os.path.exists(logical_forms_file):
                        num_examples_without_lf += 1
                        continue
                    logical_forms = None
                    with gzip.open(logical_forms_file, "rt") as lf_file:
                        logical_forms = [x.strip() for x in lf_file.readlines()][:self._max_num_logical_forms]
                    line_data["logical_form"] = logical_forms
                    data.append(line_data)
            logger.info(f"Skipped {num_examples_without_lf} out of {num_examples} examples")
        else:
            raise RuntimeError(f"Unknown file type: {file_path}. Was expecting either *.examples or *.jsonl")

        for datum in data:
            # We want the tagged file, but the ``*.examples`` files typically point to CSV.
            table_filename = os.path.join(self._tables_directory,
                                          datum["table_filename"].replace("csv", "tagged"))

            table_lines = [line.split("\t") for line in open(table_filename).readlines()]
            instance = self.text_to_instance(logical_forms=datum["logical_form"],
                                             table_lines=table_lines,
                                             question=datum["question"])
            if instance is not None:
                yield instance
def search(tables_directory: str,
           input_examples_file: str,
           output_file: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    with open(output_file, "w") as output_file_pointer:
        for instance_data in data:
            utterance = instance_data["question"]
            question_id = instance_data["id"]
            if utterance.startswith('"') and utterance.endswith('"'):
                utterance = utterance[1:-1]
            # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
            table_file = instance_data["table_filename"].replace("csv", "tagged")
            # pylint: disable=protected-access
            target_list = [TableQuestionContext._normalize_string(value) for value in
                           instance_data["target_values"]]
            try:
                target_value_list = evaluator.to_value_list(target_list)
            except:
                print(target_list)
                target_value_list = evaluator.to_value_list(target_list)
            tokenized_question = tokenizer.tokenize(utterance)
            table_file = f"{tables_directory}/{table_file}"
            context = TableQuestionContext.read_from_file(table_file, tokenized_question)
            world = WikiTablesVariableFreeWorld(context)
            walker = ActionSpaceWalker(world, max_path_length=max_path_length)
            correct_logical_forms = []
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                agenda = world.get_agenda()
                print(f"Agenda: {agenda}", file=output_file_pointer)
                all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                         max_num_logical_forms=10000)
            else:
                all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
            for logical_form in all_logical_forms:
                try:
                    denotation = world.execute(logical_form)
                except ExecutionError:
                    print(f"Failed to execute: {logical_form}", file=sys.stderr)
                    continue
                if isinstance(denotation, list):
                    denotation_list = [str(denotation_item) for denotation_item in denotation]
                else:
                    # For numbers and dates
                    denotation_list = [str(denotation)]
                denotation_value_list = evaluator.to_value_list(denotation_list)
                if evaluator.check_denotation(target_value_list, denotation_value_list):
                    correct_logical_forms.append(logical_form)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
Exemplo n.º 3
0
def make_data(input_examples_file: str, tables_directory: str,
              archived_model_file: str, output_dir: str,
              num_logical_forms: int, variable_free: bool) -> None:
    if variable_free:
        reader = WikiTablesVariableFreeDatasetReader(
            tables_directory=tables_directory,
            keep_if_no_logical_forms=True,
            output_agendas=True)
    else:
        reader = WikiTablesDatasetReader(tables_directory=tables_directory,
                                         keep_if_no_dpd=True,
                                         output_agendas=True)
    dataset = reader.read(input_examples_file)
    input_lines = []
    with open(input_examples_file) as input_file:
        input_lines = input_file.readlines()
    if variable_free:
        new_tables_config = {}
    else:
        # Note: Double { for escaping {.
        new_tables_config = f"{{model: {{tables_directory: {tables_directory}}}}}"
    archive = load_archive(archived_model_file, overrides=new_tables_config)
    model = archive.model
    model.training = False
    model._decoder_trainer._max_num_decoded_sequences = 100
    for instance, example_line in zip(dataset, input_lines):
        outputs = model.forward_on_instance(instance)
        parsed_info = parse_example_line(example_line)
        example_id = parsed_info["id"]
        logical_forms = outputs["logical_form"]
        correct_logical_forms = []
        for logical_form in logical_forms:
            if variable_free:
                world = instance.fields["world"].metadata
                target_values = instance.fields["target_values"].metadata
                logical_form_is_correct = world.evaluate_logical_form(
                    logical_form, target_values)
            else:
                logical_form_is_correct = model._executor.evaluate_logical_form(
                    logical_form, example_line)
            if logical_form_is_correct:
                correct_logical_forms.append(logical_form)
                if len(correct_logical_forms) >= num_logical_forms:
                    break
        num_found = len(correct_logical_forms)
        print(f"{num_found} found for {example_id}")
        if num_found == 0:
            continue
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_file = gzip.open(os.path.join(output_dir, f"{example_id}.gz"),
                                "wb")
        for logical_form in correct_logical_forms:
            logical_form_line = (logical_form + "\n").encode('utf-8')
            output_file.write(logical_form_line)
        output_file.close()
Exemplo n.º 4
0
 def test_parse_example_line(self):
     # pylint: disable=no-self-use,protected-access
     with open(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples") as filename:
         lines = filename.readlines()
     example_info = wikitables_util.parse_example_line(lines[0])
     question = 'what was the last year where this team was a part of the usl a-league?'
     assert example_info == {'id': 'nt-0',
                             'question': question,
                             'table_filename': 'tables/590.csv',
                             'target_values': ['2004']}
Exemplo n.º 5
0
def search(tables_directory: str, input_examples_file: str, output_path: str,
           max_path_length: int, max_num_logical_forms: int, use_agenda: bool,
           output_separate_files: bool) -> None:
    data = [
        wikitables_util.parse_example_line(example_line)
        for example_line in open(input_examples_file)
    ]
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file,
                                                      tokenized_question)
        world = WikiTablesVariableFreeWorld(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda()
            all_logical_forms = walker.get_logical_forms_with_agenda(
                agenda=agenda,
                max_num_logical_forms=10000,
                allow_partial_match=True)
        else:
            all_logical_forms = walker.get_all_logical_forms(
                max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz",
                           "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
Exemplo n.º 6
0
    def _read(self, file_path: str):
        with open(file_path, "r") as data_file:
            num_missing_logical_forms = 0
            num_lines = 0
            num_instances = 0
            for line in data_file.readlines():
                line = line.strip("\n")
                if not line:
                    continue
                num_lines += 1
                parsed_info = wikitables_util.parse_example_line(line)
                question = parsed_info["question"]
                # We want the tagged file, but the ``*.examples`` files typically point to CSV.
                table_filename = os.path.join(
                    self._tables_directory,
                    parsed_info["table_filename"].replace("csv", "tagged"))
                if self._offline_logical_forms_directory:
                    logical_forms_filename = os.path.join(
                        self._offline_logical_forms_directory,
                        parsed_info["id"] + '.gz')
                    try:
                        logical_forms_file = gzip.open(logical_forms_filename)
                        logical_forms = []
                        for logical_form_line in logical_forms_file:
                            logical_forms.append(
                                logical_form_line.strip().decode('utf-8'))
                    except FileNotFoundError:
                        logger.debug(
                            f'Missing search output for instance {parsed_info["id"]}; skipping...'
                        )
                        logical_forms = None
                        num_missing_logical_forms += 1
                        if not self._keep_if_no_logical_forms:
                            continue
                else:
                    logical_forms = None

                table_lines = [
                    line.split("\t")
                    for line in open(table_filename).readlines()
                ]
                instance = self.text_to_instance(
                    question=question,
                    table_lines=table_lines,
                    target_values=parsed_info["target_values"],
                    offline_search_output=logical_forms)
                if instance is not None:
                    num_instances += 1
                    yield instance

        if self._offline_logical_forms_directory:
            logger.info(
                f"Missing logical forms for {num_missing_logical_forms} out of {num_lines} instances"
            )
            logger.info(f"Kept {num_instances} instances")
Exemplo n.º 7
0
def search(tables_directory: str,
           input_examples_file: str,
           output_path: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool,
           output_separate_files: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file, tokenized_question)
        world = WikiTablesVariableFreeWorld(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda()
            all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                     max_num_logical_forms=10000)
        else:
            all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
Exemplo n.º 8
0
 def test_parse_example_line(self):
     # pylint: disable=no-self-use,protected-access
     with open(self.FIXTURES_ROOT / "data" / "wikitables" /
               "sample_data.examples") as filename:
         lines = filename.readlines()
     example_info = wikitables_util.parse_example_line(lines[0])
     question = 'what was the last year where this team was a part of the usl a-league?'
     assert example_info == {
         'id': 'nt-0',
         'question': question,
         'table_filename': 'tables/590.csv',
         'target_values': ['2004']
     }
Exemplo n.º 9
0
def rerank_lf(model_file, input_examples_file, params_file, lf_directory,
              output_directory):
    model = load_archive(model_file).model
    model.eval()

    params = Params.from_file(params_file)
    latent_alignment_reader = DatasetReader.from_params(
        params.pop('dataset_reader'))

    with open(input_examples_file) as input_file:
        input_lines = input_file.readlines()

    found = 0.0
    for line in tqdm(input_lines):
        parsed_info = util.parse_example_line(line)
        example_id = parsed_info["id"]
        lf_output_filename = os.path.join(lf_directory,
                                          parsed_info["id"] + '.gz')
        try:
            lf_file = gzip.open(lf_output_filename)
            sempre_forms = [lf.strip().decode('utf-8') for lf in lf_file]
            question = parsed_info['question']
            instance = latent_alignment_reader.text_to_instance(
                question, sempre_forms)
            output = model.forward_on_instance(instance)
            similarities = output['all_similarities']
            top_lfs = [
                lf for lf, score in sorted(zip(sempre_forms, similarities),
                                           key=lambda x: x[1],
                                           reverse=True)[:10]
            ]
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)
            output_file = gzip.open(
                os.path.join(output_directory, f"{example_id}.gz"), "wb")
            for logical_form in top_lfs:
                logical_form_line = (logical_form + "\n").encode('utf-8')
                output_file.write(logical_form_line)
            output_file.close()
            found += 1.0
        except FileNotFoundError:
            continue
    print(f"Found for {found/len(input_lines)} examples")
Exemplo n.º 10
0
def make_data(
    input_examples_file: str,
    tables_directory: str,
    archived_model_file: str,
    output_dir: str,
    num_logical_forms: int,
) -> None:
    reader = WikiTablesDatasetReader(
        tables_directory=tables_directory, keep_if_no_logical_forms=True, output_agendas=True
    )
    dataset = reader.read(input_examples_file)
    input_lines = []
    with open(input_examples_file) as input_file:
        input_lines = input_file.readlines()
    archive = load_archive(archived_model_file)
    model = archive.model
    model.training = False
    model._decoder_trainer._max_num_decoded_sequences = 100
    for instance, example_line in zip(dataset, input_lines):
        outputs = model.forward_on_instance(instance)
        world = instance.fields["world"].metadata
        parsed_info = util.parse_example_line(example_line)
        example_id = parsed_info["id"]
        target_list = parsed_info["target_values"]
        logical_forms = outputs["logical_form"]
        correct_logical_forms = []
        for logical_form in logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
                if len(correct_logical_forms) >= num_logical_forms:
                    break
        num_found = len(correct_logical_forms)
        print(f"{num_found} found for {example_id}")
        if num_found == 0:
            continue
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_file = gzip.open(os.path.join(output_dir, f"{example_id}.gz"), "wb")
        for logical_form in correct_logical_forms:
            logical_form_line = (logical_form + "\n").encode("utf-8")
            output_file.write(logical_form_line)
        output_file.close()
Exemplo n.º 11
0
def process_file(file_path: str, out_path: str):
    examples = []
    with open(file_path, "r") as data_file:
        for line in tqdm(data_file.readlines()[:200]):
            line = line.strip("\n")
            if not line:
                continue
            parsed_info = util.parse_example_line(line)
            question = parsed_info["question"]
            dpd_output_filename = os.path.join(DPD_PATH,
                                               parsed_info["id"] + '.gz')
            try:
                dpd_file = gzip.open(dpd_output_filename)
                sempre_forms = [
                    dpd_line.strip().decode('utf-8') for dpd_line in dpd_file
                ]
            except FileNotFoundError:
                continue
            examples.append((question, sempre_forms))
    with open(out_path, "w") as out_file:
        json.dump(examples, out_file, indent=2)
def main(examples_file: str,
         tables_directory: str,
         logical_forms_directory: str,
         output_file: str) -> None:
    examples: List[Dict] = []
    with open(examples_file) as input_file:
        for line in input_file:
            examples.append(wtq_util.parse_example_line(line))
    random.shuffle(examples)  # Shuffling to label in random order

    processed_examples: Set[str] = set()
    if os.path.exists(output_file):
        with open(output_file) as output_file_for_reading:
            for line in output_file_for_reading:
                example_id = json.loads(line)["id"]
                processed_examples.add(example_id)

    with open(output_file, "a") as output_file_for_appending:
        for example in examples:
            example_id = example["id"]
            if example_id in processed_examples:
                # We've already labeled this example
                continue
            question = example["question"]
            table_filename = example["table_filename"]
            full_table_filename = os.path.join(tables_directory, table_filename)
            table_lines = []
            with open(full_table_filename.replace(".csv", ".tsv")) as table_file:
                table_lines = table_file.readlines()
            logical_forms_file = os.path.join(logical_forms_directory, f"{example_id}.gz")
            if not os.path.exists(logical_forms_file):
                continue
            print("".join(table_lines))
            print()
            with gzip.open(logical_forms_file, "rt") as lf_file:
                for i, logical_form in enumerate(lf_file):
                    logical_form = logical_form.strip()
                    print(question)
                    print(logical_form)
                    print()
                    user_input = None
                    while user_input not in ["y", "n", "w", "s"]:
                        user_input = input("Correct? ('y'es / 'n'o / 'w'rite correct lf / 's'kip): ")
                        user_input = user_input.lower()
                    if user_input == "s":
                        break
                    elif user_input == "y":
                        instance_output = {"id": example_id,
                                           "question": question,
                                           "table_filename": table_filename,
                                           "logical_form": logical_form}
                        print(json.dumps(instance_output), file=output_file_for_appending)
                        break
                    elif user_input == "w":
                        correct_logical_form = input("Enter correct logical form: ")
                        instance_output = {"id": example_id,
                                           "question": question,
                                           "table_filename": table_filename,
                                           "logical_form": correct_logical_form}
                        print(json.dumps(instance_output), file=output_file_for_appending)
                        break
                    if i >= MAX_NUM_LOGICAL_FORMS_TO_SHOW:
                        break
     dest="output_separate_files",
     action="store_true",
     help="""If set, the script will output gzipped
                     files, one per example. You may want to do this if you;re making data to
                     train a parser.""",
 )
 parser.add_argument(
     "--num-splits",
     dest="num_splits",
     type=int,
     default=0,
     help="Number of splits to make of the data, to run as many processes (default 0)",
 )
 args = parser.parse_args()
 input_data = [
     wikitables_util.parse_example_line(example_line) for example_line in open(args.data_file)
 ]
 if args.num_splits == 0 or len(input_data) <= args.num_splits or not args.output_separate_files:
     search(
         args.table_directory,
         input_data,
         args.output_path,
         args.max_path_length,
         args.max_num_logical_forms,
         args.use_agenda,
         args.output_separate_files,
         args.conservative,
     )
 else:
     chunk_size = math.ceil(len(input_data) / args.num_splits)
     start_index = 0
Exemplo n.º 14
0
                     dest="output_separate_files",
                     action="store_true",
                     help="""If set, the script will output gzipped
                     files, one per example. You may want to do this if you;re making data to
                     train a parser.""")
 parser.add_argument(
     "--num-splits",
     dest="num_splits",
     type=int,
     default=0,
     help=
     "Number of splits to make of the data, to run as many processes (default 0)"
 )
 args = parser.parse_args()
 input_data = [
     wikitables_util.parse_example_line(example_line)
     for example_line in open(args.data_file)
 ]
 if args.num_splits == 0 or len(
         input_data) <= args.num_splits or not args.output_separate_files:
     search(args.table_directory, input_data, args.output_path,
            args.max_path_length, args.max_num_logical_forms,
            args.use_agenda, args.output_separate_files, args.conservative)
 else:
     chunk_size = math.ceil(len(input_data) / args.num_splits)
     start_index = 0
     for i in range(args.num_splits):
         if i == args.num_splits - 1:
             data_split = input_data[start_index:]
         else:
             data_split = input_data[start_index:start_index + chunk_size]
Exemplo n.º 15
0
    def _read(self, file_path: str):
        # Checking if there is a single tarball with all the logical forms. If so, untaring it
        # first.
        if self._offline_logical_forms_directory:
            tarball_with_all_lfs: str = None
            for filename in os.listdir(self._offline_logical_forms_directory):
                if filename.endswith(".tar.gz"):
                    tarball_with_all_lfs = os.path.join(
                        self._offline_logical_forms_directory, filename)
                    break
            if tarball_with_all_lfs is not None:
                logger.info(
                    f"Found a tarball in offline logical forms directory: {tarball_with_all_lfs}"
                )
                logger.info(
                    "Assuming it contains logical forms for all questions and un-taring it."
                )
                # If you're running this with beaker, the input directory will be read-only and we
                # cannot untar the files in the directory itself. So we will do so in /tmp, but that
                # means the new offline logical forms directory will be /tmp.
                self._offline_logical_forms_directory = "/tmp/"
                tarfile.open(tarball_with_all_lfs, mode='r:gz').extractall(
                    path=self._offline_logical_forms_directory)
        with open(file_path, "r") as data_file:
            num_missing_logical_forms = 0
            num_lines = 0
            num_instances = 0
            for line in data_file.readlines():
                line = line.strip("\n")
                if not line:
                    continue
                num_lines += 1
                parsed_info = wikitables_util.parse_example_line(line)
                question = parsed_info["question"]
                # We want the tagged file, but the ``*.examples`` files typically point to CSV.
                table_filename = os.path.join(
                    self._tables_directory,
                    parsed_info["table_filename"].replace("csv", "tagged"))
                if self._offline_logical_forms_directory:
                    logical_forms_filename = os.path.join(
                        self._offline_logical_forms_directory,
                        parsed_info["id"] + '.gz')
                    try:
                        logical_forms_file = gzip.open(logical_forms_filename)
                        logical_forms = []
                        for logical_form_line in logical_forms_file:
                            logical_forms.append(
                                logical_form_line.strip().decode('utf-8'))
                    except FileNotFoundError:
                        logger.debug(
                            f'Missing search output for instance {parsed_info["id"]}; skipping...'
                        )
                        logical_forms = None
                        num_missing_logical_forms += 1
                        if not self._keep_if_no_logical_forms:
                            continue
                else:
                    logical_forms = None

                table_lines = [
                    line.split("\t")
                    for line in open(table_filename).readlines()
                ]
                instance = self.text_to_instance(
                    question=question,
                    table_lines=table_lines,
                    target_values=parsed_info["target_values"],
                    offline_search_output=logical_forms)
                if instance is not None:
                    num_instances += 1
                    yield instance

        if self._offline_logical_forms_directory:
            logger.info(
                f"Missing logical forms for {num_missing_logical_forms} out of {num_lines} instances"
            )
            logger.info(f"Kept {num_instances} instances")
Exemplo n.º 16
0
    def _read_examples_file(self, file_path: str):
        with open(file_path, "r") as data_file:
            num_dpd_missing = 0
            num_lines = 0
            num_instances = 0
            for line in data_file.readlines():
                line = line.strip("\n")
                if not line:
                    continue
                num_lines += 1
                parsed_info = util.parse_example_line(line)
                question = parsed_info["question"]
                # We want the TSV file, but the ``*.examples`` files typically point to CSV.
                table_filename = os.path.join(
                    self._tables_directory,
                    parsed_info["table_filename"].replace(".csv", ".tsv"))
                if self._dpd_output_directory:
                    dpd_output_filename = os.path.join(
                        self._dpd_output_directory, parsed_info["id"] + '.gz')
                    try:
                        dpd_file = gzip.open(dpd_output_filename)
                        if self._sort_dpd_logical_forms:
                            sempre_forms = [
                                dpd_line.strip().decode('utf-8')
                                for dpd_line in dpd_file
                            ]
                            # We'll sort by the number of open parens in the logical form, which
                            # tells you how many nodes there are in the syntax tree.
                            sempre_forms.sort(key=lambda x: x.count('('))
                            if self._max_dpd_tries:
                                sempre_forms = sempre_forms[:self.
                                                            _max_dpd_tries]
                        else:
                            sempre_forms = []
                            for dpd_line in dpd_file:
                                sempre_forms.append(
                                    dpd_line.strip().decode('utf-8'))
                                if self._max_dpd_tries and len(
                                        sempre_forms) >= self._max_dpd_tries:
                                    break
                    except FileNotFoundError:
                        logger.debug(
                            f'Missing DPD output for instance {parsed_info["id"]}; skipping...'
                        )
                        sempre_forms = None
                        num_dpd_missing += 1
                        if not self._keep_if_no_dpd:
                            continue
                else:
                    sempre_forms = None

                table_lines = open(table_filename).readlines()
                instance = self.text_to_instance(question=question,
                                                 table_lines=table_lines,
                                                 example_lisp_string=line,
                                                 dpd_output=sempre_forms)
                if instance is not None:
                    num_instances += 1
                    yield instance

        if self._dpd_output_directory:
            logger.info(
                f"Missing DPD info for {num_dpd_missing} out of {num_lines} instances"
            )
            num_with_dpd = num_lines - num_dpd_missing
            num_bad_lfs = num_with_dpd - num_instances
            logger.info(
                f"DPD output was bad for {num_bad_lfs} out of {num_with_dpd} instances"
            )
            if num_bad_lfs > 0:
                logger.info(
                    "Re-run with log level set to debug to see the un-parseable logical forms"
                )
            logger.info(f"Kept {num_instances} instances")