示例#1
0
def load_data(data_cfg: dict, datasets: list = None)\
        -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary):
    """
    Load train, dev and optionally test data as specified in configuration.
    Vocabularies are created from the training set with a limit of `voc_limit`
    tokens and a minimum token frequency of `voc_min_freq`
    (specified in the configuration dictionary).

    The training data is filtered to include sentences up to `max_sent_length`
    on source and target side.

    If you set ``random_train_subset``, a random selection of this size is used
    from the training set instead of the full training set.

    :param data_cfg: configuration dictionary for data
        ("data" part of configuation file)
    :param datasets: list of dataset names to load
    :return:
        - train_data: training dataset
        - dev_data: development dataset
        - test_data: testdata set if given, otherwise None
        - src_vocab: source vocabulary extracted from training data
        - trg_vocab: target vocabulary extracted from training data
    """
    if datasets is None:
        datasets = ["train", "dev", "test"]

    # load data from files
    src_lang = data_cfg["src"]
    trg_lang = data_cfg["trg"]
    train_path = data_cfg.get("train", None)
    dev_path = data_cfg.get("dev", None)
    test_path = data_cfg.get("test", None)

    if train_path is None and dev_path is None and test_path is None:
        raise ValueError('Please specify at least one data source path.')

    level = data_cfg["level"]
    lowercase = data_cfg["lowercase"]
    max_sent_length = data_cfg["max_sent_length"]

    tok_fun = lambda s: list(s) if level == "char" else s.split()

    src_field = Field(init_token=None,
                      eos_token=EOS_TOKEN,
                      pad_token=PAD_TOKEN,
                      tokenize=tok_fun,
                      batch_first=True,
                      lower=lowercase,
                      unk_token=UNK_TOKEN,
                      include_lengths=True)

    trg_field = Field(init_token=BOS_TOKEN,
                      eos_token=EOS_TOKEN,
                      pad_token=PAD_TOKEN,
                      unk_token=UNK_TOKEN,
                      tokenize=tok_fun,
                      batch_first=True,
                      lower=lowercase,
                      include_lengths=True)

    train_data = None
    if "train" in datasets and train_path is not None:
        logger.info("Loading training data...")
        train_data = TranslationDataset(
            path=train_path,
            exts=("." + src_lang, "." + trg_lang),
            fields=(src_field, trg_field),
            filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length and
            len(vars(x)['trg']) <= max_sent_length)

        random_train_subset = data_cfg.get("random_train_subset", -1)
        if random_train_subset > -1:
            # select this many training examples randomly and discard the rest
            keep_ratio = random_train_subset / len(train_data)
            keep, _ = train_data.split(
                split_ratio=[keep_ratio, 1 - keep_ratio],
                random_state=random.getstate())
            train_data = keep

    src_max_size = data_cfg.get("src_voc_limit", sys.maxsize)
    src_min_freq = data_cfg.get("src_voc_min_freq", 1)
    trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize)
    trg_min_freq = data_cfg.get("trg_voc_min_freq", 1)

    src_vocab_file = data_cfg.get("src_vocab", None)
    trg_vocab_file = data_cfg.get("trg_vocab", None)

    assert (train_data is not None) or (src_vocab_file is not None)
    assert (train_data is not None) or (trg_vocab_file is not None)

    logger.info("Building vocabulary...")
    src_vocab = build_vocab(field="src",
                            min_freq=src_min_freq,
                            max_size=src_max_size,
                            dataset=train_data,
                            vocab_file=src_vocab_file)
    trg_vocab = build_vocab(field="trg",
                            min_freq=trg_min_freq,
                            max_size=trg_max_size,
                            dataset=train_data,
                            vocab_file=trg_vocab_file)

    dev_data = None
    if "dev" in datasets and dev_path is not None:
        logger.info("Loading dev data...")
        dev_data = TranslationDataset(path=dev_path,
                                      exts=("." + src_lang, "." + trg_lang),
                                      fields=(src_field, trg_field))

    test_data = None
    if "test" in datasets and test_path is not None:
        logger.info("Loading test data...")
        # check if target exists
        if os.path.isfile(test_path + "." + trg_lang):
            test_data = TranslationDataset(path=test_path,
                                           exts=("." + src_lang,
                                                 "." + trg_lang),
                                           fields=(src_field, trg_field))
        else:
            # no target is given -> create dataset from src only
            test_data = MonoDataset(path=test_path,
                                    ext="." + src_lang,
                                    field=src_field)
    src_field.vocab = src_vocab
    trg_field.vocab = trg_vocab
    logger.info("Data loaded.")
    return train_data, dev_data, test_data, src_vocab, trg_vocab
