예제 #1
0
def main():
    dicts = {}

    tokenizer = onmt.Tokenizer(opt.input_type, opt.lower)

    start = time.time()
    # for ASR and LM we only need to build vocab for the 'target' language
    if opt.asr or opt.lm:
        dicts['tgt'] = init_vocab('target', [opt.train_tgt],
                                  opt.tgt_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)
    elif opt.join_vocab:
        dicts['src'] = init_vocab('source', [opt.train_src, opt.train_tgt],
                                  opt.src_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)
        dicts['tgt'] = dicts['src']

    else:
        dicts['src'] = init_vocab('source', [opt.train_src],
                                  opt.src_vocab,
                                  opt.src_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)

        dicts['tgt'] = init_vocab('target', [opt.train_tgt],
                                  opt.tgt_vocab,
                                  opt.tgt_vocab_size,
                                  tokenizer,
                                  num_workers=opt.num_threads)

    elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
    print("Vocabulary generated after %s" % elapse)

    if opt.lm:
        print('Preparing training language model ...')
        train = dict()
        train['tgt'] = make_lm_data(opt.train_tgt, dicts['tgt'])
        train['src'] = None

        valid = dict()
        valid['tgt'] = make_lm_data(opt.valid_tgt, dicts['tgt'])
        valid['src'] = None

    elif opt.asr:
        print('Preparing training acoustic model ...')
        train = dict()
        train['src'], train['tgt'] = make_asr_data(
            opt.train_src,
            opt.train_tgt,
            dicts['tgt'],
            max_src_length=opt.src_seq_length,
            max_tgt_length=opt.tgt_seq_length,
            input_type=opt.input_type,
            stride=opt.stride,
            concat=opt.concat,
            prev_context=opt.previous_context,
            fp16=opt.fp16,
            reshape=(opt.reshape_speech == 1),
            asr_format=opt.asr_format)

        print('Preparing validation ...')
        valid = dict()
        valid['src'], valid['tgt'] = make_asr_data(
            opt.valid_src,
            opt.valid_tgt,
            dicts['tgt'],
            max_src_length=max(1024, opt.src_seq_length),
            max_tgt_length=max(1024, opt.tgt_seq_length),
            input_type=opt.input_type,
            stride=opt.stride,
            concat=opt.concat,
            prev_context=opt.previous_context,
            fp16=opt.fp16,
            reshape=(opt.reshape_speech == 1),
            asr_format=opt.asr_format)

    else:
        start = time.time()
        print('Binarizing the data to train translation models...')
        train = dict()
        train['src'], train['tgt'] = make_translation_data(
            opt.train_src,
            opt.train_tgt,
            dicts['src'],
            dicts['tgt'],
            tokenizer,
            max_src_length=opt.src_seq_length,
            max_tgt_length=opt.tgt_seq_length,
            add_bos=(not opt.no_bos),
            data_type=opt.data_type,
            num_workers=opt.num_threads,
            verbose=opt.verbose)

        print('Preparing validation ...')
        valid = dict()
        valid['src'], valid['tgt'] = make_translation_data(
            opt.valid_src,
            opt.valid_tgt,
            dicts['src'],
            dicts['tgt'],
            tokenizer,
            max_src_length=max(1024, opt.src_seq_length),
            max_tgt_length=max(1024, opt.tgt_seq_length),
            add_bos=(not opt.no_bos),
            data_type=opt.data_type,
            num_workers=opt.num_threads,
            verbose=opt.verbose)

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Binarization finished after %s" % elapse)

    if opt.src_vocab is None and opt.asr == False and opt.lm == False:
        save_vocabulary('source', dicts['src'], opt.save_data + '.src.dict')
    if opt.tgt_vocab is None:
        save_vocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict')

    if opt.format == 'raw':

        print('Saving data to \'' + opt.save_data + '.train.pt\'...')
        save_data = {
            'dicts': dicts,
            'type': opt.src_type,
            'train': train,
            'valid': valid
        }
        torch.save(save_data, opt.save_data + '.train.pt')
        print("Done")

    elif opt.format == 'bin':
        print('Saving data to indexed data files')

        if opt.asr:
            print("ASR data format isn't compatible with binary")
            raise AssertionError
        # save dicts in this format
        torch.save(dicts, opt.save_data + '.dict.pt')

        # binarize the training set first
        for set in ['src', 'tgt']:

            if train[set] is None:
                continue
            dtype = np.int32

            if set == 'src' and opt.asr:
                dtype = np.double

            data = IndexedDatasetBuilder(opt.save_data + ".train.%s.bin" % set,
                                         dtype=dtype)

            # add item from training data to the indexed data
            for tensor in train[set]:
                data.add_item(tensor)

            data.finalize(opt.save_data + ".train.%s.idx" % set)

        # binarize the validation set
        for set in ['src', 'tgt']:

            if valid[set] is None:
                continue

            if opt.data_type == 'int64':
                dtype = np.int64
            else:
                dtype = np.int32

            if set == 'src' and opt.asr:
                dtype = np.double

            data = IndexedDatasetBuilder(opt.save_data + ".valid.%s.bin" % set,
                                         dtype=dtype)

            # add item from training data to the indexed data
            for tensor in valid[set]:
                data.add_item(tensor)

            data.finalize(opt.save_data + ".valid.%s.idx" % set)

        print("Done")
    elif opt.format in ['mmap', 'mmem']:
        start = time.time()
        print('Saving data to memory indexed data files')
        from onmt.data_utils.MMapIndexedDataset import MMapIndexedDatasetBuilder

        if opt.asr:
            print(
                "ASR data format isn't compatible with memory indexed format")
            raise AssertionError

        # save dicts in this format
        torch.save(dicts, opt.save_data + '.dict.pt')

        # binarize the training set first
        for set in ['src', 'tgt']:
            if train[set] is None:
                continue

            if opt.data_type == 'int64':
                dtype = np.int64
            else:
                dtype = np.int32

            if set == 'src' and opt.asr:
                dtype = np.double

            train_data = MMapIndexedDatasetBuilder(opt.save_data +
                                                   ".train.%s.bin" % set,
                                                   dtype=dtype)

            # add item from training data to the indexed data
            for tensor in train[set]:
                train_data.add_item(tensor)

            train_data.finalize(opt.save_data + ".train.%s.idx" % set)

            del train_data

            if valid[set] is None:
                continue

            if set == 'src' and opt.asr:
                dtype = np.double

            valid_data = MMapIndexedDatasetBuilder(opt.save_data +
                                                   ".valid.%s.bin" % set,
                                                   dtype=dtype)

            # add item from training data to the indexed data
            for tensor in valid[set]:
                valid_data.add_item(tensor)

            valid_data.finalize(opt.save_data + ".valid.%s.idx" % set)

            del valid_data
        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Saving finished after %s" % elapse)

    else:
        raise NotImplementedError
