Exemplo n.º 1
0
    def testVocabularyFromFile(self):
        # write vocabs to file and create new ones from those files
        self.word_vocab.to_file(self.temp_file_word)
        self.char_vocab.to_file(self.temp_file_char)

        word_vocab2 = Vocabulary.from_file(self.temp_file_word)
        char_vocab2 = Vocabulary.from_file(self.temp_file_char)
        self.assertEqual(self.word_vocab.itos, word_vocab2.itos)
        self.assertEqual(self.char_vocab.itos, char_vocab2.itos)
        os.remove(self.temp_file_char)
        os.remove(self.temp_file_word)
Exemplo n.º 2
0
def translate(cfg_file, ckpt: str, output_path: str = None) -> None:
    """
    Interactive translation function.
    Loads model from checkpoint and translates either the stdin input or
    asks for input to translate interactively.
    The input has to be pre-processed according to the data that the model
    was trained on, i.e. tokenized or split into subwords.
    Translations are printed to stdout.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    """
    def _load_line_as_data(line):
        """ Create a dataset from one line via a temporary file. """
        # write src input to temporary file
        tmp_name = "tmp"
        tmp_suffix = ".src"
        tmp_filename = tmp_name + tmp_suffix
        with open(tmp_filename, "w") as tmp_file:
            tmp_file.write("{}\n".format(line))

        test_data = MonoDataset(path=tmp_name, ext=tmp_suffix, field=src_field)

        # remove temporary file
        if os.path.exists(tmp_filename):
            os.remove(tmp_filename)

        return test_data

    def _translate_data(test_data):
        """ Translates given dataset, using parameters from outer scope. """
        hypotheses = validate_on_data(model,
                                      data=test_data,
                                      batch_size=batch_size,
                                      trg_level=level,
                                      max_output_length=max_output_length,
                                      eval_metrics=[],
                                      use_cuda=use_cuda,
                                      loss_function=None,
                                      beam_size=beam_size,
                                      beam_alpha=beam_alpha)[2]
        return hypotheses

    cfg = load_config(cfg_file)

    # when checkpoint is not specified, take oldest from model dir
    if ckpt is None:
        model_dir = cfg["training"]["model_dir"]
        ckpt = get_latest_checkpoint(model_dir)

    batch_size = cfg["training"].get("batch_size", 1)
    use_cuda = cfg["training"].get("use_cuda", False)
    level = cfg["data"]["level"]
    max_output_length = cfg["training"].get("max_output_length", None)

    # read vocabs
    src_vocab_file = cfg["data"].get(
        "src_vocab", cfg["training"]["model_dir"] + "/src_vocab.txt")
    trg_vocab_file = cfg["data"].get(
        "trg_vocab", cfg["training"]["model_dir"] + "/trg_vocab.txt")
    src_vocab = Vocabulary.from_file(src_vocab_file)
    trg_vocab = Vocabulary.from_file(trg_vocab_file)

    data_cfg = cfg["data"]
    level = data_cfg["level"]
    lowercase = data_cfg["lowercase"]

    tok_fun = list if level == "char" else str.split

    src_field = Field(init_token=None,
                      eos_token=EOS_TOKEN,
                      pad_token=PAD_TOKEN,
                      tokenize=tok_fun,
                      batch_first=True,
                      lower=lowercase,
                      unk_token=UNK_TOKEN,
                      include_lengths=True)
    src_field.vocab = src_vocab

    # load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

    # build model and load parameters into it
    model = build_model(cfg["model"],
                        vocabs={
                            "src": src_vocab,
                            "trg": trg_vocab
                        })
    model.load_state_dict(model_checkpoint["model_state"])

    if use_cuda:
        model.cuda()

    # whether to use beam search for decoding, 0: greedy decoding
    if "testing" in cfg.keys():
        beam_size = cfg["testing"].get("beam_size", 0)
        beam_alpha = cfg["testing"].get("alpha", 0)
    else:
        beam_size = 0
        beam_alpha = 0
    if beam_alpha < 0:
        raise ConfigurationError("alpha for length penalty should be >= 0")

    if not sys.stdin.isatty():
        # file given
        test_data = MonoDataset(path=sys.stdin, ext="", field=src_field)
        hypotheses = _translate_data(test_data)

        if output_path is not None:
            output_path_set = "{}".format(output_path)
            with open(output_path_set, mode="w", encoding="utf-8") as out_file:
                for hyp in hypotheses:
                    out_file.write(hyp + "\n")
            print("Translations saved to: {}".format(output_path_set))
        else:
            for hyp in hypotheses:
                print(hyp)

    else:
        # enter interactive mode
        batch_size = 1
        while True:
            try:
                src_input = input("\nPlease enter a source sentence "
                                  "(pre-processed): \n")
                if not src_input.strip():
                    break

                # every line has to be made into dataset
                test_data = _load_line_as_data(line=src_input)

                hypotheses = _translate_data(test_data)
                print("JoeyNMT: {}".format(hypotheses[0]))

            except (KeyboardInterrupt, EOFError):
                print("\nBye.")
                break
