def main():
    dicts = {}

    tokenizer = onmt.Tokenizer(opt.input_type, opt.lower)

    # construct set of languages from the training languages
    src_langs = opt.train_src_lang.split("|")
    tgt_langs = opt.train_tgt_lang.split("|")
    langs = (src_langs + tgt_langs)
    langs = list(set(langs))

    if opt.load_dict is not None:
        loaded_dict = torch.load(opt.load_dict)

        new_languages = list()
        for lang in langs:
            if lang not in loaded_dict['langs']:
                new_languages.append(lang)

        dicts['langs'] = loaded_dict['langs']
        print("Loaded dictionary for languages: ", dicts['langs'])
        if len(new_languages) > 0:
            for lang in new_languages:
                idx = len(dicts['langs'])
                dicts['langs'][lang] = idx
            print("Added new languages: ", new_languages)

        # dicts['tgt'] = loaded_dict['tgt']
        # dicts['src'] = loaded_dict['src'] if 'src' in loaded_dict else None
    else:
        dicts['langs'] = dict()

        for lang in langs:
            idx = len(dicts['langs'])
            dicts['langs'][lang] = idx

        print(dicts['langs'])

    start = time.time()

    src_train_files = opt.train_src.split("|")
    if not opt.bases2s:
        tgt_train_files = opt.train_tgt.split("|")
    # tgt_train_files = opt.train_tgt.split("|")
    # for ASR and LM we only need to build vocab for the 'target' language

    # TODO: adding new words to the existing dictionary in case loading from previously created dict
    if opt.bases2s:
        print("Do not create dictionary")
    elif opt.asr or opt.lm:
        dicts['tgt'] = init_vocab('target',
                                  tgt_train_files,
                                  opt.tgt_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)
    elif opt.join_vocab:
        dicts['src'] = init_vocab('source',
                                  set(src_train_files + tgt_train_files),
                                  opt.src_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)
        dicts['tgt'] = dicts['src']

    else:
        dicts['src'] = init_vocab('source',
                                  src_train_files,
                                  opt.src_vocab,
                                  opt.src_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)

        dicts['tgt'] = init_vocab('target',
                                  tgt_train_files,
                                  opt.tgt_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)

    elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
    print("Vocabulary generated after %s" % elapse)

    # DATA GENERATION starts from here

    if opt.lm:
        raise NotImplementedError

    elif opt.bases2s:
        print('Preparing training speech autoencoder ...')
        src_input_files = opt.train_src.split("|")
        src_langs = opt.train_src_lang.split("|")
        # train = dict()
        idx = opt.starting_train_idx

        for (src_file, src_lang) in zip(src_input_files, src_langs):
            # First, read and convert data to tensor format

            src_data, src_sizes = make_bases2s_data(
                src_file,
                stride=opt.stride,
                concat=opt.concat,
                prev_context=opt.previous_context,
                num_workers=opt.num_threads,
                fp16=opt.fp16,
                asr_format=opt.asr_format,
                output_format=opt.format)

            src_lang_data = [torch.Tensor([dicts['langs'][src_lang]])]

            data = dict()

            data['src'] = src_data

            data['src_sizes'] = src_sizes
            data['src_lang'] = src_lang_data
            data['tgt_sizes'] = None
            data['tgt'] = None
            data['tgt_lang'] = None

            print("Saving training set %i %s-%s to disk ..." %
                  (idx, src_lang, src_lang))

            # take basedir from opt.save_data
            path = os.path.join(dirname(opt.save_data),
                                "train.%i.%s-%s" % (idx, src_lang, src_lang))
            os.makedirs(path, exist_ok=True)

            # save data immediately
            save_dataset(path, data, opt.format, dicts, opt.src_type)
            idx = idx + 1

        src_input_files = opt.valid_src.split("|")

        src_langs = opt.valid_src_lang.split("|")
        n_input_files = len(src_input_files)

        idx = opt.starting_valid_idx

        for (src_file, src_lang) in zip(src_input_files, src_langs):
            # First, read and convert data to tensor format

            src_data, src_sizes = make_bases2s_data(
                src_file,
                stride=opt.stride,
                concat=opt.concat,
                prev_context=opt.previous_context,
                num_workers=opt.num_threads,
                fp16=opt.fp16,
                asr_format=opt.asr_format,
                output_format=opt.format)

            src_lang_data = [torch.Tensor([dicts['langs'][src_lang]])]

            data = dict()

            data['src'] = src_data

            data['src_sizes'] = src_sizes
            data['src_lang'] = src_lang_data
            data['tgt_sizes'] = None
            data['tgt'] = None
            data['tgt_lang'] = None

            print("Saving validation set %i %s-%s to disk ..." %
                  (idx, src_lang, src_lang))

            # take basedir from opt.save_data
            path = os.path.join(dirname(opt.save_data),
                                "valid.%i.%s-%s" % (idx, src_lang, src_lang))
            os.makedirs(path, exist_ok=True)

            # save data immediately
            save_dataset(path, data, opt.format, dicts, opt.src_type)
            idx = idx + 1

    elif opt.asr:
        print('Preparing training acoustic model ...')

        src_input_files = opt.train_src.split("|")
        tgt_input_files = opt.train_tgt.split("|")

        src_langs = opt.train_src_lang.split("|")
        tgt_langs = opt.train_tgt_lang.split("|")

        assert len(src_input_files) == len(src_langs)
        assert len(src_input_files) == len(tgt_input_files)
        assert len(tgt_input_files) == len(tgt_langs)

        n_input_files = len(src_input_files)

        idx = opt.starting_train_idx
        for (src_file, tgt_file, src_lang,
             tgt_lang) in zip(src_input_files, tgt_input_files, src_langs,
                              tgt_langs):
            # First, read and convert data to tensor format

            src_data, tgt_data, src_sizes, tgt_sizes = make_asr_data(
                src_file,
                tgt_file,
                dicts['tgt'],
                tokenizer,
                max_src_length=opt.src_seq_length,
                max_tgt_length=opt.tgt_seq_length,
                input_type=opt.input_type,
                stride=opt.stride,
                concat=opt.concat,
                prev_context=opt.previous_context,
                fp16=opt.fp16,
                asr_format=opt.asr_format,
                output_format=opt.format,
                num_workers=opt.num_threads)

            # save each dataset as bilingual (no multi parallel data)
            # we only need to have 1 language per file
            # which will be broadcasted
            n_samples = len(src_data)
            src_lang_data = [torch.Tensor([dicts['langs'][src_lang]])]
            tgt_lang_data = [torch.Tensor([dicts['langs'][tgt_lang]])]

            data = dict()

            data['src'] = src_data
            data['tgt'] = tgt_data

            data['src_sizes'] = src_sizes
            data['tgt_sizes'] = tgt_sizes
            data['src_lang'] = src_lang_data
            data['tgt_lang'] = tgt_lang_data

            print("Saving training set %i %s-%s to disk ..." %
                  (idx, src_lang, tgt_lang))

            # take basedir from opt.save_data
            path = os.path.join(dirname(opt.save_data),
                                "train.%i.%s-%s" % (idx, src_lang, tgt_lang))
            os.makedirs(path, exist_ok=True)

            # save data immediately
            save_dataset(path, data, opt.format, dicts, opt.src_type)
            idx = idx + 1
            # create

        print('Preparing validation ...')

        src_input_files = opt.valid_src.split("|")
        tgt_input_files = opt.valid_tgt.split("|")

        src_langs = opt.valid_src_lang.split("|")
        tgt_langs = opt.valid_tgt_lang.split("|")

        assert len(src_input_files) == len(src_langs)
        assert len(src_input_files) == len(tgt_input_files)
        assert len(tgt_input_files) == len(tgt_langs)

        n_input_files = len(src_input_files)

        idx = opt.starting_valid_idx

        for (src_file, tgt_file, src_lang,
             tgt_lang) in zip(src_input_files, tgt_input_files, src_langs,
                              tgt_langs):
            src_data, tgt_data, src_sizes, tgt_sizes = make_asr_data(
                src_file,
                tgt_file,
                dicts['tgt'],
                tokenizer,
                max_src_length=max(1024, opt.src_seq_length),
                max_tgt_length=max(1024, opt.tgt_seq_length),
                input_type=opt.input_type,
                stride=opt.stride,
                concat=opt.concat,
                prev_context=opt.previous_context,
                fp16=opt.fp16,
                asr_format=opt.asr_format,
                output_format=opt.format)

            # save each dataset as bilingual (no multi parallel data)
            # we only need to have 1 language per file
            # which will be broadcasted
            n_samples = len(src_data)
            src_lang_data = [torch.Tensor([dicts['langs'][src_lang]])]
            tgt_lang_data = [torch.Tensor([dicts['langs'][tgt_lang]])]

            data = dict()

            data['src'] = src_data
            data['tgt'] = tgt_data

            data['src_sizes'] = src_sizes
            data['tgt_sizes'] = tgt_sizes
            data['src_lang'] = src_lang_data
            data['tgt_lang'] = tgt_lang_data

            print("Saving validation set %i %s-%s to disk ..." %
                  (idx, src_lang, tgt_lang))

            # take basedir from opt.save_data
            path = os.path.join(dirname(opt.save_data),
                                "valid.%i.%s-%s" % (idx, src_lang, tgt_lang))
            os.makedirs(path, exist_ok=True)

            # save data immediately
            save_dataset(path, data, opt.format, dicts, opt.src_type)
            idx = idx + 1

    else:

        # Translation dataset
        src_input_files = opt.train_src.split("|")
        tgt_input_files = opt.train_tgt.split("|")

        src_langs = opt.train_src_lang.split("|")
        tgt_langs = opt.train_tgt_lang.split("|")

        assert len(src_input_files) == len(src_langs)
        assert len(src_input_files) == len(tgt_input_files)
        assert len(tgt_input_files) == len(tgt_langs)

        n_input_files = len(src_input_files)

        start = time.time()
        print('Binarizing data to train translation models...')
        idx = opt.starting_train_idx

        for (src_file, tgt_file, src_lang,
             tgt_lang) in zip(src_input_files, tgt_input_files, src_langs,
                              tgt_langs):
            src_data, tgt_data, src_sizes, tgt_sizes = make_translation_data(
                src_file,
                tgt_file,
                dicts['src'],
                dicts['tgt'],
                tokenizer,
                max_src_length=opt.src_seq_length,
                max_tgt_length=opt.tgt_seq_length,
                add_bos=(not opt.no_bos),
                data_type=opt.data_type,
                num_workers=opt.num_threads,
                verbose=opt.verbose)

            # save each dataset as bilingual (no multi parallel data)
            # we only need to have 1 language per file
            # which will be broadcasted
            n_samples = len(src_data)
            src_lang_data = [torch.Tensor([dicts['langs'][src_lang]])]
            tgt_lang_data = [torch.Tensor([dicts['langs'][tgt_lang]])]

            data = dict()
            data['src'] = src_data
            data['tgt'] = tgt_data

            data['src_sizes'] = src_sizes
            data['tgt_sizes'] = tgt_sizes
            data['src_lang'] = src_lang_data
            data['tgt_lang'] = tgt_lang_data

            print("Saving training set %i %s-%s to disk ..." %
                  (idx, src_lang, tgt_lang))

            # take basedir from opt.save_data
            path = os.path.join(dirname(opt.save_data),
                                "train.%i.%s-%s" % (idx, src_lang, tgt_lang))
            os.makedirs(path, exist_ok=True)

            # save data immediately
            save_dataset(path, data, opt.format, dicts, opt.src_type)
            idx = idx + 1

        print('Preparing validation ...')

        src_input_files = opt.valid_src.split("|")
        tgt_input_files = opt.valid_tgt.split("|")

        src_langs = opt.valid_src_lang.split("|")
        tgt_langs = opt.valid_tgt_lang.split("|")

        assert len(src_input_files) == len(src_langs)
        assert len(src_input_files) == len(tgt_input_files)
        assert len(tgt_input_files) == len(tgt_langs)

        n_input_files = len(src_input_files)

        idx = opt.starting_valid_idx
        for (src_file, tgt_file, src_lang,
             tgt_lang) in zip(src_input_files, tgt_input_files, src_langs,
                              tgt_langs):
            src_data, tgt_data, src_sizes, tgt_sizes = make_translation_data(
                src_file,
                tgt_file,
                dicts['src'],
                dicts['tgt'],
                tokenizer,
                max_src_length=max(1024, opt.src_seq_length),
                max_tgt_length=max(1024, opt.tgt_seq_length),
                add_bos=(not opt.no_bos),
                data_type=opt.data_type,
                num_workers=opt.num_threads,
                verbose=opt.verbose)

            # save each dataset as bilingual (no multi parallel data)
            # we only need to have 1 language per file
            # which will be broadcasted
            n_samples = len(src_data)
            src_lang_data = [torch.Tensor([dicts['langs'][src_lang]])]
            tgt_lang_data = [torch.Tensor([dicts['langs'][tgt_lang]])]

            data = dict()
            data['src'] = src_data
            data['tgt'] = tgt_data

            data['src_sizes'] = src_sizes
            data['tgt_sizes'] = tgt_sizes
            data['src_lang'] = src_lang_data
            data['tgt_lang'] = tgt_lang_data

            print("Saving validation set %i %s-%s to disk ..." %
                  (idx, src_lang, tgt_lang))

            # take basedir from opt.save_data
            path = os.path.join(dirname(opt.save_data),
                                "valid.%i.%s-%s" % (idx, src_lang, tgt_lang))
            os.makedirs(path, exist_ok=True)

            # save data immediately
            save_dataset(path, data, opt.format, dicts, opt.src_type)
            idx = idx + 1

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Binarization finished after %s" % elapse)

    print("Saving dictionary to %s" % (opt.save_data + '.dict.pt'))
    torch.save(dicts, opt.save_data + '.dict.pt')

    if opt.src_vocab is None and opt.asr == False and opt.lm == False and not opt.bases2s:
        save_vocabulary('source', dicts['src'], opt.save_data + '.src.dict')
    if opt.tgt_vocab is None and not opt.bases2s:
        save_vocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict')

    print("Finished.")
