示例#1
0
def main():

    start = time.time()
    print("Loading data from '%s'" % opt.data)

    if opt.data_format == 'raw':
        dataset = torch.load(opt.data)
        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse )


        train_data = onmt.Dataset(dataset['train']['src'],
                                 dataset['train']['tgt'], opt.batch_size_words,
                                 data_type=dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier = opt.batch_size_multiplier,
                                 sort_by_target=opt.sort_by_target)
        valid_data = onmt.Dataset(dataset['valid']['src'],
                                 dataset['valid']['tgt'], opt.batch_size_words,
                                 data_type=dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents)

        dicts = dataset['dicts']
        if "src" in dicts:
            print(' * vocabulary size. source = %d; target = %d' %
            (dicts['src'].size(), dicts['tgt'].size()))
        else:
            print(' * vocabulary size. target = %d' %
            (dicts['tgt'].size()))

        print(' * number of training sentences. %d' %
          train_data.size())
        print(' * maximum batch size (words per batch). %d' % opt.batch_size_words)

    else:
        raise NotImplementedError

    print('Building model...')
    model = build_language_model(opt, dicts)
    
    
    """ Building the loss function """

    loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)
    
    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.")
    else:
        if opt.fp16:
            trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt)
        else:
            trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt)

    
    trainer.run(save_file=opt.load_from)
示例#2
0
def main():

    start = time.time()
    print("Loading data from '%s'" % opt.data)

    if opt.data_format == 'raw':
        dataset = torch.load(opt.data)
        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse)

        train_data = LanguageModelDataset(
            dataset['train']['tgt'],
            batch_size_sents=opt.batch_size_sents,
            seq_length=opt.lm_seq_length)
        valid_data = LanguageModelDataset(
            dataset['valid']['tgt'],
            batch_size_sents=opt.batch_size_sents,
            seq_length=opt.lm_seq_length)

        dicts = dataset['dicts']
        if "src" in dicts:
            print(' * vocabulary size. source = %d; target = %d' %
                  (dicts['src'].size(), dicts['tgt'].size()))
        else:
            print(' * vocabulary size. target = %d' % (dicts['tgt'].size()))

        print(' * number of training sentences. %d' % train_data.size())
        print(' * maximum batch size (words per batch). %d' %
              opt.batch_size_words)

    else:
        raise NotImplementedError

    print('Building model...')
    model = build_language_model(opt, dicts)

    print(model)
    """ Building the loss function """

    loss_function = NMTLossFunc(dicts['tgt'].size(),
                                label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        raise NotImplementedError("Multi-GPU training is not supported ATM.")
    else:
        # if opt.fp16:
        #     trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt)
        # else:
        trainer = XETrainer(model, loss_function, train_data, valid_data,
                            dicts, opt)

    trainer.run(save_file=opt.load_from)
示例#3
0
def main():

    if opt.data_format == 'raw':
        start = time.time()
        if opt.data.endswith(".train.pt"):
            print("Loading data from '%s'" % opt.data)
            dataset = torch.load(opt.data)
        else:
            print("Loading data from %s" % opt.data + ".train.pt")
            dataset = torch.load(opt.data + ".train.pt")

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse )

        train_data = onmt.Dataset(dataset['train']['src'],
                                 dataset['train']['tgt'], opt.batch_size_words,
                                 data_type=dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier = opt.batch_size_multiplier,
                                 reshape_speech=opt.reshape_speech,
                                 augment=opt.augment_speech)
        valid_data = onmt.Dataset(dataset['valid']['src'],
                                 dataset['valid']['tgt'], opt.batch_size_words,
                                 data_type=dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents,
                                 reshape_speech=opt.reshape_speech)

        dicts = dataset['dicts']
        if "src" in dicts:
            print(' * vocabulary size. source = %d; target = %d' %
            (dicts['src'].size(), dicts['tgt'].size()))
        else:
            print(' * vocabulary size. target = %d' %
            (dicts['tgt'].size()))

        print(' * number of training sentences. %d' %
          len(dataset['train']['src']))
        print(' * maximum batch size (words per batch). %d' % opt.batch_size_words)

    elif opt.data_format == 'bin':

        from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset

        dicts = torch.load(opt.data + ".dict.pt")

        #~ train = {}
        train_path = opt.data + '.train'
        train_src = IndexedInMemoryDataset(train_path + '.src')
        train_tgt = IndexedInMemoryDataset(train_path + '.tgt')

        train_data = onmt.Dataset(train_src,
                                 train_tgt, opt.batch_size_words,
                                 data_type=opt.encoder_type,
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier = opt.batch_size_multiplier)

        valid_path = opt.data + '.valid'
        valid_src = IndexedInMemoryDataset(valid_path + '.src')
        valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt')

        valid_data = onmt.Dataset(valid_src,
                                 valid_tgt, opt.batch_size_words,
                                 data_type=opt.encoder_type,
                                 batch_size_sents=opt.batch_size_sents)

    else:
        raise NotImplementedError

    print('Building model...')

    if not opt.fusion:
        model = build_model(opt, dicts)

        """ Building the loss function """
        if opt.ctc_loss != 0:
            loss_function = NMTAndCTCLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing,ctc_weight = opt.ctc_loss)
        else:
            loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing)
    else:
        from onmt.ModelConstructor import build_fusion
        from onmt.modules.Loss import FusionLoss

        model = build_fusion(opt, dicts)

        loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing)


    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
            raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.")
    else:
        if opt.fp16:
            trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt)
        else:
            trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt)

    
    trainer.run(save_file=opt.load_from)
