Exemple #1
0
def load_data(
    data_cfg: dict
) -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary):
    """
    Load train, dev and optionally test data as specified in configuration.
    Vocabularies are created from the training set with a limit of `voc_limit`
    tokens and a minimum token frequency of `voc_min_freq`
    (specified in the configuration dictionary).

    The training data is filtered to include sentences up to `max_sent_length`
    on source and target side.

    If you set ``random_train_subset``, a random selection of this size is used
    from the training set instead of the full training set.

    :param data_cfg: configuration dictionary for data
        ("data" part of configuation file)
    :return:
        - train_data: training dataset
        - dev_data: development dataset
        - test_data: testdata set if given, otherwise None
        - src_vocab: source vocabulary extracted from training data
        - trg_vocab: target vocabulary extracted from training data
    """
    # load data from files
    src_lang = data_cfg["src"]
    trg_lang = data_cfg["trg"]
    train_path = data_cfg["train"]
    dev_path = data_cfg["dev"]
    test_path = data_cfg.get("test", None)
    level = data_cfg["level"]
    lowercase = data_cfg["lowercase"]
    max_sent_length = data_cfg["max_sent_length"]

    tok_fun = lambda s: list(s) if level == "char" else s.split()

    src_field = data.Field(init_token=None,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           batch_first=True,
                           lower=lowercase,
                           unk_token=UNK_TOKEN,
                           include_lengths=True)

    trg_field = data.Field(init_token=BOS_TOKEN,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           unk_token=UNK_TOKEN,
                           batch_first=True,
                           lower=lowercase,
                           include_lengths=True)

    train_data = TranslationDataset(
        path=train_path,
        exts=("." + src_lang, "." + trg_lang),
        fields=(src_field, trg_field),
        filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length and len(
            vars(x)['trg']) <= max_sent_length)

    src_max_size = data_cfg.get("src_voc_limit", sys.maxsize)
    src_min_freq = data_cfg.get("src_voc_min_freq", 1)
    trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize)
    trg_min_freq = data_cfg.get("trg_voc_min_freq", 1)

    src_vocab_file = data_cfg.get("src_vocab", None)
    trg_vocab_file = data_cfg.get("trg_vocab", None)

    src_vocab = build_vocab(field="src",
                            min_freq=src_min_freq,
                            max_size=src_max_size,
                            dataset=train_data,
                            vocab_file=src_vocab_file)
    trg_vocab = build_vocab(field="trg",
                            min_freq=trg_min_freq,
                            max_size=trg_max_size,
                            dataset=train_data,
                            vocab_file=trg_vocab_file)

    random_train_subset = data_cfg.get("random_train_subset", -1)
    if random_train_subset > -1:
        # select this many training examples randomly and discard the rest
        keep_ratio = random_train_subset / len(train_data)
        keep, _ = train_data.split(split_ratio=[keep_ratio, 1 - keep_ratio],
                                   random_state=random.getstate())
        train_data = keep

    dev_data = TranslationDataset(path=dev_path,
                                  exts=("." + src_lang, "." + trg_lang),
                                  fields=(src_field, trg_field))
    test_data = None
    if test_path is not None:
        # check if target exists
        if os.path.isfile(test_path + "." + trg_lang):
            test_data = TranslationDataset(path=test_path,
                                           exts=("." + src_lang,
                                                 "." + trg_lang),
                                           fields=(src_field, trg_field))
        else:
            # no target is given -> create dataset from src only
            test_data = MonoDataset(path=test_path,
                                    ext="." + src_lang,
                                    field=src_field)
    src_field.vocab = src_vocab
    trg_field.vocab = trg_vocab
    return train_data, dev_data, test_data, src_vocab, trg_vocab
Exemple #2
0

def tokenize_word(text):
    text = preprocess(text)
    return tokenize(text)


SRC = Field(tokenize=tokenize_word, lower=True, batch_first=True)

TRG = Field(tokenize=tokenize_word,
            init_token='<sos>',
            eos_token='<eos>',
            lower=True,
            batch_first=True)