Exemplo n.º 2
0
def main():
    dicts = {}

    tokenizer = onmt.Tokenizer(opt.input_type, opt.lower)

    start = time.time()
    # for ASR and LM we only need to build vocab for the 'target' language
    if opt.asr or opt.lm:
        dicts['tgt'] = init_vocab('target', [opt.train_tgt],
                                  opt.tgt_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)
    elif opt.join_vocab:
        dicts['src'] = init_vocab('source', [opt.train_src, opt.train_tgt],
                                  opt.src_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)
        dicts['tgt'] = dicts['src']

    else:
        dicts['src'] = init_vocab('source', [opt.train_src],
                                  opt.src_vocab,
                                  opt.src_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)

        dicts['tgt'] = init_vocab('target', [opt.train_tgt],
                                  opt.tgt_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)

    elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
    print("Vocabulary generated after %s" % elapse)

    if opt.lm:
        print('Preparing training language model ...')
        train = dict()
        train['tgt'] = make_lm_data(opt.train_tgt, dicts['tgt'])
        train['src'] = None

        valid = dict()
        valid['tgt'] = make_lm_data(opt.valid_tgt, dicts['tgt'])
        valid['src'] = None

    elif opt.asr:
        print('Preparing training acoustic model ...')
        train = dict()
        train['src'], train['tgt'] = make_asr_data(
            opt.train_src,
            opt.train_tgt,
            dicts['tgt'],
            max_src_length=opt.src_seq_length,
            max_tgt_length=opt.tgt_seq_length,
            input_type=opt.input_type,
            stride=opt.stride,
            concat=opt.concat,
            prev_context=opt.previous_context,
            fp16=opt.fp16,
            reshape=(opt.reshape_speech == 1),
            asr_format=opt.asr_format)

        print('Preparing validation ...')
        valid = dict()
        valid['src'], valid['tgt'] = make_asr_data(
            opt.valid_src,
            opt.valid_tgt,
            dicts['tgt'],
            max_src_length=max(1024, opt.src_seq_length),
            max_tgt_length=max(1024, opt.tgt_seq_length),
            input_type=opt.input_type,
            stride=opt.stride,
            concat=opt.concat,
            prev_context=opt.previous_context,
            fp16=opt.fp16,
            reshape=(opt.reshape_speech == 1),
            asr_format=opt.asr_format)

    else:
        start = time.time()
        print('Binarizing the data to train translation models...')
        train = dict()
        train['src'], train['tgt'] = make_translation_data(
            opt.train_src,
            opt.train_tgt,
            dicts['src'],
            dicts['tgt'],
            tokenizer,
            max_src_length=opt.src_seq_length,
            max_tgt_length=opt.tgt_seq_length,
            add_bos=(not opt.no_bos),
            data_type=opt.data_type,
            num_workers=opt.num_threads,
            verbose=opt.verbose)

        print('Preparing validation ...')
        valid = dict()
        valid['src'], valid['tgt'] = make_translation_data(
            opt.valid_src,
            opt.valid_tgt,
            dicts['src'],
            dicts['tgt'],
            tokenizer,
            max_src_length=max(1024, opt.src_seq_length),
            max_tgt_length=max(1024, opt.tgt_seq_length),
            add_bos=(not opt.no_bos),
            data_type=opt.data_type,
            num_workers=opt.num_threads,
            verbose=opt.verbose)

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Binarization finished after %s" % elapse)

    if opt.src_vocab is None and opt.asr == False and opt.lm == False:
        save_vocabulary('source', dicts['src'], opt.save_data + '.src.dict')
    if opt.tgt_vocab is None:
        save_vocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict')

    if opt.format == 'raw':

        print('Saving data to \'' + opt.save_data + '.train.pt\'...')
        save_data = {
            'dicts': dicts,
            'type': opt.src_type,
            'train': train,
            'valid': valid
        }
        torch.save(save_data, opt.save_data + '.train.pt')
        print("Done")

    elif opt.format == 'bin':
        print('Saving data to indexed data files')

        if opt.asr:
            print("ASR data format isn't compatible with binary")
            raise AssertionError
        # save dicts in this format
        torch.save(dicts, opt.save_data + '.dict.pt')

        # binarize the training set first
        for set in ['src', 'tgt']:

            if train[set] is None:
                continue
            dtype = np.int32

            if set == 'src' and opt.asr:
                dtype = np.double

            data = IndexedDatasetBuilder(opt.save_data + ".train.%s.bin" % set,
                                         dtype=dtype)

            # add item from training data to the indexed data
            for tensor in train[set]:
                data.add_item(tensor)

            data.finalize(opt.save_data + ".train.%s.idx" % set)

        # binarize the validation set
        for set in ['src', 'tgt']:

            if valid[set] is None:
                continue

            if opt.data_type == 'int64':
                dtype = np.int64
            else:
                dtype = np.int32

            if set == 'src' and opt.asr:
                dtype = np.double

            data = IndexedDatasetBuilder(opt.save_data + ".valid.%s.bin" % set,
                                         dtype=dtype)

            # add item from training data to the indexed data
            for tensor in valid[set]:
                data.add_item(tensor)

            data.finalize(opt.save_data + ".valid.%s.idx" % set)

        print("Done")
    elif opt.format in ['mmap', 'mmem']:
        start = time.time()
        print('Saving data to memory indexed data files')
        from onmt.data_utils.MMapIndexedDataset import MMapIndexedDatasetBuilder

        if opt.asr:
            print(
                "ASR data format isn't compatible with memory indexed format")
            raise AssertionError

        # save dicts in this format
        torch.save(dicts, opt.save_data + '.dict.pt')

        # binarize the training set first
        for set in ['src', 'tgt']:
            if train[set] is None:
                continue

            if opt.data_type == 'int64':
                dtype = np.int64
            else:
                dtype = np.int32

            if set == 'src' and opt.asr:
                dtype = np.double

            train_data = MMapIndexedDatasetBuilder(opt.save_data +
                                                   ".train.%s.bin" % set,
                                                   dtype=dtype)

            # add item from training data to the indexed data
            for tensor in train[set]:
                train_data.add_item(tensor)

            train_data.finalize(opt.save_data + ".train.%s.idx" % set)

            del train_data

            if valid[set] is None:
                continue

            if set == 'src' and opt.asr:
                dtype = np.double

            valid_data = MMapIndexedDatasetBuilder(opt.save_data +
                                                   ".valid.%s.bin" % set,
                                                   dtype=dtype)

            # add item from training data to the indexed data
            for tensor in valid[set]:
                valid_data.add_item(tensor)

            valid_data.finalize(opt.save_data + ".valid.%s.idx" % set)

            del valid_data
        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Saving finished after %s" % elapse)

    else:
        raise NotImplementedError