示例#4
0
def main():
    if opt.data_format == 'raw':
        start = time.time()
        if opt.data.endswith(".train.pt"):
            print("Loading data from '%s'" % opt.data)
            dataset = torch.load(
                opt.data)  # This requires a lot of cpu memory!
        else:
            print("Loading data from %s" % opt.data + ".train.pt")
            dataset = torch.load(opt.data + ".train.pt")

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse)

        train_data = onmt.Dataset(dataset['train']['src'],
                                  dataset['train']['tgt'],
                                  opt.batch_size_words,
                                  data_type=dataset.get("type", "text"),
                                  batch_size_sents=opt.batch_size_sents,
                                  multiplier=opt.batch_size_multiplier,
                                  reshape_speech=opt.reshape_speech,
                                  augment=opt.augment_speech)
        valid_data = onmt.Dataset(dataset['valid']['src'],
                                  dataset['valid']['tgt'],
                                  opt.batch_size_words,
                                  data_type=dataset.get("type", "text"),
                                  batch_size_sents=opt.batch_size_sents,
                                  reshape_speech=opt.reshape_speech)

        dicts = dataset['dicts']

        print(' * number of training sentences. %d' %
              len(dataset['train']['src']))
        print(' * maximum batch size (words per batch). %d' %
              opt.batch_size_words)

    elif opt.data_format == 'bin':

        from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset

        dicts = torch.load(opt.data + ".dict.pt")

        train_path = opt.data + '.train'
        train_src = IndexedInMemoryDataset(train_path + '.src')
        train_tgt = IndexedInMemoryDataset(train_path + '.tgt')

        train_data = onmt.Dataset(train_src,
                                  train_tgt,
                                  opt.batch_size_words,
                                  data_type=opt.encoder_type,
                                  batch_size_sents=opt.batch_size_sents,
                                  multiplier=opt.batch_size_multiplier)

        valid_path = opt.data + '.valid'
        valid_src = IndexedInMemoryDataset(valid_path + '.src')
        valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt')

        valid_data = onmt.Dataset(valid_src,
                                  valid_tgt,
                                  opt.batch_size_words,
                                  data_type=opt.encoder_type,
                                  batch_size_sents=opt.batch_size_sents)

    else:
        raise NotImplementedError

    additional_data = []
    if (opt.additional_data != "none"):
        add_data = opt.additional_data.split(";")
        add_format = opt.additional_data_format.split(";")
        assert (len(add_data) == len(add_format))
        for i in range(len(add_data)):
            if add_format[i] == 'raw':
                if add_data[i].endswith(".train.pt"):
                    print("Loading data from '%s'" % add_data[i])
                    add_dataset = torch.load(add_data[i])
                else:
                    print("Loading data from %s" % add_data[i] + ".train.pt")
                    add_dataset = torch.load(add_data[i] + ".train.pt")

                additional_data.append(
                    onmt.Dataset(add_dataset['train']['src'],
                                 add_dataset['train']['tgt'],
                                 opt.batch_size_words,
                                 data_type=add_dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier=opt.batch_size_multiplier,
                                 reshape_speech=opt.reshape_speech,
                                 augment=opt.augment_speech))
                add_dicts = add_dataset['dicts']

                for d in ['src', 'tgt']:
                    if (d in dicts):
                        if (d in add_dicts):
                            assert (dicts[d].size() == add_dicts[d].size())
                    else:
                        if (d in add_dicts):
                            dicts[d] = add_dicts[d]

            elif add_format[i] == 'bin':

                from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset

                train_path = add_data[i] + '.train'
                train_src = IndexedInMemoryDataset(train_path + '.src')
                train_tgt = IndexedInMemoryDataset(train_path + '.tgt')

                additional_data.append(
                    onmt.Dataset(train_src,
                                 train_tgt,
                                 opt.batch_size_words,
                                 data_type=opt.encoder_type,
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier=opt.batch_size_multiplier))

    # Restore from checkpoint
    if opt.load_from:
        checkpoint = torch.load(opt.load_from,
                                map_location=lambda storage, loc: storage)
        print("* Loading dictionaries from the checkpoint")
        dicts = checkpoint['dicts']
    else:
        dicts['tgt'].patch(opt.patch_vocab_multiplier)
        checkpoint = None

    if "src" in dicts:
        print(' * vocabulary size. source = %d; target = %d' %
              (dicts['src'].size(), dicts['tgt'].size()))
    else:
        print(' * vocabulary size. target = %d' % (dicts['tgt'].size()))

    print('Building model...')

    if not opt.fusion:
        model = build_model(opt, dicts)
        """ Building the loss function """
        if opt.ctc_loss != 0:
            loss_function = NMTAndCTCLossFunc(
                dicts['tgt'].size(),
                label_smoothing=opt.label_smoothing,
                ctc_weight=opt.ctc_loss)
        else:
            loss_function = NMTLossFunc(dicts['tgt'].size(),
                                        label_smoothing=opt.label_smoothing)
    else:
        from onmt.ModelConstructor import build_fusion
        from onmt.modules.Loss import FusionLoss

        model = build_fusion(opt, dicts)

        loss_function = FusionLoss(dicts['tgt'].size(),
                                   label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        raise NotImplementedError(
            "Warning! Multi-GPU training is not fully tested and potential bugs can happen."
        )
    else:
        # if opt.fp16:
        #     trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt)
        # else:
        trainer = XETrainer(model, loss_function, train_data, valid_data,
                            dicts, opt)
        if (len(additional_data) > 0):
            trainer.add_additional_data(additional_data, opt.data_ratio)

    trainer.run(checkpoint=checkpoint)
示例#5
0
def main():

    start = time.time()
    print("Loading data from '%s'" % opt.data)

    if opt.data_format == 'raw':
        dataset = torch.load(opt.data)
        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse)

        trainData = onmt.Dataset(dataset['train']['src'],
                                 dataset['train']['tgt'],
                                 opt.batch_size_words,
                                 opt.gpus,
                                 max_seq_num=opt.batch_size_sents,
                                 pad_count=opt.pad_count,
                                 multiplier=opt.batch_size_multiplier,
                                 sort_by_target=opt.sort_by_target)
        validData = onmt.Dataset(dataset['valid']['src'],
                                 dataset['valid']['tgt'],
                                 opt.batch_size_words,
                                 opt.gpus,
                                 max_seq_num=opt.batch_size_sents)

        dicts = dataset['dicts']
        print(' * vocabulary size. source = %d; target = %d' %
              (dicts['src'].size(), dicts['tgt'].size()))
        print(' * number of training sentences. %d' %
              len(dataset['train']['src']))
        print(' * maximum batch size (words per batch). %d' %
              opt.batch_size_words)
    elif opt.data_format == 'bin':
        from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset

        dicts = torch.load(opt.data + ".dict.pt")

        #~ train = {}
        train_path = opt.data + '.train'
        train_src = IndexedInMemoryDataset(train_path + '.src')
        train_tgt = IndexedInMemoryDataset(train_path + '.tgt')

        trainData = onmt.Dataset(train_src,
                                 train_tgt,
                                 opt.batch_size_words,
                                 opt.gpus,
                                 max_seq_num=opt.batch_size_sents,
                                 pad_count=opt.pad_count,
                                 multiplier=opt.batch_size_multiplier,
                                 sort_by_target=opt.sort_by_target)

        valid_path = opt.data + '.valid'
        valid_src = IndexedInMemoryDataset(valid_path + '.src')
        valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt')

        validData = onmt.Dataset(valid_src,
                                 valid_tgt,
                                 opt.batch_size_words,
                                 opt.gpus,
                                 max_seq_num=opt.batch_size_sents)

    else:
        raise NotImplementedError

    print('Building model...')
    model = build_model(opt, dicts)
    """ Building the loss function """
    loss_function = NMTLossFunc(dicts['tgt'].size(),
                                label_smoothing=opt.label_smoothing)

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)

    optim = None

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        #~ trainer = MultiGPUXETrainer(model, loss_function, trainData, validData, dataset, opt)
        raise NotImplementedError(
            "Warning! Multi-GPU training is not fully tested and potential bugs can happen."
        )
    else:
        if opt.fp16:
            trainer = FP16XETrainer(model, loss_function, trainData, validData,
                                    dicts, opt)
        else:
            trainer = XETrainer(model, loss_function, trainData, validData,
                                dicts, opt)

    trainer.run(save_file=opt.load_from)
