コード例 #1
0
def make_translation_data(src_file,
                          tgt_file,
                          src_dicts,
                          tgt_dicts,
                          tokenizer,
                          max_src_length=64,
                          max_tgt_length=64,
                          add_bos=True,
                          data_type='int64',
                          num_workers=1,
                          verbose=False):

    src, tgt = [], []
    src_sizes = []
    tgt_sizes = []

    print("[INFO] Binarizing file %s ..." % src_file)
    binarized_src = Binarizer.binarize_file(src_file,
                                            src_dicts,
                                            tokenizer,
                                            bos_word=None,
                                            eos_word=None,
                                            data_type=data_type,
                                            num_workers=num_workers,
                                            verbose=verbose)

    if add_bos:
        tgt_bos_word = onmt.constants.BOS_WORD
    else:
        tgt_bos_word = None

    print("[INFO] Binarizing file %s ..." % tgt_file)
    binarized_tgt = Binarizer.binarize_file(tgt_file,
                                            tgt_dicts,
                                            tokenizer,
                                            bos_word=tgt_bos_word,
                                            eos_word=onmt.constants.EOS_WORD,
                                            data_type=data_type,
                                            num_workers=num_workers,
                                            verbose=verbose)

    src = binarized_src['data']
    src_sizes = binarized_src['sizes']

    tgt = binarized_tgt['data']
    tgt_sizes = binarized_tgt['sizes']

    # currently we don't ignore anything :D
    ignored = 0

    print(('Prepared %d sentences ' +
           '(%d ignored due to length == 0 or src len > %d or tgt len > %d)') %
          (len(src), ignored, max_src_length, max_tgt_length))

    return src, tgt
コード例 #2
0
def make_asr_data(src_file,
                  tgt_file,
                  tgt_dicts,
                  tokenizer,
                  max_src_length=64,
                  max_tgt_length=64,
                  add_bos=True,
                  data_type='int64',
                  num_workers=1,
                  verbose=False,
                  input_type='word',
                  stride=1,
                  concat=4,
                  prev_context=0,
                  fp16=False,
                  reshape=True,
                  asr_format="h5",
                  output_format="raw"):
    src, tgt = [], []
    src_sizes = []
    tgt_sizes = []
    count, ignored = 0, 0
    n_unk_words = 0

    print('[INFO] Processing %s  ...' % src_file)
    binarized_src = SpeechBinarizer.binarize_file(src_file,
                                                  input_format=asr_format,
                                                  output_format=output_format,
                                                  concat=concat,
                                                  stride=stride,
                                                  fp16=fp16,
                                                  prev_context=prev_context,
                                                  num_workers=num_workers)

    src = binarized_src['data']
    src_sizes = binarized_src['sizes']

    if add_bos:
        tgt_bos_word = onmt.constants.BOS_WORD
    else:
        tgt_bos_word = None

    print("[INFO] Binarizing file %s ..." % tgt_file)
    binarized_tgt = Binarizer.binarize_file(tgt_file,
                                            tgt_dicts,
                                            tokenizer,
                                            bos_word=tgt_bos_word,
                                            eos_word=onmt.constants.EOS_WORD,
                                            data_type=data_type,
                                            num_workers=num_workers,
                                            verbose=verbose)

    tgt = binarized_tgt['data']
    tgt_sizes = binarized_tgt['sizes']

    ignored = 0

    if len(src_sizes) != len(tgt_sizes):
        print("Warning: data size mismatched.")

    print(('Prepared %d sentences ' +
           '(%d ignored due to length == 0 or src len > %d or tgt len > %d)') %
          (len(src), ignored, max_src_length, max_tgt_length))

    return src, tgt, src_sizes, tgt_sizes