예제 #2
0
def main():

    dicts = {}

    # for ASR and LM we only need to build vocab for the 'target' language
    if opt.asr or opt.lm:
        dicts['tgt'] = init_vocab('target', opt.train_tgt, opt.tgt_vocab,
                                      opt.tgt_vocab_size, input_type=opt.input_type)
    elif opt.join_vocab:
        dicts['src'] = init_vocab('source', [opt.train_src, opt.train_tgt], opt.src_vocab,
                                      opt.tgt_vocab_size, join=True, input_type=opt.input_type)
        dicts['tgt'] = dicts['src']

    else:
        dicts['src'] = init_vocab('source', opt.train_src, opt.src_vocab,
                                      opt.src_vocab_size, input_type=opt.input_type)

        dicts['tgt'] = init_vocab('target', opt.train_tgt, opt.tgt_vocab,
                                      opt.tgt_vocab_size, input_type=opt.input_type)


    if opt.lm:
        print('Preparing training language model ...')
        train = dict()
        train['tgt'] = make_lm_data( opt.train_tgt,
                                     dicts['tgt'])
        train['src'] = None

        valid = dict()
        valid['tgt'] = make_lm_data(opt.valid_tgt,
                                   dicts['tgt'])
        valid['src'] = None

    elif opt.asr:
        print('Preparing training acoustic model ...')
        train = dict()
        train['src'], train['tgt'] = make_asr_data(opt.train_src, opt.train_tgt,
                                           dicts['tgt'],
                                                 max_src_length=opt.src_seq_length,
                                                 max_tgt_length=opt.tgt_seq_length,
                                                 input_type=opt.input_type,
                                                 stride=opt.stride,concat=opt.concat,
                                                   prev_context=opt.previous_context,
                                                   fp16=opt.fp16,reshape=(opt.reshape_speech==1),asr_format=opt.asr_format)

        print('Preparing validation ...')
        valid = dict()
        valid['src'], valid['tgt'] = make_asr_data(opt.valid_src, opt.valid_tgt,
                                             dicts['tgt'],
                                                 max_src_length=max(1024,opt.src_seq_length),
                                                 max_tgt_length=max(1024,opt.tgt_seq_length),
                                                 input_type=opt.input_type,
                                                 stride=opt.stride,concat=opt.concat,
                                                   prev_context=opt.previous_context,
                                                   fp16=opt.fp16,reshape=(opt.reshape_speech==1),asr_format=opt.asr_format)

    else:
        print('Preparing training translation model...')
        train = dict()
        train['src'], train['tgt'] = make_translation_data(opt.train_src, opt.train_tgt,
                                          dicts['src'], dicts['tgt'],
                                          max_src_length=opt.src_seq_length,
                                          max_tgt_length=opt.tgt_seq_length,
                                          sort_by_target=opt.sort_by_target,
                                          input_type=opt.input_type)

        print('Preparing validation ...')
        valid = dict()
        valid['src'], valid['tgt'] = make_translation_data(opt.valid_src, opt.valid_tgt,
                                          dicts['src'], dicts['tgt'], 
                                          max_src_length=max(1024,opt.src_seq_length),
                                          max_tgt_length=max(1024,opt.tgt_seq_length),
                                          input_type=opt.input_type)

    if opt.src_vocab is None and opt.asr == False and opt.lm == False:
        save_vocabulary('source', dicts['src'], opt.save_data + '.src.dict')
    if opt.tgt_vocab is None:
        save_vocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict')


    if opt.format == 'raw':

        print('Saving data to \'' + opt.save_data + '.train.pt\'...')
        save_data = {'dicts': dicts,
                     'type':  opt.src_type,
                     'train': train,
                     'valid': valid}
        torch.save(save_data, opt.save_data + '.train.pt')
        print("Done")

    elif opt.format == 'bin':
        print('Saving data to indexed data files')

        if opt.asr:
            print("ASR data format isn't compatible with binary")
            raise AssertionError
        # save dicts in this format
        torch.save(dicts, opt.save_data + '.dict.pt')

        # binarize the training set first
        for set in ['src', 'tgt']:

            if train[set] is None:
                continue
            dtype=np.int32

            if set == 'src' and opt.asr:
                dtype=np.double

            data = IndexedDatasetBuilder(opt.save_data + ".train.%s.bin" % set, dtype=dtype)

            # add item from training data to the indexed data
            for tensor in train[set]:
                data.add_item(tensor)

            data.finalize(opt.save_data + ".train.%s.idx" % set)

        # binarize the validation set
        for set in ['src', 'tgt']:

            if valid[set] is None:
                continue

            dtype = np.int32

            if set == 'src' and opt.asr:
                dtype = np.double

            data = IndexedDatasetBuilder(opt.save_data + ".valid.%s.bin" % set, dtype=dtype)

            # add item from training data to the indexed data
            for tensor in valid[set]:
                data.add_item(tensor)

            data.finalize(opt.save_data + ".valid.%s.idx" % set)

        print("Done")

    else: raise NotImplementedError