示例#6
0
def main():

    if not opt.multi_dataset:
        if opt.data_format in ['bin', 'raw']:
            start = time.time()

            if opt.data.endswith(".train.pt"):
                print("Loading data from '%s'" % opt.data)
                dataset = torch.load(opt.data)
            else:
                print("Loading data from %s" % opt.data + ".train.pt")
                dataset = torch.load(opt.data + ".train.pt")

            elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
            print("Done after %s" % elapse)

            dicts = dataset['dicts']

            # For backward compatibility
            train_dict = defaultdict(lambda: None, dataset['train'])
            valid_dict = defaultdict(lambda: None, dataset['valid'])

            if train_dict['src_lang'] is not None:
                assert 'langs' in dicts
                train_src_langs = train_dict['src_lang']
                train_tgt_langs = train_dict['tgt_lang']
            else:
                # allocate new languages
                dicts['langs'] = {'src': 0, 'tgt': 1}
                train_src_langs = list()
                train_tgt_langs = list()
                # Allocation one for the bilingual case
                train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            if not opt.streaming:
                train_data = onmt.Dataset(
                    numpy_to_torch(train_dict['src']),
                    numpy_to_torch(train_dict['tgt']),
                    train_dict['src_sizes'],
                    train_dict['tgt_sizes'],
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    augment=opt.augment_speech,
                    upsampling=opt.upsampling,
                    num_split=len(opt.gpus))
            else:
                train_data = onmt.StreamDataset(
                    train_dict['src'],
                    train_dict['tgt'],
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    augment=opt.augment_speech,
                    upsampling=opt.upsampling)

            if valid_dict['src_lang'] is not None:
                assert 'langs' in dicts
                valid_src_langs = valid_dict['src_lang']
                valid_tgt_langs = valid_dict['tgt_lang']
            else:
                # allocate new languages
                valid_src_langs = list()
                valid_tgt_langs = list()

                # Allocation one for the bilingual case
                valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            if not opt.streaming:
                valid_data = onmt.Dataset(
                    numpy_to_torch(valid_dict['src']),
                    numpy_to_torch(valid_dict['tgt']),
                    valid_dict['src_sizes'],
                    valid_dict['tgt_sizes'],
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    upsampling=opt.upsampling,
                    num_split=len(opt.gpus))
            else:
                valid_data = onmt.StreamDataset(
                    numpy_to_torch(valid_dict['src']),
                    numpy_to_torch(valid_dict['tgt']),
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    upsampling=opt.upsampling)

            print(' * number of training sentences. %d' %
                  len(dataset['train']['src']))
            print(' * maximum batch size (words per batch). %d' %
                  opt.batch_size_words)

        elif opt.data_format in ['scp', 'scpmem', 'mmem']:
            print("Loading memory mapped data files ....")
            start = time.time()
            from onmt.data.mmap_indexed_dataset import MMapIndexedDataset
            from onmt.data.scp_dataset import SCPIndexDataset

            dicts = torch.load(opt.data + ".dict.pt")
            if opt.data_format in ['scp', 'scpmem']:
                audio_data = torch.load(opt.data + ".scp_path.pt")

            # allocate languages if not
            if 'langs' not in dicts:
                dicts['langs'] = {'src': 0, 'tgt': 1}
            else:
                print(dicts['langs'])

            train_path = opt.data + '.train'
            if opt.data_format in ['scp', 'scpmem']:
                train_src = SCPIndexDataset(audio_data['train'],
                                            concat=opt.concat)
            else:
                train_src = MMapIndexedDataset(train_path + '.src')

            train_tgt = MMapIndexedDataset(train_path + '.tgt')

            # check the lang files if they exist (in the case of multi-lingual models)
            if os.path.exists(train_path + '.src_lang.bin'):
                assert 'langs' in dicts
                train_src_langs = MMapIndexedDataset(train_path + '.src_lang')
                train_tgt_langs = MMapIndexedDataset(train_path + '.tgt_lang')
            else:
                train_src_langs = list()
                train_tgt_langs = list()
                # Allocate a Tensor(1) for the bilingual case
                train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            # check the length files if they exist
            if os.path.exists(train_path + '.src_sizes.npy'):
                train_src_sizes = np.load(train_path + '.src_sizes.npy')
                train_tgt_sizes = np.load(train_path + '.tgt_sizes.npy')
            else:
                train_src_sizes, train_tgt_sizes = None, None

            if opt.encoder_type == 'audio':
                data_type = 'audio'
            else:
                data_type = 'text'

            if not opt.streaming:
                train_data = onmt.Dataset(
                    train_src,
                    train_tgt,
                    train_src_sizes,
                    train_tgt_sizes,
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    src_align_right=opt.src_align_right,
                    augment=opt.augment_speech,
                    upsampling=opt.upsampling,
                    cleaning=True,
                    verbose=True,
                    num_split=len(opt.gpus))
            else:
                train_data = onmt.StreamDataset(
                    train_src,
                    train_tgt,
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=False,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    upsampling=opt.upsampling)

            valid_path = opt.data + '.valid'
            if opt.data_format in ['scp', 'scpmem']:
                valid_src = SCPIndexDataset(audio_data['valid'],
                                            concat=opt.concat)
            else:
                valid_src = MMapIndexedDataset(valid_path + '.src')
            valid_tgt = MMapIndexedDataset(valid_path + '.tgt')

            if os.path.exists(valid_path + '.src_lang.bin'):
                assert 'langs' in dicts
                valid_src_langs = MMapIndexedDataset(valid_path + '.src_lang')
                valid_tgt_langs = MMapIndexedDataset(valid_path + '.tgt_lang')
            else:
                valid_src_langs = list()
                valid_tgt_langs = list()

                # Allocation one for the bilingual case
                valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            # check the length files if they exist
            if os.path.exists(valid_path + '.src_sizes.npy'):
                valid_src_sizes = np.load(valid_path + '.src_sizes.npy')
                valid_tgt_sizes = np.load(valid_path + '.tgt_sizes.npy')
            else:
                valid_src_sizes, valid_tgt_sizes = None, None

            if not opt.streaming:
                valid_data = onmt.Dataset(
                    valid_src,
                    valid_tgt,
                    valid_src_sizes,
                    valid_tgt_sizes,
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    src_align_right=opt.src_align_right,
                    cleaning=True,
                    verbose=True,
                    debug=True,
                    num_split=len(opt.gpus))
            else:
                # for validation data, we have to go through sentences (very slow but to ensure correctness)
                valid_data = onmt.StreamDataset(
                    valid_src,
                    valid_tgt,
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents)

            elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
            print("Done after %s" % elapse)

        else:
            raise NotImplementedError

        print(' * number of sentences in training data: %d' %
              train_data.size())
        print(' * number of sentences in validation data: %d' %
              valid_data.size())

    else:
        print("[INFO] Reading multiple dataset ...")
        # raise NotImplementedError

        dicts = torch.load(opt.data + ".dict.pt")

        root_dir = os.path.dirname(opt.data)

        print("Loading training data ...")

        train_dirs, valid_dirs = dict(), dict()

        # scan the data directory to find the training data
        for dir_ in os.listdir(root_dir):
            if os.path.isdir(os.path.join(root_dir, dir_)):
                if str(dir_).startswith("train"):
                    idx = int(dir_.split(".")[1])
                    train_dirs[idx] = dir_
                if dir_.startswith("valid"):
                    idx = int(dir_.split(".")[1])
                    valid_dirs[idx] = dir_

        train_sets, valid_sets = list(), list()

        for (idx_, dir_) in sorted(train_dirs.items()):

            data_dir = os.path.join(root_dir, dir_)
            print("[INFO] Loading training data %i from %s" % (idx_, dir_))

            if opt.data_format in ['bin', 'raw']:
                raise NotImplementedError

            elif opt.data_format in ['scp', 'scpmem', 'mmem']:
                from onmt.data.mmap_indexed_dataset import MMapIndexedDataset
                from onmt.data.scp_dataset import SCPIndexDataset

                if opt.data_format in ['scp', 'scpmem']:
                    audio_data = torch.load(
                        os.path.join(data_dir, "data.scp_path.pt"))
                    src_data = SCPIndexDataset(audio_data, concat=opt.concat)
                else:
                    src_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.src"))

                tgt_data = MMapIndexedDataset(
                    os.path.join(data_dir, "data.tgt"))

                src_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.src_lang'))
                tgt_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.tgt_lang'))

                if os.path.exists(os.path.join(data_dir,
                                               'data.src_sizes.npy')):
                    src_sizes = np.load(
                        os.path.join(data_dir, 'data.src_sizes.npy'))
                    tgt_sizes = np.load(
                        os.path.join(data_dir, 'data.tgt_sizes.npy'))
                else:
                    src_sizes, sizes = None, None

                if opt.encoder_type == 'audio':
                    data_type = 'audio'
                else:
                    data_type = 'text'

                if not opt.streaming:
                    train_data = onmt.Dataset(
                        src_data,
                        tgt_data,
                        src_sizes,
                        tgt_sizes,
                        src_lang_data,
                        tgt_lang_data,
                        batch_size_words=opt.batch_size_words,
                        data_type=data_type,
                        sorting=True,
                        batch_size_sents=opt.batch_size_sents,
                        multiplier=opt.batch_size_multiplier,
                        src_align_right=opt.src_align_right,
                        augment=opt.augment_speech,
                        upsampling=opt.upsampling,
                        cleaning=True,
                        verbose=True,
                        num_split=len(opt.gpus))

                    train_sets.append(train_data)

                else:
                    print("Multi-dataset not implemented for Streaming tasks.")
                    raise NotImplementedError

        for (idx_, dir_) in sorted(valid_dirs.items()):

            data_dir = os.path.join(root_dir, dir_)

            print("[INFO] Loading validation data %i from %s" % (idx_, dir_))

            if opt.data_format in ['bin', 'raw']:
                raise NotImplementedError

            elif opt.data_format in ['scp', 'scpmem', 'mmem']:

                if opt.data_format in ['scp', 'scpmem']:
                    audio_data = torch.load(
                        os.path.join(data_dir, "data.scp_path.pt"))
                    src_data = SCPIndexDataset(audio_data, concat=opt.concat)
                else:
                    src_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.src"))

                tgt_data = MMapIndexedDataset(
                    os.path.join(data_dir, "data.tgt"))

                src_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.src_lang'))
                tgt_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.tgt_lang'))

                if os.path.exists(os.path.join(data_dir,
                                               'data.src_sizes.npy')):
                    src_sizes = np.load(
                        os.path.join(data_dir, 'data.src_sizes.npy'))
                    tgt_sizes = np.load(
                        os.path.join(data_dir, 'data.tgt_sizes.npy'))
                else:
                    src_sizes, sizes = None, None

                if opt.encoder_type == 'audio':
                    data_type = 'audio'
                else:
                    data_type = 'text'

                if not opt.streaming:
                    valid_data = onmt.Dataset(
                        src_data,
                        tgt_data,
                        src_sizes,
                        tgt_sizes,
                        src_lang_data,
                        tgt_lang_data,
                        batch_size_words=opt.batch_size_words,
                        data_type=data_type,
                        sorting=True,
                        batch_size_sents=opt.batch_size_sents,
                        src_align_right=opt.src_align_right,
                        cleaning=True,
                        verbose=True,
                        debug=True,
                        num_split=len(opt.gpus))

                    valid_sets.append(valid_data)

                else:
                    raise NotImplementedError

        train_data = train_sets
        valid_data = valid_sets

    if opt.load_from:
        checkpoint = torch.load(opt.load_from,
                                map_location=lambda storage, loc: storage)
        print("* Loading dictionaries from the checkpoint")
        dicts = checkpoint['dicts']
    else:
        dicts['tgt'].patch(opt.patch_vocab_multiplier)
        checkpoint = None

    # Put the vocab mask from dicts to the datasets
    for data in [train_data, valid_data]:
        if isinstance(data, list):
            for i, data_ in enumerate(data):
                data_.set_mask(dicts['tgt'].vocab_mask)
                data[i] = data_
        else:
            data.set_mask(dicts['tgt'].vocab_mask)

    if "src" in dicts:
        print(' * vocabulary size. source = %d; target = %d' %
              (dicts['src'].size(), dicts['tgt'].size()))
    else:
        print('[INFO] vocabulary size. target = %d' % (dicts['tgt'].size()))

    print('* Building model...')

    if not opt.fusion:
        if opt.bayes_by_backprop:
            model = build_bayesian_model(opt, dicts)
        else:
            model = build_model(opt, dicts)
        """ Building the loss function """
        # if opt.ctc_loss != 0:
        #     pass
        #     loss_function = NMTAndCTCLossFunc(dicts['tgt'].size(),
        #                                       label_smoothing=opt.label_smoothing,
        #                                       ctc_weight=opt.ctc_loss)
        if opt.nce:
            from onmt.modules.nce.nce_loss import NCELoss
            loss_function = NCELoss(opt.model_size,
                                    dicts['tgt'].size(),
                                    noise_ratio=opt.nce_noise,
                                    logz=9,
                                    label_smoothing=opt.label_smoothing)
        else:
            loss_function = NMTLossFunc(opt.model_size,
                                        dicts['tgt'].size(),
                                        label_smoothing=opt.label_smoothing,
                                        mirror=opt.mirror_loss,
                                        fast_xentropy=opt.fast_xentropy)

        # This function replaces modules with the more optimized counterparts so that it can run faster
        # Currently exp with LayerNorm
        if not opt.memory_profiling:
            optimize_model(model, fp16=opt.fp16)

    else:
        from onmt.model_factory import build_fusion
        from onmt.modules.loss import FusionLoss

        model = build_fusion(opt, dicts)

        loss_function = FusionLoss(dicts['tgt'].size(),
                                   label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    if not opt.debugging and len(opt.gpus) == 1:
        if opt.bayes_by_backprop:

            from onmt.train_utils.bayes_by_backprop_trainer import BayesianTrainer
            trainer = BayesianTrainer(model, loss_function, train_data,
                                      valid_data, dicts, opt)

        else:
            trainer = XETrainer(model, loss_function, train_data, valid_data,
                                dicts, opt)
    else:
        from onmt.train_utils.new_trainer import Trainer
        trainer = Trainer(model, loss_function, train_data, valid_data, dicts,
                          opt)

    trainer.run(checkpoint=checkpoint)
示例#7
0
def main():

    start = time.time()
    print("Loading data from '%s'" % opt.data)
    dataset = torch.load(opt.data)
    elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
    print("Done after %s" % elapse)

    #~ dict_checkpoint = opt.load_from
    #~ if dict_checkpoint:
    #~ print('Loading dicts from checkpoint at %s' % dict_checkpoint)
    #~ checkpoint = torch.load(dict_checkpoint, map_location=lambda storage, loc: storage)
    #~ dataset['dicts'] = checkpoint['dicts']
    #~ else:
    #~ checkpoint = None

    trainData = onmt.Dataset(dataset['train']['src'],
                             dataset['train']['tgt'],
                             opt.batch_size_words,
                             opt.gpus,
                             data_type=dataset.get("type", "text"),
                             max_seq_num=opt.batch_size_sents)
    validData = onmt.Dataset(dataset['valid']['src'],
                             dataset['valid']['tgt'],
                             opt.batch_size_words,
                             opt.gpus,
                             volatile=True,
                             data_type=dataset.get("type", "text"),
                             max_seq_num=opt.batch_size_sents)

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (dicts['src'].size(), dicts['tgt'].size()))
    print(' * number of training sentences. %d' % len(dataset['train']['src']))
    print(' * maximum batch size (words per batch). %d' % opt.batch_size_words)

    print('Building model...')
    model = build_model(opt, dicts)
    """ Building the loss function """
    loss_function = NMTLossFunc(dataset['dicts']['tgt'].size(),
                                label_smoothing=opt.label_smoothing,
                                shard_size=opt.max_generator_batches)

    #~ print(model)
    #~ print(loss_function)

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)

    optim = None

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        trainer = MultiGPUXETrainer(model, loss_function, trainData, validData,
                                    dataset, opt)
        print(
            "Warning! Multi-GPU training is used. Not fully tested and potential bugs can happen."
        )
    else:
        trainer = XETrainer(model, loss_function, trainData, validData,
                            dataset, opt)

    trainer.run(save_file=opt.load_from)