Exemplo n.º 3
0
Arquivo: data.py Projeto: deep-spin/S7
def load_data(data_cfg: dict):
    """
    Load train, dev and optionally test data as specified in configuration.
    Vocabularies are created from the training set with a limit of `voc_limit`
    tokens and a minimum token frequency of `voc_min_freq`
    (specified in the configuration dictionary).

    The training data is filtered to include sentences up to `max_sent_length`
    on source and target side.

    :param data_cfg: configuration dictionary for data
        ("data" part of configuation file)
    :return:
        - train_data: training dataset
        - dev_data: development dataset
        - test_data: testdata set if given, otherwise None
        - vocabs: dictionary from src and trg (and possibly other fields) to
            their corresponding vocab objects
    """
    data_format = data_cfg.get("format", "bitext")
    formats = {"bitext", "tsv"}
    assert data_format in formats

    train_path = data_cfg["train"]
    dev_path = data_cfg["dev"]
    test_path = data_cfg.get("test", None)

    lowercase = data_cfg["lowercase"]
    max_sent_length = data_cfg["max_sent_length"]
    default_level = data_cfg.get("level", "word")
    voc_limit = data_cfg.get("voc_limit", sys.maxsize)
    voc_min_freq = data_cfg.get("voc_min_freq", 1)

    tokenizers = {
        "word": str.split,
        "bpe": str.split,
        "char": list,
        "char-decompose": decompose_tokenize,
        "tag": partial(str.split, sep=";")
    }

    # is field_names better?
    # column_labels seems like an ok name for the
    all_fields = data_cfg.get("columns", ["src", "trg"])
    label_fields = data_cfg.get("label_fields", [])
    assert all(label in all_fields for label in label_fields)
    sequential_fields = [
        field for field in all_fields if field not in label_fields
    ]
    src_fields = [f_name for f_name in all_fields if f_name != "trg"]
    trg_fields = ["trg"]

    suffixes = {
        f_name: data_cfg.get(f_name, "")
        for f_name in sequential_fields
    }

    seq_field_cls = partial(Field,
                            eos_token=EOS_TOKEN,
                            pad_token=PAD_TOKEN,
                            unk_token=UNK_TOKEN,
                            batch_first=True,
                            lower=lowercase,
                            include_lengths=True)

    fields = dict()
    # what are the source fields? what are the target fields?

    for f_name in sequential_fields:
        bos = BOS_TOKEN if f_name == "trg" else None
        current_level = data_cfg.get(f_name + "_level", default_level)
        assert current_level in tokenizers, "Invalid tokenization level"
        tok_fun = tokenizers[current_level]

        fields[f_name] = seq_field_cls(init_token=bos, tokenize=tok_fun)

    for f_name in label_fields:
        fields[f_name] = Field(sequential=False)

    filter_ex = partial(filter_example,
                        max_sent_length=max_sent_length,
                        keys=tuple(fields.keys()))

    if data_format == "bitext":
        dataset_cls = partial(TranslationDataset,
                              exts=("." + suffixes["src"],
                                    "." + suffixes["trg"]),
                              fields=(fields["src"], fields["trg"]))
    else:
        dataset_cls = partial(TSVDataset,
                              fields=fields,
                              columns=all_fields,
                              label_columns=label_fields)

    if test_path is not None:
        trg_suffix = suffixes["trg"]
        if data_format != "bitext" or isfile(test_path + "." + trg_suffix):
            test_dataset_cls = dataset_cls
        else:
            test_dataset_cls = partial(MonoDataset,
                                       ext="." + suffixes["src"],
                                       field=fields["src"])
    else:
        test_dataset_cls = None

    train_data = dataset_cls(path=train_path, filter_pred=filter_ex)

    vocabs = dict()

    # here's the thing: you want to have a vocab for each f_name, but not
    # necessarily a *distinct* vocab
    share_src_vocabs = data_cfg.get("share_src_vocabs", False)
    if share_src_vocabs:
        field_groups = [src_fields, trg_fields]
    else:
        field_groups = [[f] for f in sequential_fields]
    for f_group in field_groups:
        if len(f_group) == 1:
            f_name = f_group[0]
            max_size = data_cfg.get("{}_voc_limit".format(f_name), voc_limit)
            min_freq = data_cfg.get("{}_voc_min_freq".format(f_name),
                                    voc_min_freq)
            vocab_file = data_cfg.get("{}_vocab".format(f_name), None)
        else:
            # multiple fields sharing a vocabulary
            max_size = voc_limit
            min_freq = voc_min_freq
            vocab_file = None

        if vocab_file is not None:
            f_vocab = Vocabulary.from_file(vocab_file)
        else:
            f_vocab = Vocabulary.from_dataset(train_data, f_group, max_size,
                                              min_freq)
        for f_name in f_group:
            vocabs[f_name] = f_vocab

    label_field_groups = [[lf] for lf in label_fields]
    for f_group in label_field_groups:
        if len(f_group) == 1:
            f_name = f_group[0]
            max_size = data_cfg.get("{}_voc_limit".format(f_name), voc_limit)
            min_freq = data_cfg.get("{}_voc_min_freq".format(f_name),
                                    voc_min_freq)
            vocab_file = data_cfg.get("{}_vocab".format(f_name), None)
        else:
            # multiple fields sharing a vocabulary
            max_size = voc_limit
            min_freq = voc_min_freq
            vocab_file = None

        if vocab_file is not None:
            f_vocab = Vocabulary.from_file(vocab_file)
        else:
            f_vocab = Vocabulary.from_dataset(train_data,
                                              f_group,
                                              max_size,
                                              min_freq,
                                              sequential=False)
        for f_name in f_group:
            vocabs[f_name] = f_vocab
    '''
    for vocab_name, vocab in vocabs.items():
        print(vocab_name)
        print(vocab.itos)
        print()
    '''

    dev_data = dataset_cls(path=dev_path)

    if test_path is not None:
        trg_suffix = suffixes["trg"]
        if data_format != "bitext" or isfile(test_path + "." + trg_suffix):
            test_dataset_cls = dataset_cls
        else:
            test_dataset_cls = partial(MonoDataset,
                                       ext="." + suffixes["src"],
                                       field=fields["src"])
        test_data = test_dataset_cls(path=test_path)
    else:
        test_data = None

    for field_name in fields:
        fields[field_name].vocab = vocabs[field_name]

    ret = {
        "train_data": train_data,
        "dev_data": dev_data,
        "test_data": test_data,
        "vocabs": vocabs
    }
    return ret