예제 #3
0
def main():

    dicts = {}

    if opt.join_vocab:
        dicts['src'] = initVocabulary('source', [opt.train_src, opt.train_tgt],
                                      opt.src_vocab,
                                      opt.tgt_vocab_size,
                                      join=True,
                                      input_type=opt.input_type)
        dicts['tgt'] = dicts['src']
    else:
        dicts['src'] = initVocabulary('source',
                                      opt.train_src,
                                      opt.src_vocab,
                                      opt.src_vocab_size,
                                      input_type=opt.input_type)

        dicts['tgt'] = initVocabulary('target',
                                      opt.train_tgt,
                                      opt.tgt_vocab,
                                      opt.tgt_vocab_size,
                                      input_type=opt.input_type)
    train = {}
    valid = {}

    print('Preparing training ...')

    train['src'], train['tgt'] = makeData(opt.train_src,
                                          opt.train_tgt,
                                          dicts['src'],
                                          dicts['tgt'],
                                          max_src_length=opt.src_seq_length,
                                          max_tgt_length=opt.tgt_seq_length,
                                          sort_by_target=opt.sort_by_target,
                                          input_type=opt.input_type)

    print('Preparing validation ...')

    valid['src'], valid['tgt'] = makeData(
        opt.valid_src,
        opt.valid_tgt,
        dicts['src'],
        dicts['tgt'],
        max_src_length=max(1024, opt.src_seq_length),
        max_tgt_length=max(1024, opt.tgt_seq_length),
        input_type=opt.input_type)

    if opt.format == 'raw':

        if opt.src_vocab is None:
            saveVocabulary('source', dicts['src'], opt.save_data + '.src.dict')
        if opt.tgt_vocab is None:
            saveVocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict')

        print('Saving data to \'' + opt.save_data + '.train.pt\'...')
        save_data = {
            'dicts': dicts,
            'type': opt.src_type,
            'train': train,
            'valid': valid
        }
        torch.save(save_data, opt.save_data + '.train.pt')
        print("Done")
    elif opt.format == 'bin':
        # save the dictionary first
        torch.save(dicts, opt.save_data + '.dict.pt')

        train_bin = dict()

        # binarize the data now
        for set in ['src', 'tgt']:
            train_bin[set] = IndexedDatasetBuilder(opt.save_data +
                                                   ".train.%s.bin" % set)

            for tensor_sent in train[set]:
                train_bin[set].add_item(tensor_sent)

            train_bin[set].finalize(opt.save_data + ".train.%s.idx" % set)

        valid_bin = dict()
        for set in ['src', 'tgt']:
            valid_bin[set] = IndexedDatasetBuilder(opt.save_data +
                                                   ".valid.%s.bin" % set)

            for tensor_sent in valid[set]:
                valid_bin[set].add_item(tensor_sent)

            valid_bin[set].finalize(opt.save_data + ".valid.%s.idx" % set)

        print("Done")
    else:
        raise NotImplementedError