示例#8
0
def main():

    if opt.data_format == 'raw':
        start = time.time()

        if opt.data.endswith(".train.pt"):
            print("Loading data from '%s'" % opt.data)
            dataset = torch.load(opt.data)
        else:
            print("Loading data from %s" % opt.data + ".train.pt")
            dataset = torch.load(opt.data + ".train.pt")

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse )

        # For backward compatibility
        train_dict = defaultdict(lambda: None, dataset['train'])
        valid_dict = defaultdict(lambda: None, dataset['valid'])

        train_data = onmt.Dataset(train_dict['src'], train_dict['tgt'],
                                  train_dict['src_atbs'], train_dict['tgt_atbs'],
                                  batch_size_words=opt.batch_size_words,
                                  data_type=dataset.get("type", "text"),
                                  batch_size_sents=opt.batch_size_sents,
                                  multiplier=opt.batch_size_multiplier,
                                  augment=opt.augment_speech,
                                  upsampling=opt.upsampling)
        valid_data = onmt.Dataset(valid_dict['src'], valid_dict['tgt'],
                                  valid_dict['src_atbs'], valid_dict['tgt_atbs'],
                                  batch_size_words=opt.batch_size_words,
                                  data_type=dataset.get("type", "text"),
                                  batch_size_sents=opt.batch_size_sents,
                                  upsampling=opt.upsampling)

        dicts = dataset['dicts']

        print(' * number of training sentences. %d' % len(dataset['train']['src']))
        print(' * maximum batch size (words per batch). %d' % opt.batch_size_words)

    elif opt.data_format == 'bin':
        print("Loading memory binned data files ....")
        start = time.time()
        from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset

        dicts = torch.load(opt.data + ".dict.pt")

        train_path = opt.data + '.train'
        train_src = IndexedInMemoryDataset(train_path + '.src')
        train_tgt = IndexedInMemoryDataset(train_path + '.tgt')

        train_data = onmt.Dataset(train_src,
                                  train_tgt,
                                  batch_size_words=opt.batch_size_words,
                                  data_type="text",
                                  batch_size_sents=opt.batch_size_sents,
                                  multiplier=opt.batch_size_multiplier)

        valid_path = opt.data + '.valid'
        valid_src = IndexedInMemoryDataset(valid_path + '.src')
        valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt')

        valid_data = onmt.Dataset(valid_src,
                                  valid_tgt,
                                  batch_size_words=opt.batch_size_words,
                                  data_type="text",
                                  batch_size_sents=opt.batch_size_sents)

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse)
    elif opt.data_format == 'mmem':
        print("Loading memory mapped data files ....")
        start = time.time()
        from onmt.data_utils.MMapIndexedDataset import MMapIndexedDataset

        dicts = torch.load(opt.data + ".dict.pt")

        train_path = opt.data + '.train'
        train_src = MMapIndexedDataset(train_path + '.src')
        train_tgt = MMapIndexedDataset(train_path + '.tgt')

        train_data = onmt.Dataset(train_src,
                                  train_tgt,
                                  batch_size_words=opt.batch_size_words,
                                  data_type="text",
                                  batch_size_sents=opt.batch_size_sents,
                                  multiplier=opt.batch_size_multiplier)

        valid_path = opt.data + '.valid'
        valid_src = MMapIndexedDataset(valid_path + '.src')
        valid_tgt = MMapIndexedDataset(valid_path + '.tgt')

        valid_data = onmt.Dataset(valid_src,
                                  valid_tgt,
                                  batch_size_words=opt.batch_size_words,
                                  data_type="text",
                                  batch_size_sents=opt.batch_size_sents)
        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse)

    else:
        raise NotImplementedError

    additional_data = []
    if opt.additional_data != "none":
        add_data = opt.additional_data.split(";")
        add_format = opt.additional_data_format.split(";")
        assert(len(add_data) == len(add_format))
        for i in range(len(add_data)):
            if add_format[i] == 'raw':
                if add_data[i].endswith(".train.pt"):
                    print("Loading data from '%s'" % opt.data)
                    add_dataset = torch.load(add_data[i])
                else:
                    print("Loading data from %s" % opt.data + ".train.pt")
                    add_dataset = torch.load(add_data[i] + ".train.pt")

                additional_data.append(onmt.Dataset(add_dataset['train']['src'],
                                          dataset['train']['tgt'], batch_size_words=opt.batch_size_words,
                                          data_type=dataset.get("type", "text"),
                                          batch_size_sents=opt.batch_size_sents,
                                          multiplier=opt.batch_size_multiplier,
                                          reshape_speech=opt.reshape_speech,
                                          augment=opt.augment_speech))
            elif add_format[i] == 'bin':

                from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset

                train_path = add_data[i] + '.train'
                train_src = IndexedInMemoryDataset(train_path + '.src')
                train_tgt = IndexedInMemoryDataset(train_path + '.tgt')

                additional_data.append(onmt.Dataset(train_src,
                                       train_tgt,
                                       batch_size_words=opt.batch_size_words,
                                       data_type=opt.encoder_type,
                                       batch_size_sents=opt.batch_size_sents,
                                       multiplier = opt.batch_size_multiplier))

    if opt.load_from:
        checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage)
        print("* Loading dictionaries from the checkpoint")
        dicts = checkpoint['dicts']
    else:
        dicts['tgt'].patch(opt.patch_vocab_multiplier)
        checkpoint = None

    if "src" in dicts:
        print(' * vocabulary size. source = %d; target = %d' %
              (dicts['src'].size(), dicts['tgt'].size()))
    else:
        print(' * vocabulary size. target = %d' %
              (dicts['tgt'].size()))

    print('Building model...')

    if not opt.fusion:
        if opt.bert_scalar and opt.finetune_bert:
            print("WARNING: we only fine tune bert, we don't finetune scalar parameters, please set opt.bert_scalar False")

        print("Using scalared bert vector: ", opt.bert_scalar)
        print("Using Bert+Transformer to finetuning : ", opt.finetune_bert)

        model = build_model(opt, dicts)

        if not opt.finetune_bert:
            for param in model.bert.parameters():
                param.requires_grad = False

        if not opt.finetune_bert and opt.bert_scalar:
            scalar_mix = ScalarMix(
               onmt.Constants.BERT_LAYERS,
               do_layer_norm=False,
               initial_scalar_parameters=None,
               trainable=True,
            )
            model.add_module("scalar_mix", scalar_mix)
        print(model)
      #  for name, param in model.bert_model.named_parameters():
            # print(name, param, param.requires_grad)
            # the params in bert_model which require gradient:
            # if param.requires_grad:
            #    print(name)

        """ Building the loss function """
        if opt.ctc_loss != 0:
            loss_function = NMTAndCTCLossFunc(dicts['tgt'].size(),
                                              label_smoothing=opt.label_smoothing,
                                              ctc_weight=opt.ctc_loss)
        else:
            loss_function = NMTLossFunc(dicts['tgt'].size(),
                                        label_smoothing=opt.label_smoothing)
    else:
        from onmt.ModelConstructor import build_fusion
        from onmt.modules.Loss import FusionLoss

        model = build_fusion(opt, dicts)

        loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of all parameters: %d' % n_params)

    n_params_grad = sum([p.nelement() for p in model.parameters() if p.requires_grad == True])
    print('* number of all parameters that need gradient: %d' % n_params_grad)

    n_params_nograd = sum([p.nelement() for p in model.parameters() if p.requires_grad == False])
    print('* number of all parameters that do not need gradient: %d' % n_params_nograd)

    assert n_params == (n_params_grad + n_params_nograd)
    # print(model)

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.")
    else:
        trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt,setup_optimizer=True)
        if len(additional_data) > 0:
            trainer.add_additional_data(additional_data, opt.data_ratio)

    trainer.run(checkpoint=checkpoint)