示例#2
0
def translate(cfg_file: str,
              ckpt: str,
              output_path: str = None,
              batch_class: Batch = Batch,
              n_best: int = 1) -> None:
    """
    Interactive translation function.
    Loads model from checkpoint and translates either the stdin input or
    asks for input to translate interactively.
    The input has to be pre-processed according to the data that the model
    was trained on, i.e. tokenized or split into subwords.
    Translations are printed to stdout.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param output_path: path to output file
    :param batch_class: class type of batch
    :param n_best: amount of candidates to display
    """
    def _load_line_as_data(line):
        """ Create a dataset from one line via a temporary file. """
        # write src input to temporary file
        tmp_name = "tmp"
        tmp_suffix = ".src"
        tmp_filename = tmp_name + tmp_suffix
        with open(tmp_filename, "w") as tmp_file:
            tmp_file.write("{}\n".format(line))

        test_data = MonoDataset(path=tmp_name, ext=tmp_suffix, field=src_field)

        # remove temporary file
        if os.path.exists(tmp_filename):
            os.remove(tmp_filename)

        return test_data

    def _translate_data(test_data):
        """ Translates given dataset, using parameters from outer scope. """
        # pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores = validate_on_data(
            model, data=test_data, batch_size=batch_size,
            batch_class=batch_class, batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric="",
            use_cuda=use_cuda, compute_loss=False, beam_size=beam_size,
            beam_alpha=beam_alpha, postprocess=postprocess,
            bpe_type=bpe_type, sacrebleu=sacrebleu, n_gpu=n_gpu, n_best=n_best)
        return hypotheses

    cfg = load_config(cfg_file)
    model_dir = cfg["training"]["model_dir"]

    _ = make_logger(model_dir, mode="translate")
    # version string returned

    # when checkpoint is not specified, take oldest from model dir
    if ckpt is None:
        ckpt = get_latest_checkpoint(model_dir)

    # read vocabs
    src_vocab_file = cfg["data"].get("src_vocab", model_dir + "/src_vocab.txt")
    trg_vocab_file = cfg["data"].get("trg_vocab", model_dir + "/trg_vocab.txt")
    src_vocab = Vocabulary(file=src_vocab_file)
    trg_vocab = Vocabulary(file=trg_vocab_file)

    data_cfg = cfg["data"]
    level = data_cfg["level"]
    lowercase = data_cfg["lowercase"]

    tok_fun = lambda s: list(s) if level == "char" else s.split()

    src_field = Field(init_token=None,
                      eos_token=EOS_TOKEN,
                      pad_token=PAD_TOKEN,
                      tokenize=tok_fun,
                      batch_first=True,
                      lower=lowercase,
                      unk_token=UNK_TOKEN,
                      include_lengths=True)
    src_field.vocab = src_vocab

    # parse test args
    batch_size, batch_type, use_cuda, device, n_gpu, level, _, \
        max_output_length, beam_size, beam_alpha, postprocess, \
        bpe_type, sacrebleu, _, _ = parse_test_args(cfg, mode="translate")

    # load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

    # build model and load parameters into it
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
    model.load_state_dict(model_checkpoint["model_state"])

    if use_cuda:
        model.to(device)

    if not sys.stdin.isatty():
        # input file given
        test_data = MonoDataset(path=sys.stdin, ext="", field=src_field)
        all_hypotheses = _translate_data(test_data)

        if output_path is not None:
            # write to outputfile if given

            def write_to_file(output_path_set, hypotheses):
                with open(output_path_set, mode="w", encoding="utf-8") \
                        as out_file:
                    for hyp in hypotheses:
                        out_file.write(hyp + "\n")
                logger.info("Translations saved to: %s.", output_path_set)

            if n_best > 1:
                for n in range(n_best):
                    file_name, file_extension = os.path.splitext(output_path)
                    write_to_file(
                        "{}-{}{}".format(
                            file_name, n,
                            file_extension if file_extension else ""), [
                                all_hypotheses[i]
                                for i in range(n, len(all_hypotheses), n_best)
                            ])
            else:
                write_to_file("{}".format(output_path), all_hypotheses)
        else:
            # print to stdout
            for hyp in all_hypotheses:
                print(hyp)

    else:
        # enter interactive mode
        batch_size = 1
        batch_type = "sentence"
        while True:
            try:
                src_input = input("\nPlease enter a source sentence "
                                  "(pre-processed): \n")
                if not src_input.strip():
                    break

                # every line has to be made into dataset
                test_data = _load_line_as_data(line=src_input)
                hypotheses = _translate_data(test_data)

                print("JoeyNMT: Hypotheses ranked by score")
                for i, hyp in enumerate(hypotheses):
                    print("JoeyNMT #{}: {}".format(i + 1, hyp))

            except (KeyboardInterrupt, EOFError):
                print("\nBye.")
                break
示例#3
0
from torchtext.legacy.data import Field, RawField
import numpy as np

from utils.entities_list import Entities_list
from utils.class_utils import keys_vocab_cls, iob_labels_vocab_cls, entities_vocab_cls

MAX_BOXES_NUM = 70  # limit max number boxes of every documents
MAX_TRANSCRIPT_LEN_GLOBAL = 64  # limit max length text of every box
MAX_WIDTH = 1024
# text string label converter
TextSegmentsField = Field(sequential=True,
                          use_vocab=True,
                          include_lengths=True,
                          batch_first=True)
TextSegmentsField.vocab = keys_vocab_cls
# iob string label converter
IOBTagsField = Field(sequential=True,
                     is_target=True,
                     use_vocab=True,
                     batch_first=True)
IOBTagsField.vocab = iob_labels_vocab_cls


class Document:
    def __init__(
            self,
            boxes_and_transcripts_file: Path,
            image_file: Path,
            #resized_image_size: Tuple[int, int] = (480, 960),
            segment_height: int = 64,