fields = [('src', SRC), ('trg', TRG)]

ds = TranslationDataset('lang.', ('en', 'de'), fields)

train_ds, test_ds = ds.split(0.9)

SRC.build_vocab(ds)
TRG.build_vocab(ds)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32

train_iter = BucketIterator(train_ds, batch_size=batch_size, device=device)

test_iter = BucketIterator(test_ds, batch_size=batch_size, device=device)
Exemple #3
0
            tokenize=tokenizer_fr,
            lower=True,
            init_token=BOS,
            eos_token=EOS,
            batch_first=True)

# prefix_f = 'data/escape.en-de.tok.50k'
prefix_f = 'data/data'
parallel_dataset = TranslationDataset(path=prefix_f,
                                      exts=('.en', '.fr'),
                                      fields=[('src', src), ('tgt', tgt)])

src.build_vocab(parallel_dataset, max_size=15000)
tgt.build_vocab(parallel_dataset, max_size=15000)

train, valid = parallel_dataset.split(split_ratio=0.97)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 24

train_iterator, valid_iterator = BucketIterator.splits(
    (train, valid),
    batch_size=BATCH_SIZE,
    sort_key=lambda x: interleave_keys(len(x.src), len(x.tgt)),
    device=device)


class Encoder(nn.Module):
    def __init__(self, hidden_dim: int, src_ntoken: int, dropout: float):
        super().__init__()
