Esempio n. 1
0
    def augment_generator():
        for base, edit in input_examples:
            base_words = [w.lower() for w in base.split(' ')]
            base_words = tuple([convert_to_bytes(base_words)])
            edit_instance = parse_instance(edit)

            yield base_words + edit_instance
Esempio n. 2
0
def parse_instance(instance, noiser=None, free=None):
    if isinstance(instance, str):
        instance = instance.split('\t')
    elif isinstance(instance, dict):
        instance = [
            instance.get('base', ''),
            instance.get('output', ''),
            instance.get('src', ''),
            instance.get('tgt', ''),
        ]

    assert len(instance) == 4

    base, output = instance[:2]
    base_words = base.lower().split(' ')
    output_words = output.lower().split(' ')

    orig_base_ids, extended_base_ids, oov = create_oov(base_words)
    orig_output_ids, extended_output_ids = words2ids(output_words), words2ids(
        output_words, oov)

    src, tgt = instance[2:]
    src_words = src.lower().split(' ')
    tgt_words = tgt.lower().split(' ')

    if free is None:
        free = set()

    insert_words = sorted(set(tgt_words) - set(src_words) - free)
    delete_words = sorted((set(src_words) & set(tgt_words)) - free)

    if noiser:
        src_words, tgt_words, insert_words, delete_words = noiser(
            (src_words, tgt_words, insert_words, delete_words))

    if len(insert_words) == 0:
        insert_words.append(vocab.UNKNOWN_TOKEN)

    if len(delete_words) == 0:
        delete_words.append(vocab.UNKNOWN_TOKEN)

    if len(oov) == 0:
        oov = [vocab.PAD_TOKEN]

    return (orig_base_ids, extended_base_ids, orig_output_ids,
            extended_output_ids, words2ids(src_words), words2ids(tgt_words),
            words2ids(insert_words), words2ids(delete_words),
            convert_to_bytes(oov))
Esempio n. 3
0
def map_str_to_bytes(instance):
    if isinstance(instance, tuple):
        return tuple([convert_to_bytes(s) for s in instance])
    else:
        return convert_to_bytes(instance)