示例#9
0
def main():
    if opt.data_format in ['bin', 'raw']:
        start = time.time()

        if opt.data.endswith(".train.pt"):
            print("Loading data from '%s'" % opt.data)
            dataset = torch.load(opt.data)
        else:
            print("Loading data from %s" % opt.data + ".train.pt")
            dataset = torch.load(opt.data + ".train.pt")

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse)

        dicts = dataset['dicts']

        # For backward compatibility
        train_dict = defaultdict(lambda: None, dataset['train'])
        valid_dict = defaultdict(lambda: None, dataset['valid'])

        if train_dict['src_lang'] is not None:
            assert 'langs' in dicts
            train_src_langs = train_dict['src_lang']
            train_tgt_langs = train_dict['tgt_lang']
        else:
            # allocate new languages
            dicts['langs'] = {'src': 0, 'tgt': 1}
            train_src_langs = list()
            train_tgt_langs = list()
            # Allocation one for the bilingual case
            train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
            train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

        if not opt.streaming:
            train_data = onmt.Dataset(train_dict['src'],
                                      train_dict['tgt'],
                                      train_src_langs,
                                      train_tgt_langs,
                                      batch_size_words=opt.batch_size_words,
                                      data_type=dataset.get("type", "text"),
                                      sorting=True,
                                      batch_size_sents=opt.batch_size_sents,
                                      multiplier=opt.batch_size_multiplier,
                                      augment=opt.augment_speech,
                                      upsampling=opt.upsampling)
        else:
            train_data = onmt.StreamDataset(
                train_dict['src'],
                train_dict['tgt'],
                train_src_langs,
                train_tgt_langs,
                batch_size_words=opt.batch_size_words,
                data_type=dataset.get("type", "text"),
                sorting=True,
                batch_size_sents=opt.batch_size_sents,
                multiplier=opt.batch_size_multiplier,
                augment=opt.augment_speech,
                upsampling=opt.upsampling)

        if valid_dict['src_lang'] is not None:
            assert 'langs' in dicts
            valid_src_langs = valid_dict['src_lang']
            valid_tgt_langs = valid_dict['tgt_lang']
        else:
            # allocate new languages
            valid_src_langs = list()
            valid_tgt_langs = list()

            # Allocation one for the bilingual case
            valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
            valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

        if not opt.streaming:
            valid_data = onmt.Dataset(valid_dict['src'],
                                      valid_dict['tgt'],
                                      valid_src_langs,
                                      valid_tgt_langs,
                                      batch_size_words=opt.batch_size_words,
                                      data_type=dataset.get("type", "text"),
                                      sorting=True,
                                      batch_size_sents=opt.batch_size_sents,
                                      upsampling=opt.upsampling)
        else:
            valid_data = onmt.StreamDataset(
                valid_dict['src'],
                valid_dict['tgt'],
                valid_src_langs,
                valid_tgt_langs,
                batch_size_words=opt.batch_size_words,
                data_type=dataset.get("type", "text"),
                sorting=True,
                batch_size_sents=opt.batch_size_sents,
                upsampling=opt.upsampling)

        print(' * number of training sentences. %d' %
              len(dataset['train']['src']))
        print(' * maximum batch size (words per batch). %d' %
              opt.batch_size_words)

    elif opt.data_format == 'mmem':
        print("Loading memory mapped data files ....")
        start = time.time()
        from onmt.data.mmap_indexed_dataset import MMapIndexedDataset

        dicts = torch.load(opt.data + ".dict.pt")

        # allocate languages if not
        if 'langs' not in dicts:
            dicts['langs'] = {'src': 0, 'tgt': 1}
        else:
            print(dicts['langs'])

        train_path = opt.data + '.train'
        train_src = MMapIndexedDataset(train_path + '.src')
        train_tgt = MMapIndexedDataset(train_path + '.tgt')

        # check the lang files if they exist (in the case of multi-lingual models)
        if os.path.exists(train_path + '.src_lang.bin'):
            assert 'langs' in dicts
            train_src_langs = MMapIndexedDataset(train_path + '.src_lang')
            train_tgt_langs = MMapIndexedDataset(train_path + '.tgt_lang')
        else:
            train_src_langs = list()
            train_tgt_langs = list()
            # Allocate a Tensor(1) for the bilingual case
            train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
            train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

        if opt.encoder_type == 'audio':
            data_type = 'audio'
        else:
            data_type = 'text'

        if not opt.streaming:
            train_data = onmt.Dataset(train_src,
                                      train_tgt,
                                      train_src_langs,
                                      train_tgt_langs,
                                      batch_size_words=opt.batch_size_words,
                                      data_type=data_type,
                                      sorting=True,
                                      batch_size_sents=opt.batch_size_sents,
                                      multiplier=opt.batch_size_multiplier,
                                      src_align_right=opt.src_align_right,
                                      upsampling=opt.upsampling,
                                      cleaning=True,
                                      verbose=True)
        else:
            train_data = onmt.StreamDataset(
                train_src,
                train_tgt,
                train_src_langs,
                train_tgt_langs,
                batch_size_words=opt.batch_size_words,
                data_type=data_type,
                sorting=False,
                batch_size_sents=opt.batch_size_sents,
                multiplier=opt.batch_size_multiplier,
                upsampling=opt.upsampling)

        valid_path = opt.data + '.valid'
        valid_src = MMapIndexedDataset(valid_path + '.src')
        valid_tgt = MMapIndexedDataset(valid_path + '.tgt')

        if os.path.exists(valid_path + '.src_lang.bin'):
            assert 'langs' in dicts
            valid_src_langs = MMapIndexedDataset(valid_path + '.src_lang')
            valid_tgt_langs = MMapIndexedDataset(valid_path + '.tgt_lang')
        else:
            valid_src_langs = list()
            valid_tgt_langs = list()

            # Allocation one for the bilingual case
            valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
            valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

        if not opt.streaming:
            valid_data = onmt.Dataset(valid_src,
                                      valid_tgt,
                                      valid_src_langs,
                                      valid_tgt_langs,
                                      batch_size_words=opt.batch_size_words,
                                      data_type="text",
                                      sorting=False,
                                      batch_size_sents=opt.batch_size_sents,
                                      src_align_right=opt.src_align_right,
                                      cleaning=True,
                                      verbose=True)
        else:
            # for validation data, we have to go through sentences (very slow but to ensure correctness)
            valid_data = onmt.StreamDataset(
                valid_src,
                valid_tgt,
                valid_src_langs,
                valid_tgt_langs,
                batch_size_words=opt.batch_size_words,
                data_type="text",
                sorting=True,
                batch_size_sents=opt.batch_size_sents)

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse)

    else:
        raise NotImplementedError

    # additional_data = []
    # if opt.additional_data != "none":
    #     add_data = opt.additional_data.split(";")
    #     add_format = opt.additional_data_format.split(";")
    #     assert (len(add_data) == len(add_format))
    #     for i in range(len(add_data)):
    #         if add_format[i] == 'raw':
    #             if add_data[i].endswith(".train.pt"):
    #                 print("Loading data from '%s'" % opt.data)
    #                 add_dataset = torch.load(add_data[i])
    #             else:
    #                 print("Loading data from %s" % opt.data + ".train.pt")
    #                 add_dataset = torch.load(add_data[i] + ".train.pt")
    #
    #             additional_data.append(onmt.Dataset(add_dataset['train']['src'],
    #                                                 dataset['train']['tgt'], batch_size_words=opt.batch_size_words,
    #                                                 data_type=dataset.get("type", "text"), sorting=True,
    #                                                 batch_size_sents=opt.batch_size_sents,
    #                                                 multiplier=opt.batch_size_multiplier,
    #                                                 reshape_speech=opt.reshape_speech,
    #                                                 augment=opt.augment_speech))
    #         elif add_format[i] == 'bin':
    #
    #             from onmt.data.indexed_dataset import IndexedInMemoryDataset
    #
    #             train_path = add_data[i] + '.train'
    #             train_src = IndexedInMemoryDataset(train_path + '.src')
    #             train_tgt = IndexedInMemoryDataset(train_path + '.tgt')
    #
    #             additional_data.append(onmt.Dataset(train_src,
    #                                                 train_tgt,
    #                                                 batch_size_words=opt.batch_size_words,
    #                                                 data_type=opt.encoder_type,
    #                                                 batch_size_sents=opt.batch_size_sents,
    #                                                 multiplier=opt.batch_size_multiplier))

    if opt.load_from:
        checkpoint = torch.load(opt.load_from,
                                map_location=lambda storage, loc: storage)
        print("* Loading dictionaries from the checkpoint")
        dicts = checkpoint['dicts']
    else:
        dicts['tgt'].patch(opt.patch_vocab_multiplier)
        checkpoint = None

    if "src" in dicts:
        print(' * vocabulary size. source = %d; target = %d' %
              (dicts['src'].size(), dicts['tgt'].size()))
    else:
        print(' * vocabulary size. target = %d' % (dicts['tgt'].size()))

    print(' * number of sentences in training data: %d' % train_data.size())
    print(' * number of sentences in validation data: %d' % valid_data.size())

    print('* Building model...')

    if not opt.fusion:
        model = build_model(opt, dicts)
        """ Building the loss function """
        if opt.ctc_loss != 0:
            loss_function = NMTAndCTCLossFunc(
                dicts['tgt'].size(),
                label_smoothing=opt.label_smoothing,
                ctc_weight=opt.ctc_loss)
        else:
            loss_function = NMTLossFunc(opt.model_size,
                                        dicts['tgt'].size(),
                                        label_smoothing=opt.label_smoothing,
                                        mirror=opt.mirror_loss)

        # This function replaces modules with the more optimized counterparts so that it can run faster
        # Currently exp with LayerNorm
        optimize_model(model)

    else:
        from onmt.model_factory import build_fusion
        from onmt.modules.loss import FusionLoss

        model = build_fusion(opt, dicts)

        loss_function = FusionLoss(dicts['tgt'].size(),
                                   label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        raise NotImplementedError(
            "Multi-GPU training is not supported at the moment.")
    else:
        trainer = XETrainer(model, loss_function, train_data, valid_data,
                            dicts, opt)

    trainer.run(checkpoint=checkpoint)
示例#10
0
def main():

    start = time.time()
    print("Loading data from '%s'" % opt.data)

    if opt.data_format == 'raw':
        dataset = torch.load(opt.data)
        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse)

        dicts = dataset['dicts']

        # For backward compatibility
        train_dict = defaultdict(lambda: None, dataset['train'])
        valid_dict = defaultdict(lambda: None, dataset['valid'])

        if train_dict['src_lang'] is not None:
            assert 'langs' in dicts
            train_src_langs = train_dict['src_lang']
            train_tgt_langs = train_dict['tgt_lang']
        else:
            # allocate new languages
            dicts['langs'] = {'src': 0, 'tgt': 1}
            train_src_langs = list()
            train_tgt_langs = list()
            # Allocation one for the bilingual case
            train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
            train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

        train_data = LanguageModelDataset(
            dataset['train']['tgt'],
            train_tgt_langs,
            batch_size_sents=opt.batch_size_sents,
            seq_length=opt.lm_seq_length)

        if valid_dict['src_lang'] is not None:
            assert 'langs' in dicts
            valid_src_langs = valid_dict['src_lang']
            valid_tgt_langs = valid_dict['tgt_lang']
        else:
            # allocate new languages
            valid_src_langs = list()
            valid_tgt_langs = list()

            # Allocation one for the bilingual case
            valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
            valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

        valid_data = LanguageModelDataset(
            dataset['valid']['tgt'],
            valid_tgt_langs,
            batch_size_sents=opt.batch_size_sents,
            seq_length=opt.lm_seq_length)

        if opt.load_from:
            checkpoint = torch.load(opt.load_from,
                                    map_location=lambda storage, loc: storage)
            print("* Loading dictionaries from the checkpoint")
            dicts = checkpoint['dicts']
        else:
            dicts['tgt'].patch(opt.patch_vocab_multiplier)
            checkpoint = None

        if "src" in dicts:
            print(' * vocabulary size. source = %d; target = %d' %
                  (dicts['src'].size(), dicts['tgt'].size()))
        else:
            print(' * vocabulary size. target = %d' % (dicts['tgt'].size()))

        print(' * number of training sentences. %d' % train_data.size())
        print(' * maximum batch size (words per batch). %d' %
              (opt.batch_size_sents * opt.lm_seq_length))

    else:
        raise NotImplementedError

    print('Building model...')
    model = build_language_model(opt, dicts)
    optimize_model(model)
    """ Building the loss function """
    loss_function = NMTLossFunc(opt.model_size,
                                dicts['tgt'].size(),
                                label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        raise NotImplementedError("Multi-GPU training is not supported ATM.")
    else:
        # if opt.fp16:
        #     trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt)
        # else:
        trainer = XETrainer(model, loss_function, train_data, valid_data,
                            dicts, opt)

    trainer.run(checkpoint=checkpoint)