Exemple #4
0
def load_data(
    data_cfg: dict
) -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary):
    """
    Load train, dev and optionally test data as specified in configuration.
    Vocabularies are created from the training set with a limit of `voc_limit`
    tokens and a minimum token frequency of `voc_min_freq`
    (specified in the configuration dictionary).

    The training data is filtered to include sentences up to `max_sent_length`
    on source and target side.

    If you set ``random_train_subset``, a random selection of this size is used
    from the training set instead of the full training set.

    :param data_cfg: configuration dictionary for data
        ("data" part of configuation file)
    :return:
        - train_data: training dataset
        - dev_data: development dataset
        - test_data: testdata set if given, otherwise None
        - src_vocab: source vocabulary extracted from training data
        - trg_vocab: target vocabulary extracted from training data
        - train_kb: TranslationDataset from train KB
        - dev_kb: TranslationDataset from dev KB
        - test_kb: TranslationDataset from test KB
        - train_kb_lookup: List of KB association lookup indices from data to KBs
        - dev_kb_lookup: List of KB association lookup indices from data to KBs
        - test_kb_lookup: List of KB association lookup indices from data to KBs
        - train_kb_lengths: List of KB lengths
        - dev_kb_lengths: List of KB lengths
        - test_kb_lengths: List of KB lengths
        

    """
    # load data from files
    src_lang = data_cfg["src"]
    trg_lang = data_cfg["trg"]

    train_path = data_cfg["train"]
    dev_path = data_cfg["dev"]
    test_path = data_cfg.get("test", None)
    level = data_cfg["level"]
    lowercase = data_cfg["lowercase"]
    max_sent_length = data_cfg.get("max_sent_length", sys.maxsize * .1)

    #kb stuff
    kb_task = bool(data_cfg.get("kb_task", False))

    if kb_task:
        kb_src = data_cfg.get("kb_src", "kbk")
        kb_trg = data_cfg.get("kb_trg", "kbv")
        kb_lkp = data_cfg.get("kb_lkp", "lkp")
        kb_len = data_cfg.get("kb_len", "len")
        kb_trv = data_cfg.get("kb_truvals", "trv")
        global_trv = data_cfg.get("global_trv", "")
        if global_trv:
            print(
                f"UserWarning global_trv parameter deprecated, use nothing instead."
            )
        trutrg = data_cfg.get("trutrg", "car")
        canonization_mode = data_cfg.get("canonization_mode", "canonize")
        assert canonization_mode in ["canonize", "hash"], canonization_mode

        # TODO FIXME following is hardcoded; add to configs please
        pnctprepro = True
    else:
        # the rest of the above variables are set at the end of load data for the non KB case
        pnctprepro = False

    # default joeyNMT behaviour for sentences

    tok_fun = list if level == "char" else (
        pkt_tokenize if pnctprepro else tokenize)

    src_field = data.Field(init_token=None,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           batch_first=True,
                           lower=lowercase,
                           unk_token=UNK_TOKEN,
                           include_lengths=True)

    trg_field = data.Field(init_token=BOS_TOKEN,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           unk_token=UNK_TOKEN,
                           batch_first=True,
                           lower=lowercase,
                           include_lengths=True)

    if kb_task:
        # NOTE lowercase MUST be False for datasets with tokens that may include whitespace!
        # the torchtext lowercase pipeline seems to operate not just on the first char of a token (dataset field level)
        # but lowercases individual words separated by whitespace WITHIN a specified token
        # which leads to the vocab not recognizing tokens even though added to the field.vocab
        # via joeynmt.vocabulary.build_vocab
        # other way to circumvent may be to lowercase in the same manner before calling
        # field.process
        trv_field = data.Field(init_token=BOS_TOKEN,
                               eos_token=EOS_TOKEN,
                               pad_token=PAD_TOKEN,
                               tokenize=lambda entire_line: [entire_line],
                               unk_token=UNK_TOKEN,
                               batch_first=True,
                               lower=False,
                               include_lengths=False)

    train_data = TranslationDataset(
        path=train_path,
        exts=("." + src_lang, "." + trg_lang),
        fields=(src_field, trg_field),
        filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length and len(
            vars(x)['trg']) <= max_sent_length)

    if kb_task:  #load train_kb and metadata

        # NOTE change trg_lang to trutrg for dev/test
        # train_data has been loaded with normal extension (canonized files, e.g. train.carno)
        # dev/test_data will be loaded from non canonized files
        canon_trg = trg_lang  # keep this for loss reporting (load dev/test data from here separately)
        trg_lang = trutrg

        train_kb_truvals = MonoDataset(path=train_path,
                                       ext=("." + kb_trv),
                                       field=("kbtrv", trv_field),
                                       filter_pred=lambda x: True)

        train_kb = TranslationDataset(path=train_path,
                                      exts=("." + kb_src, "." + kb_trg),
                                      fields=(("kbsrc", src_field),
                                              ("kbtrg", trg_field)),
                                      filter_pred=lambda x: True)

        with open(train_path + "." + kb_lkp, "r") as lkp:
            lookup = lkp.readlines()
        train_kb_lookup = [int(elem[:-1]) for elem in lookup if elem[:-1]]
        with open(train_path + "." + kb_len, "r") as lens:
            lengths = lens.readlines()
        train_kb_lengths = [int(elem[:-1]) for elem in lengths if elem[:-1]]

    # now that we have train data, build vocabulary from it. worry about dev and test data further below

    src_max_size = data_cfg.get("src_voc_limit", sys.maxsize)
    src_min_freq = data_cfg.get("src_voc_min_freq", 1)
    trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize)
    trg_min_freq = data_cfg.get("trg_voc_min_freq", 1)

    src_vocab_file = data_cfg.get("src_vocab", None)

    # NOTE unused
    trg_vocab_file = data_cfg.get("trg_vocab", None)
    trg_kb_vocab_file = data_cfg.get("trg_kb_vocab", None)
    trg_vocab_file = trg_vocab_file if not trg_kb_vocab_file else trg_kb_vocab_file  # prefer to use joint trg_kb_vocab_file if specified

    vocab_building_datasets = train_data if not kb_task else (train_data,
                                                              train_kb)
    vocab_building_src_fields = "src" if not kb_task else ("src", "kbsrc")
    vocab_building_trg_fields = "trg" if not kb_task else ("trg", "kbtrg")

    src_vocab = build_vocab(fields=vocab_building_src_fields,
                            min_freq=src_min_freq,
                            max_size=src_max_size,
                            dataset=vocab_building_datasets,
                            vocab_file=src_vocab_file)
    trg_vocab = build_vocab(fields=vocab_building_trg_fields,
                            min_freq=trg_min_freq,
                            max_size=trg_max_size,
                            dataset=vocab_building_datasets,
                            vocab_file=trg_vocab_file)

    random_train_subset = data_cfg.get("random_train_subset", -1)
    if random_train_subset > -1:
        # select this many training examples randomly and discard the rest
        keep_ratio = random_train_subset / len(train_data)
        keep, _ = train_data.split(split_ratio=[keep_ratio, 1 - keep_ratio],
                                   random_state=random.getstate())
        train_data = keep

    dev_data = TranslationDataset(path=dev_path,
                                  exts=("." + src_lang, "." + trg_lang),
                                  fields=(src_field, trg_field))

    if kb_task:  #load dev kb and metadata; load canonized dev data for loss reporting

        dev_data_canon = TranslationDataset(path=dev_path,
                                            exts=("." + src_lang,
                                                  "." + canon_trg),
                                            fields=(src_field, trg_field))

        dev_kb = TranslationDataset(path=dev_path,
                                    exts=("." + kb_src, "." + kb_trg),
                                    fields=(("kbsrc", src_field), ("kbtrg",
                                                                   trg_field)),
                                    filter_pred=lambda x: True)
        dev_kb_truvals = MonoDataset(path=dev_path,
                                     ext=("." + kb_trv),
                                     field=("kbtrv", trv_field),
                                     filter_pred=lambda x: True)

        with open(dev_path + "." + kb_lkp, "r") as lkp:
            lookup = lkp.readlines()
        dev_kb_lookup = [int(elem[:-1]) for elem in lookup if elem[:-1]]
        with open(dev_path + "." + kb_len, "r") as lens:
            lengths = lens.readlines()
        dev_kb_lengths = [int(elem[:-1]) for elem in lengths if elem[:-1]]

    test_data = None
    if test_path is not None:
        # check if target exists
        if os.path.isfile(test_path + "." + trg_lang):
            test_data = TranslationDataset(path=test_path,
                                           exts=("." + src_lang,
                                                 "." + trg_lang),
                                           fields=(src_field, trg_field))
        else:
            # no target is given -> create dataset from src only
            test_data = MonoDataset(path=test_path,
                                    ext="." + src_lang,
                                    field=src_field)
    if kb_task:  #load test kb and metadata

        test_data_canon = TranslationDataset(path=test_path,
                                             exts=("." + src_lang,
                                                   "." + canon_trg),
                                             fields=(src_field, trg_field))
        test_kb = TranslationDataset(path=test_path,
                                     exts=("." + kb_src, "." + kb_trg),
                                     fields=(("kbsrc", src_field),
                                             ("kbtrg", trg_field)),
                                     filter_pred=lambda x: True)
        test_kb_truvals = MonoDataset(path=test_path,
                                      ext=("." + kb_trv),
                                      field=("kbtrv", trv_field),
                                      filter_pred=lambda x: True)

        with open(test_path + "." + kb_lkp, "r") as lkp:
            lookup = lkp.readlines()
        test_kb_lookup = [int(elem[:-1]) for elem in lookup if elem[:-1]]
        with open(test_path + "." + kb_len, "r") as lens:
            lengths = lens.readlines()
        test_kb_lengths = [int(elem[:-1]) for elem in lengths if elem[:-1]]

    # finally actually set the .vocab field attributes
    src_field.vocab = src_vocab  # also sets kb_src_field.vocab if theyre the same (variables point to same object)
    trg_field.vocab = trg_vocab

    if kb_task:
        # NOTE this vocab is hardcodedly built from the concatenation of train+dev+test trv files!
        # trv_path = train_path[:len(train_path)-train_path[::-1].find("/")]+global_trv
        # assert os.path.isfile(trv_path)

        trv_ext = "." + kb_trv

        trv_train_path = train_path + trv_ext
        trv_dev_path = dev_path + trv_ext
        trv_test_path = test_path + trv_ext

        assert os.path.isfile(trv_train_path)

        # try to make vocabulary exactly as large as needed

        # trv_vocab._from_file(trv_path)
        trv_vocab = deepcopy(trg_vocab)
        # FIXME only add this for source copying?
        trv_vocab._from_list(src_vocab.itos)

        trv_vocab._from_file(trv_train_path)
        trv_vocab._from_file(trv_dev_path)
        trv_vocab._from_file(trv_test_path)
        if canonization_mode == "canonize":
            # stanford data
            assert "schedule" in trv_vocab.itos
        # NOTE really important for model.postprocess:
        # trv_vocab must begin with trg_vocab
        # to look up canonical tokens correctly
        assert trg_vocab.itos == trv_vocab.itos[:len(trg_vocab)]

        print(
            f"Added true value lines as tokens to trv_vocab of length={len(trv_vocab)}"
        )
        trv_field.vocab = trv_vocab

    if kb_task:
        # make canonization function to create KB from source for batches without one

        entities_path = "data/kvr/kvret_entities_altered.json"  # TODO FIXME add to config
        entities = load_json(fp=entities_path)
        efficient_entities = preprocess_entity_dict(entities,
                                                    lower=lowercase,
                                                    tok_fun=tok_fun)

        if canonization_mode == "hash":
            # initialize with train knowledgebases
            hash_vocab = build_vocab(max_size=4096,
                                     dataset=train_kb,
                                     fields=vocab_building_trg_fields,
                                     min_freq=1,
                                     vocab_file=trv_train_path)
            hash_vocab._from_file(trv_train_path)
            hash_vocab._from_file(trv_dev_path)
            hash_vocab._from_file(trv_test_path)

            # assert False, hash_vocab.itos

        # assert False, canonize_sequence(["your", "meeting", "in", "conference", "room", "100", "is", "with", "martha"], efficient_entities) # assert False, # NOTE
        # assert False, hash_canons(["Sure" , "the", "chinese", "good", "luck", "chinese", "food", "takeaway", "is", "on","the_good_luck_chinese_food_takeaway_address"], hash_vocab.itos) # assert False, # NOTE

        if canonization_mode == "canonize":

            class Canonizer:
                def __init__(self, copy_from_source: bool = False):
                    self.copy_from_source = bool(copy_from_source)

                def __call__(self, seq):
                    processed, indices, matches = canonize_sequence(
                        seq, efficient_entities)
                    return processed, indices, matches
        elif canonization_mode == "hash":

            class Canonizer:
                def __init__(self, copy_from_source: bool = False):
                    self.copy_from_source = bool(copy_from_source)

                def __call__(self, seq):
                    processed, indices, matches = hash_canons(
                        seq, hash_vocab.itos)
                    return processed, indices, matches
        else:
            raise ValueError(
                f"canonization mode {canonization_mode} not implemented")

    if not kb_task:  #default values for normal pipeline
        train_kb, dev_kb, test_kb = None, None, None
        train_kb_lookup, dev_kb_lookup, test_kb_lookup = [], [], []
        train_kb_lengths, dev_kb_lengths, test_kb_lengths = [], [], []
        train_kb_truvals, dev_kb_truvals, test_kb_truvals = [], [], []
        trv_vocab = None
        dev_data_canon, test_data_canon = [], []
        Canonizer = None

    # FIXME return dict here lol
    return train_data, dev_data, test_data,\
        src_vocab, trg_vocab,\
        train_kb, dev_kb, test_kb,\
        train_kb_lookup, dev_kb_lookup, test_kb_lookup,\
        train_kb_lengths, dev_kb_lengths, test_kb_lengths,\
        train_kb_truvals, dev_kb_truvals, test_kb_truvals,\
        trv_vocab, Canonizer, \
        dev_data_canon, test_data_canon