def main():
    dicts = {}

    tokenizer = onmt.Tokenizer(opt.input_type, opt.lower)

    # construct set of languages from the training languages
    src_langs = opt.train_src_lang.split("|")
    tgt_langs = opt.train_tgt_lang.split("|")
    langs = (src_langs + tgt_langs)
    langs = list(set(langs))

    dicts['langs'] = dict()

    for lang in langs:
        idx = len(dicts['langs'])
        dicts['langs'][lang] = idx

    print(dicts['langs'])

    start = time.time()

    src_train_files = opt.train_src.split("|")
    tgt_train_files = opt.train_tgt.split("|")
    # for ASR and LM we only need to build vocab for the 'target' language
    if opt.asr or opt.lm:
        dicts['tgt'] = init_vocab('target',
                                  tgt_train_files,
                                  opt.tgt_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)
    elif opt.join_vocab:
        dicts['src'] = init_vocab('source',
                                  set(src_train_files + tgt_train_files),
                                  opt.src_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)
        dicts['tgt'] = dicts['src']

    else:
        dicts['src'] = init_vocab('source',
                                  src_train_files,
                                  opt.src_vocab,
                                  opt.src_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)

        dicts['tgt'] = init_vocab('target',
                                  tgt_train_files,
                                  opt.tgt_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)

    elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
    print("Vocabulary generated after %s" % elapse)

    if opt.lm:
        print('Preparing training language model ...')
        train = dict()
        train['tgt'] = make_lm_data(opt.train_tgt, dicts['tgt'])
        train['src'] = None

        valid = dict()
        valid['tgt'] = make_lm_data(opt.valid_tgt, dicts['tgt'])
        valid['src'] = None

    elif opt.asr:
        print('Preparing training acoustic model ...')
        train = dict()
        train['src'], train['tgt'] = make_asr_data(
            opt.train_src,
            opt.train_tgt,
            dicts['tgt'],
            max_src_length=opt.src_seq_length,
            max_tgt_length=opt.tgt_seq_length,
            input_type=opt.input_type,
            stride=opt.stride,
            concat=opt.concat,
            prev_context=opt.previous_context,
            fp16=opt.fp16,
            reshape=(opt.reshape_speech == 1),
            asr_format=opt.asr_format)

        print('Preparing validation ...')
        valid = dict()
        valid['src'], valid['tgt'] = make_asr_data(
            opt.valid_src,
            opt.valid_tgt,
            dicts['tgt'],
            max_src_length=max(1024, opt.src_seq_length),
            max_tgt_length=max(1024, opt.tgt_seq_length),
            input_type=opt.input_type,
            stride=opt.stride,
            concat=opt.concat,
            prev_context=opt.previous_context,
            fp16=opt.fp16,
            reshape=(opt.reshape_speech == 1),
            asr_format=opt.asr_format)

    else:

        src_input_files = opt.train_src.split("|")
        tgt_input_files = opt.train_tgt.split("|")

        src_langs = opt.train_src_lang.split("|")
        tgt_langs = opt.train_tgt_lang.split("|")

        assert len(src_input_files) == len(src_langs)
        assert len(src_input_files) == len(tgt_input_files)
        assert len(tgt_input_files) == len(tgt_langs)

        n_input_files = len(src_input_files)

        train = dict()
        train['src'], train['tgt'] = list(), list()
        train['src_lang'], train['tgt_lang'] = list(), list()

        start = time.time()
        print('Binarizing data to train translation models...')

        for (src_file, tgt_file, src_lang,
             tgt_lang) in zip(src_input_files, tgt_input_files, src_langs,
                              tgt_langs):

            src_data, tgt_data = make_translation_data(
                src_file,
                tgt_file,
                dicts['src'],
                dicts['tgt'],
                tokenizer,
                max_src_length=opt.src_seq_length,
                max_tgt_length=opt.tgt_seq_length,
                add_bos=(not opt.no_bos),
                data_type=opt.data_type,
                num_workers=opt.num_threads,
                verbose=opt.verbose)

            n_samples = len(src_data)
            if n_input_files == 1:
                # For single-file cases we only need to have 1 language per file
                # which will be broadcasted
                src_lang_data = [torch.Tensor([dicts['langs'][src_lang]])]
                tgt_lang_data = [torch.Tensor([dicts['langs'][tgt_lang]])]
            else:
                # each sample will have a different language id
                src_lang_data = [
                    torch.Tensor([dicts['langs'][src_lang]])
                    for _ in range(n_samples)
                ]
                tgt_lang_data = [
                    torch.Tensor([dicts['langs'][tgt_lang]])
                    for _ in range(n_samples)
                ]

            train['src'] += src_data
            train['tgt'] += tgt_data
            train['src_lang'] += src_lang_data
            train['tgt_lang'] += tgt_lang_data

        print('Preparing validation ...')

        src_input_files = opt.valid_src.split("|")
        tgt_input_files = opt.valid_tgt.split("|")

        src_langs = opt.valid_src_lang.split("|")
        tgt_langs = opt.valid_tgt_lang.split("|")

        assert len(src_input_files) == len(src_langs)
        assert len(src_input_files) == len(tgt_input_files)
        assert len(tgt_input_files) == len(tgt_langs)

        n_input_files = len(src_input_files)

        valid = dict()
        valid['src'], valid['tgt'] = list(), list()
        valid['src_lang'], valid['tgt_lang'] = list(), list()

        for (src_file, tgt_file, src_lang,
             tgt_lang) in zip(src_input_files, tgt_input_files, src_langs,
                              tgt_langs):

            src_data, tgt_data = make_translation_data(
                src_file,
                tgt_file,
                dicts['src'],
                dicts['tgt'],
                tokenizer,
                max_src_length=max(1024, opt.src_seq_length),
                max_tgt_length=max(1024, opt.tgt_seq_length),
                add_bos=(not opt.no_bos),
                data_type=opt.data_type,
                num_workers=opt.num_threads,
                verbose=opt.verbose)

            n_samples = len(src_data)
            if n_input_files == 1:
                # For single-file cases we only need to have 1 language per file
                # which will be broadcasted
                src_lang_data = [torch.Tensor([dicts['langs'][src_lang]])]
                tgt_lang_data = [torch.Tensor([dicts['langs'][tgt_lang]])]
            else:
                # each sample will have a different language id
                src_lang_data = [
                    torch.Tensor([dicts['langs'][src_lang]])
                    for _ in range(n_samples)
                ]
                tgt_lang_data = [
                    torch.Tensor([dicts['langs'][tgt_lang]])
                    for _ in range(n_samples)
                ]

            valid['src'] += src_data
            valid['tgt'] += tgt_data
            valid['src_lang'] += src_lang_data
            valid['tgt_lang'] += tgt_lang_data

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Binarization finished after %s" % elapse)

    if opt.src_vocab is None and opt.asr == False and opt.lm == False:
        save_vocabulary('source', dicts['src'], opt.save_data + '.src.dict')
    if opt.tgt_vocab is None:
        save_vocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict')

    # SAVE DATA
    if opt.format in ['raw', 'bin']:

        print('Saving data to \'' + opt.save_data + '.train.pt\'...')
        save_data = {
            'dicts': dicts,
            'type': opt.src_type,
            'train': train,
            'valid': valid
        }
        torch.save(save_data, opt.save_data + '.train.pt')
        print("Done")

    elif opt.format in ['mmap', 'mmem']:
        print('Saving data to memory indexed data files')
        from onmt.data.mmap_indexed_dataset import MMapIndexedDatasetBuilder

        if opt.asr:
            print(
                "ASR data format isn't compatible with memory indexed format")
            raise AssertionError

        # save dicts in this format
        torch.save(dicts, opt.save_data + '.dict.pt')

        # binarize the training set first
        for set_ in ['src', 'tgt', 'src_lang', 'tgt_lang']:
            if train[set_] is None:
                continue

            if opt.data_type == 'int64':
                dtype = np.int64
            else:
                dtype = np.int32

            if set_ == 'src' and opt.asr:
                dtype = np.double

            train_data = MMapIndexedDatasetBuilder(opt.save_data +
                                                   ".train.%s.bin" % set_,
                                                   dtype=dtype)

            # add item from training data to the indexed data
            for tensor in train[set_]:
                train_data.add_item(tensor)

            train_data.finalize(opt.save_data + ".train.%s.idx" % set_)

            del train_data

            if valid[set_] is None:
                continue

            if set_ == 'src' and opt.asr:
                dtype = np.double

            valid_data = MMapIndexedDatasetBuilder(opt.save_data +
                                                   ".valid.%s.bin" % set_,
                                                   dtype=dtype)

            # add item from training data to the indexed data
            for tensor in valid[set_]:
                valid_data.add_item(tensor)

            valid_data.finalize(opt.save_data + ".valid.%s.idx" % set_)

            del valid_data

    else:
        raise NotImplementedError