예제 #1
0
def main():

    print("Loading data from '%s'" % opt.data)

    dataset = torch.load(opt.data)

    trainData = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'],
                             opt.batch_size, opt.cuda)
    validData = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'],
                             opt.batch_size, opt.cuda)

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (dicts['src'].size(), dicts['tgt'].size()))
    print(' * number of training sentences. %d' % len(dataset['train']['src']))
    print(' * maximum batch size. %d' % opt.batch_size)

    print('Building model...')

    if opt.train_from is None:
        encoder = onmt.Models.Encoder(opt, dicts['src'])
        decoder = onmt.Models.Decoder(opt, dicts['tgt'])
        generator = nn.Sequential(nn.Linear(opt.rnn_size, dicts['tgt'].size()),
                                  nn.LogSoftmax())
        if opt.cuda > 1:
            generator = nn.DataParallel(generator, device_ids=opt.gpus)
        model = onmt.Models.NMTModel(encoder, decoder, generator)
        if opt.cuda > 1:
            model = nn.DataParallel(model, device_ids=opt.gpus)
        if opt.cuda:
            model.cuda()
        else:
            model.cpu()

        model.generator = generator

        for p in model.parameters():
            p.data.uniform_(-opt.param_init, opt.param_init)

        optim = onmt.Optim(model.parameters(),
                           opt.optim,
                           opt.learning_rate,
                           opt.max_grad_norm,
                           lr_decay=opt.learning_rate_decay,
                           start_decay_at=opt.start_decay_at)
    else:
        print('Loading from checkpoint at %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from)
        model = checkpoint['model']
        if opt.cuda:
            model.cuda()
        else:
            model.cpu()
        optim = checkpoint['optim']
        opt.start_epoch = checkpoint['epoch'] + 1

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)

    trainModel(model, trainData, validData, dataset, optim)
예제 #2
0
def main():

    start = time.time()
    print("Loading data from '%s'" % opt.data)

    if opt.data_format == 'raw':
        dataset = torch.load(opt.data)
        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse )


        train_data = onmt.Dataset(dataset['train']['src'],
                                 dataset['train']['tgt'], opt.batch_size_words,
                                 data_type=dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier = opt.batch_size_multiplier,
                                 sort_by_target=opt.sort_by_target)
        valid_data = onmt.Dataset(dataset['valid']['src'],
                                 dataset['valid']['tgt'], opt.batch_size_words,
                                 data_type=dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents)

        dicts = dataset['dicts']
        if "src" in dicts:
            print(' * vocabulary size. source = %d; target = %d' %
            (dicts['src'].size(), dicts['tgt'].size()))
        else:
            print(' * vocabulary size. target = %d' %
            (dicts['tgt'].size()))

        print(' * number of training sentences. %d' %
          train_data.size())
        print(' * maximum batch size (words per batch). %d' % opt.batch_size_words)

    else:
        raise NotImplementedError

    print('Building model...')
    model = build_language_model(opt, dicts)
    
    
    """ Building the loss function """

    loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)
    
    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.")
    else:
        if opt.fp16:
            trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt)
        else:
            trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt)

    
    trainer.run(save_file=opt.load_from)
예제 #3
0
파일: neural.py 프로젝트: sam-lev/NeuralMT
    def train(self, data_path, working_dir):
        logger = logging.getLogger('mmt.train.OpenNMTDecoder')
        logger.info('Training started for data "%s"' % data_path)

        save_model = os.path.join(working_dir, 'train_model')

        # Loading training data ----------------------------------------------------------------------------------------
        data_file = os.path.join(data_path, 'train_processed.train.pt')
        logger.info('Loading data from "%s"... START' % data_file)
        start_time = time.time()
        data_set = torch.load(data_file)
        logger.info('Loading data... END %.2fs' % (time.time() - start_time))

        src_dict, trg_dict = data_set['dicts']['src'], data_set['dicts']['tgt']
        src_train, trg_train = data_set['train']['src'], data_set['train']['tgt']
        src_valid, trg_valid = data_set['valid']['src'], data_set['valid']['tgt']

        # Creating trainer ---------------------------------------------------------------------------------------------

        logger.info('Building model... START')
        start_time = time.time()
        trainer = NMTEngineTrainer.new_instance(src_dict, trg_dict, random_seed=3435, gpu_ids=self._gpus)
        logger.info('Building model... END %.2fs' % (time.time() - start_time))

        # Creating data sets -------------------------------------------------------------------------------------------

        logger.info('Creating Data... START')
        start_time = time.time()
        train_data = onmt.Dataset(src_train, trg_train, trainer.batch_size, self._gpus)
        valid_data = onmt.Dataset(src_valid, trg_valid, trainer.batch_size, self._gpus, volatile=True)
        logger.info('Creating Data... END %.2fs' % (time.time() - start_time))

        logger.info(' Vocabulary size. source = %d; target = %d' % (src_dict.size(), trg_dict.size()))
        logger.info(' Number of training sentences. %d' % len(data_set['train']['src']))
        logger.info(' Maximum batch size. %d' % trainer.batch_size)

        # Training model -----------------------------------------------------------------------------------------------

        logger.info('Training model... START')
        try:
            start_time = time.time()
            checkpoint = trainer.train_model(train_data, valid_data=valid_data, save_path=save_model)
            logger.info('Training model... END %.2fs' % (time.time() - start_time))
        except TrainingInterrupt as e:
            checkpoint = e.checkpoint
            logger.info('Training model... INTERRUPTED %.2fs' % (time.time() - start_time))

        # Saving last checkpoint ---------------------------------------------------------------------------------------
        model_folder = os.path.abspath(os.path.join(self._model, os.path.pardir))
        if not os.path.isdir(model_folder):
            os.mkdir(model_folder)

        logger.info('Storing model "%s" to %s' % (checkpoint, self._model))
        os.rename(checkpoint, self._model)
예제 #4
0
def main():

    print("Loading data from '%s'" % opt.data)

    dataset = torch.load(opt.data)
    features = dataset['features']

    trainData = onmt.Dataset(dataset['train']['data'], opt.gpus)

    validData = onmt.Dataset(dataset['valid']['data'], opt.gpus, volatile=True)

    dicts = dataset['dicts']
    opt_pre = dataset['opt']
    print(' * vocabulary size. = %d' % (dicts['vocab'].size()))
    print(' * maximum batch size. %d' % opt_pre.batch_size)

    print('Building model...')

    model = BLSTM.BLSTM(opt, dicts, dataset['train']['type'], features)

    if opt.param_init:
        print('Intializing model parameters.')
        for p in model.parameters():
            p.data.uniform_(-opt.param_init, opt.param_init)
        model.word_lut.weight.data[onmt.Constants.PAD].zero_()
        model.session_lut.weight.data[onmt.Constants.PAD].zero_()
        model.format_lut.weight.data[onmt.Constants.PAD].zero_()
        model.pos_lut.weight.data[onmt.Constants.PAD].zero_()
    if opt.emb_init:
        print('Intializing embeddings.')
        w2v = dicts['w2v']
        for i in range(model.word_lut.weight.size(0)):
            if i in w2v:
                model.word_lut.weight[i].data.copy_(torch.FloatTensor(w2v[i]))
                #model.word_lut.weight[i].data.copy_(torch.from_numpy(w2v[i]))

    if len(opt.gpus) >= 1:
        model.cuda()
    else:
        model.cpu()

    model_optim = onmt.Optim(opt.optim,
                             opt.learning_rate,
                             opt.max_grad_norm,
                             opt.stop_lr,
                             lr_decay=opt.learning_rate_decay,
                             start_decay_at=opt.start_decay_at)
    model_optim.set_parameters(model.parameters())
    trainModel(model, trainData, validData, dataset, model_optim, dicts)
예제 #5
0
    def buildData(self, srcBatch, goldBatch):
        # This needs to be the same as preprocess.py.

        if self.start_with_bos:
            srcData = [
                self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD)
                for b in srcBatch
            ]
        else:
            srcData = [
                self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
                for b in srcBatch
            ]

        tgtData = None
        if goldBatch:
            tgtData = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in goldBatch
            ]

        return onmt.Dataset(srcData,
                            tgtData,
                            9999, [self.opt.gpu],
                            max_seq_num=self.opt.batch_size)
예제 #6
0
    def buildData(self, srcBatch, goldBatch):
        # This needs to be the same as preprocess.py.
        if self._type == "text":
            srcData = [
                self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
                for b in srcBatch
            ]
        elif self._type == "img":
            srcData = [
                transforms.ToTensor()(Image.open(self.opt.src_img_dir + "/" +
                                                 b[0])) for b in srcBatch
            ]

        tgtData = None
        if goldBatch:
            tgtData = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in goldBatch
            ]

        return onmt.Dataset(srcData,
                            tgtData,
                            self.opt.batch_size,
                            self.opt.cuda,
                            volatile=True,
                            data_type=self._type,
                            balance=False)
예제 #7
0
    def buildData(self, srcBatch, goldBatch):
        srcFeats = []
        if self.src_feature_dicts:
            srcFeats = [[] for i in range(len(self.src_feature_dicts))]
        srcData = []
        tgtData = None
        for b in srcBatch:
            _, srcD, srcFeat = onmt.IO.readSrcLine(b, self.src_dict,
                                                   self.src_feature_dicts,
                                                   self._type)
            srcData += [srcD]
            for i in range(len(srcFeats)):
                srcFeats[i] += [srcFeat[i]]

        if goldBatch:
            for b in goldBatch:
                _, tgtD, tgtFeat = onmt.IO.readTgtLine(b, self.src_dict, None,
                                                       self._type)
                tgtData += [tgtD]

        return onmt.Dataset(srcData,
                            tgtData,
                            self.opt.batch_size,
                            self.opt.cuda,
                            volatile=True,
                            data_type=self._type,
                            srcFeatures=srcFeats)
예제 #8
0
    def buildData(self,
                  srcBatch,
                  goldBatch,
                  alignBatch=None,
                  tgtUniBatch=None):
        srcData = [
            self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)[0]
            for b in srcBatch
        ]
        tgtData = None
        if goldBatch:
            tgtData = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)[0]
                for b in goldBatch
            ]

        return onmt.Dataset(srcData,
                            tgtData,
                            tgtUniBatch,
                            alignBatch,
                            self.opt.batch_size,
                            self.opt.cuda,
                            volatile=True)
예제 #9
0
    def build_data(self, src_sents, tgt_sents):
        # This needs to be the same as preprocess.py.

        if self.start_with_bos:
            src_data = [
                self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD)
                for b in src_sents
            ]
        else:
            src_data = [
                self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
                for b in src_sents
            ]

        tgt_data = None
        if tgt_sents:
            tgt_data = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in tgt_sents
            ]

        return onmt.Dataset(src_data,
                            tgt_data,
                            sys.maxsize,
                            data_type=self._type,
                            batch_size_sents=self.opt.batch_size)
예제 #10
0
    def buildData(self, srcBatch, goldBatch):
        srcData = [
            self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in srcBatch
        ]
        srcData = []
        for b in srcBatch:
            idx = self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            if len(idx) <= 0:
                idx = self.src_dict.convertToIdx(onmt.Constants.UNK_WORD,
                                                 onmt.Constants.UNK_WORD)
            srcData.append(idx)
        tgtData = None
        if goldBatch:
            tgtData = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in goldBatch
            ]

        return onmt.Dataset(srcData,
                            tgtData,
                            self.opt.batch_size,
                            self.opt.cuda,
                            volatile=True)
예제 #11
0
    def buildData(self, srcBatch, cxtBatch, goldBatch):
        srcData = [
            self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in srcBatch
        ]
        cxtData = None
        if cxtBatch:
            cxtData = [
                self.cxt_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
                for b in cxtBatch
            ]
        tgtData = None
        if goldBatch:
            tgtData = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in goldBatch
            ]

        return onmt.Dataset(srcData,
                            cxtData,
                            tgtData,
                            self.opt.batch_size,
                            self.opt.cuda,
                            volatile=True)
예제 #12
0
    def __init__(self, opt, dicts):

        self.data_path = opt.data

        self.smallest_dataset_multiplicator = opt.smallest_dataset_multiplicator

        self.dicts = dicts

        nSets = dicts['nSets']
        setIDs = dicts['setIDs']
        self.setIDs = setIDs

        # Initialize the place holder for training data
        self.trainSets = dict()

        self.trainSets['src'] = list()
        self.trainSets['tgt'] = list()

        for i in range(nSets):
            self.trainSets['src'].append(list())
            self.trainSets['tgt'].append(list())

        self.loading_strategy = opt.loading_strategy
        """ First, we have to read the data path to detect n training files"""
        train_files = []
        for root, dirs, files in os.walk(opt.data + "/"):
            for tfile in files:
                if "train." in tfile:
                    train_files.append(tfile)

        self.train_files = train_files
        self.datasets = dict()

        if self.loading_strategy == 'all':
            """ Load all the data in the shards """
            for train_file in sorted(self.train_files,
                                     key=lambda a: int(a.split(".")[-1])):
                print("Loading training data from '%s'" %
                      (opt.data + "/" + train_file) + "...")
                data_ = torch.load(opt.data + "/" + train_file)

                for i in range(nSets):
                    self.trainSets['src'][i] += data_['src'][i]
                    self.trainSets['tgt'][i] += data_['tgt'][i]

                    #~ trainSets[i] = onmt.Dataset(dataset['train']['src'][i], dataset['train']['tgt'][i],
                    #~ opt.batch_size, opt.gpus)

            # After loading everything, make a dataset
            for i in range(nSets):
                self.datasets[i] = onmt.Dataset(self.trainSets['src'][i],
                                                self.trainSets['tgt'][i],
                                                opt.batch_size, opt.gpus)

        else:
            raise NotImplementedError
        """ Next, if the loading strategy is to load all then pre-load all shards """
예제 #13
0
    def build_data(self, src_sents, tgt_sents, type='mt'):
        # This needs to be the same as preprocess.py.

        if type == 'mt':
            if self.start_with_bos:
                src_data = [
                    self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                               onmt.Constants.BOS_WORD)
                    for b in src_sents
                ]
            else:
                src_data = [
                    self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
                    for b in src_sents
                ]
        elif type == 'asr':
            # no need to deal with this
            src_data = src_sents
        else:
            raise NotImplementedError

        tgt_bos_word = self.opt.bos_token
        if self.opt.no_bos_gold:
            tgt_bos_word = None
        tgt_data = None
        if tgt_sents:
            tgt_data = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           tgt_bos_word,
                                           onmt.Constants.EOS_WORD)
                for b in tgt_sents
            ]

        src_atbs = None

        if self.attributes:
            tgt_atbs = dict()

            idx = 0
            for i in self.atb_dict:
                tgt_atbs[i] = [
                    self.atb_dict[i].convertToIdx([self.attributes[idx]],
                                                  onmt.Constants.UNK_WORD)
                    for _ in src_sents
                ]
                idx = idx + 1

        else:
            tgt_atbs = None

        return onmt.Dataset(src_data,
                            tgt_data,
                            src_atbs=src_atbs,
                            tgt_atbs=tgt_atbs,
                            batch_size_words=sys.maxsize,
                            data_type=self._type,
                            batch_size_sents=self.opt.batch_size)
예제 #14
0
    def build_asr_data(self, src_data, tgt_sents):
        # This needs to be the same as preprocess.py.

        tgt_data = None
        if tgt_sents:
            tgt_data = [self.tgt_dict.convertToIdx(b,
                                                   onmt.constants.UNK_WORD,
                                                   onmt.constants.BOS_WORD,
                                                   onmt.constants.EOS_WORD) for b in tgt_sents]

        return onmt.Dataset(src_data, tgt_data,
                            batch_size_words=sys.maxsize,
                            data_type=self._type, batch_size_sents=self.opt.batch_size)
예제 #15
0
    def buildASRData(self, srcData, goldBatch):
        # This needs to be the same as preprocess.py.

        tgtData = None
        if goldBatch:
            tgtData = [self.tgt_dict.convertToIdx(b,
                                                  onmt.constants.UNK_WORD,
                                                  onmt.constants.BOS_WORD,
                                                  onmt.constants.EOS_WORD) for b in goldBatch]

        return onmt.Dataset(srcData, tgtData, sys.maxsize,
                            [self.opt.gpu],
                            data_type=self._type, max_seq_num=self.opt.batch_size)
예제 #16
0
    def buildData(self, srcBatch, goldBatch):
        srcData = [self.src_dict.convertToIdx(b,
                    onmt.Constants.UNK_WORD, padding=True) for b in srcBatch]
        tgtData = []
        if goldBatch:
            for label in goldBatch:
                if label == self.label0:
                    tgtData += [torch.LongTensor([0])]
                elif label == self.label1:
                    tgtData += [torch.LongTensor([1])]

        return onmt.Dataset(srcData, tgtData,
            self.opt.batch_size, self.opt.cuda, volatile=True)
예제 #17
0
    def buildData(self, srcBatch, goldBatch):
        srcData = [self.src_dict.convertToIdx(labels=b,
                    unkWord=onmt.Constants.UNK_WORD,
                    bosWord=onmt.Constants.BOS_WORD,
                    eosWord=onmt.Constants.EOS_WORD) for b in srcBatch]
        tgtData = None
        if goldBatch:
            tgtData = [self.tgt_dict.convertToIdx(labels=b,
                        unkWord=onmt.Constants.UNK_WORD,
                        bosWord=onmt.Constants.BOS_WORD,
                        eosWord=onmt.Constants.EOS_WORD) for b in goldBatch]

        return onmt.Dataset(srcData, tgtData,
            self.opt.batch_size, self.opt.cuda)
예제 #18
0
    def build_data(self, src_sents, tgt_sents):
        # This needs to be the same as preprocess.py.

        if self.start_with_bos:
            src_data = [
                self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD)
                for b in src_sents
            ]
        else:
            src_data = [
                self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
                for b in src_sents
            ]

        tgt_data = None
        if tgt_sents:
            tgt_data = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in tgt_sents
            ]

        src_atbs = None

        if self.attributes:
            tgt_atbs = dict()

            idx = 0
            for i in self.atb_dict:

                tgt_atbs[i] = [
                    self.atb_dict[i].convertToIdx([self.attributes[idx]],
                                                  onmt.Constants.UNK_WORD)
                    for _ in src_sents
                ]
                idx = idx + 1

        else:
            tgt_atbs = None

        return onmt.Dataset(src_data,
                            tgt_data,
                            src_atbs=src_atbs,
                            tgt_atbs=tgt_atbs,
                            batch_size_words=sys.maxsize,
                            data_type=self._type,
                            batch_size_sents=self.opt.batch_size)
예제 #19
0
    def build_data(self, src_sents, tgt_sents, type='mt'):
        # This needs to be the same as preprocess.py.

        if type == 'mt':
            if self.start_with_bos:
                src_data = [
                    self.src_dict.convertToIdx(b, onmt.constants.UNK_WORD,
                                               onmt.constants.BOS_WORD)
                    for b in src_sents
                ]
            else:
                src_data = [
                    self.src_dict.convertToIdx(b, onmt.constants.UNK_WORD)
                    for b in src_sents
                ]
            data_type = 'text'
        elif type == 'asr':
            # no need to deal with this
            src_data = src_sents
            data_type = 'audio'
        else:
            raise NotImplementedError

        tgt_bos_word = self.opt.bos_token
        if self.opt.no_bos_gold:
            tgt_bos_word = None
        tgt_data = None
        if tgt_sents:
            tgt_data = [
                self.tgt_dict.convertToIdx(b, onmt.constants.UNK_WORD,
                                           tgt_bos_word,
                                           onmt.constants.EOS_WORD)
                for b in tgt_sents
            ]

        src_lang_data = [torch.Tensor([self.lang_dict[self.src_lang]])]
        tgt_lang_data = [torch.Tensor([self.lang_dict[self.tgt_lang]])]

        return onmt.Dataset(src_data,
                            tgt_data,
                            src_langs=src_lang_data,
                            tgt_langs=tgt_lang_data,
                            batch_size_words=sys.maxsize,
                            data_type=data_type,
                            batch_size_sents=self.opt.batch_size,
                            src_align_right=self.opt.src_align_right)
    def buildData(self, srcBatch, goldBatch):
        srcData = []
        tgtData = []
        for b in srcBatch:
            _, srcD = onmt.IO.readSrcLine(b, self.src_dict)
            srcData += [srcD]

        if goldBatch:
            for b in goldBatch:
                _, tgtD, tgtFeat = onmt.IO.readTgtLine(b, self.tgt_dict)
                tgtData += [tgtD]

        return onmt.Dataset(srcData,
                            tgtData,
                            self.opt.batch_size,
                            self.opt.cuda,
                            volatile=True)
    def buildData(self, src1Batch, src2Batch, goldBatch):
        src1Data = [
            self.src_dict.convertToIdx(b,
                                       onmt.Constants.UNK_WORD,
                                       padding=True) for b in src1Batch
        ]
        src2Data = [
            self.src_dict.convertToIdx(b,
                                       onmt.Constants.UNK_WORD,
                                       padding=True) for b in src2Batch
        ]
        tgtData = []
        if goldBatch:
            for label in goldBatch:
                tgtData += [torch.LongTensor([int(label)])]

        return onmt.Dataset(src1Data, src2Data, tgtData, self.opt.batch_size,
                            self.opt.cuda)
예제 #22
0
    def buildData(self, srcBatch):
        srcFeats = []
        srcData = []
        tgtData = []
        for b in srcBatch:
            _, srcD, srcFeat = onmt.IO.readSrcLine(b, self.src_dict, None,
                                                   "text")
            srcData += [srcD]
            for i in range(len(srcFeats)):
                srcFeats[i] += [srcFeat[i]]

        return onmt.Dataset(srcData,
                            tgtData,
                            self.opt.batch_size,
                            True,
                            volatile=True,
                            data_type="text",
                            srcFeatures=srcFeats)
예제 #23
0
    def buildData(self, srcBatch, goldBatch):
        srcData = [
            self.all_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in srcBatch
        ]
        tgtData = None
        if goldBatch:
            tgtData = [
                self.all_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in goldBatch
            ]

        return onmt.Dataset(srcData,
                            tgtData,
                            self.opt.batch_size,
                            self.opt.cuda,
                            volatile=True,
                            sort_key=False)
예제 #24
0
    def buildData(self, srcBatch, goldBatch):
        # This needs to be the same as preprocess.py.
        srcData = [
            self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in srcBatch
        ]

        tgtData = None
        if goldBatch:
            tgtData = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in goldBatch
            ]

        return onmt.Dataset(srcData,
                            tgtData,
                            self.opt.batch_size,
                            self.opt.cuda,
                            volatile=True,
                            data_type=self._type)
예제 #25
0
    def buildData(self, srcBatch, fea1_Batch, fea2_Batch, fea3_Batch,
                  fea4_Batch, fea5_Batch, goldBatch):

        srcData = [
            self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in srcBatch
        ]
        fea1_Data = [
            self.feature_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in fea1_Batch
        ]
        fea2_Data = [
            self.feature_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in fea2_Batch
        ]
        fea3_Data = [
            self.feature_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in fea3_Batch
        ]
        fea4_Data = [
            self.feature_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in fea4_Batch
        ]
        fea5_Data = [
            self.feature_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in fea5_Batch
        ]

        tgtData = None
        if goldBatch:
            tgtData = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in goldBatch
            ]

        return onmt.Dataset(srcData, tgtData, fea1_Data, fea2_Data, fea3_Data, fea4_Data, fea5_Data, \
            self.opt.batch_size, self.opt.cuda, volatile=True)
예제 #26
0
def main():
    print("Loading data from '%s'" % opt.data)

    dataset = torch.load(opt.data)
    dict_checkpoint = (opt.train_from
                       if opt.train_from else opt.train_from_state_dict)
    if dict_checkpoint:
        print('Loading dicts from checkpoint at %s' % dict_checkpoint)
        checkpoint = torch.load(dict_checkpoint,
                                map_location=lambda storage, loc: storage)
        dataset['dicts'] = checkpoint['dicts']

    trainData = onmt.Dataset(dataset['train']['src'],
                             dataset['train']['tgt'],
                             opt.batch_size,
                             opt.gpus,
                             data_type=dataset.get("type", "text"),
                             srcFeatures=dataset['train'].get('src_features'),
                             tgtFeatures=dataset['train'].get('tgt_features'),
                             alignment=dataset['train'].get('alignments'))
    validData = onmt.Dataset(dataset['valid']['src'],
                             dataset['valid']['tgt'],
                             opt.batch_size,
                             opt.gpus,
                             volatile=True,
                             data_type=dataset.get("type", "text"),
                             srcFeatures=dataset['valid'].get('src_features'),
                             tgtFeatures=dataset['valid'].get('tgt_features'),
                             alignment=dataset['valid'].get('alignments'))

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (dicts['src'].size(), dicts['tgt'].size()))
    if 'src_features' in dicts:
        for j in range(len(dicts['src_features'])):
            print(' * src feature %d size = %d' %
                  (j, dicts['src_features'][j].size()))

    dicts = dataset['dicts']
    print(' * number of training sentences. %d' % len(dataset['train']['src']))
    print(' * maximum batch size. %d' % opt.batch_size)

    print('Building model...')

    if opt.encoder_type == "text":
        encoder = onmt.Models.Encoder(opt, dicts['src'],
                                      dicts.get('src_features', None))
    elif opt.encoder_type == "img":
        encoder = onmt.modules.ImageEncoder(opt)
        assert ("type" not in dataset or dataset["type"] == "img")
    else:
        print("Unsupported encoder type %s" % (opt.encoder_type))

    decoder = onmt.Models.Decoder(opt, dicts['tgt'])

    if opt.copy_attn:
        generator = onmt.modules.CopyGenerator(opt, dicts['src'], dicts['tgt'])
    else:
        generator = nn.Sequential(nn.Linear(opt.rnn_size, dicts['tgt'].size()),
                                  nn.LogSoftmax())
        if opt.share_decoder_embeddings:
            generator[0].weight = decoder.embeddings.word_lut.weight

    model = onmt.Models.NMTModel(encoder, decoder, len(opt.gpus) > 1)

    if opt.train_from:
        print('Loading model from checkpoint at %s' % opt.train_from)
        chk_model = checkpoint['model']
        generator_state_dict = chk_model.generator.state_dict()
        model_state_dict = {
            k: v
            for k, v in chk_model.state_dict().items() if 'generator' not in k
        }
        model.load_state_dict(model_state_dict)
        generator.load_state_dict(generator_state_dict)
        opt.start_epoch = checkpoint['epoch'] + 1

    if opt.train_from_state_dict:
        print('Loading model from checkpoint at %s' %
              opt.train_from_state_dict)
        model.load_state_dict(checkpoint['model'])
        generator.load_state_dict(checkpoint['generator'])
        opt.start_epoch = checkpoint['epoch'] + 1

    if len(opt.gpus) >= 1:
        model.cuda()
        generator.cuda()
    else:
        model.cpu()
        generator.cpu()

    if len(opt.gpus) > 1:
        print('Multi gpu training ', opt.gpus)
        model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
        generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

    model.generator = generator

    if not opt.train_from_state_dict and not opt.train_from:
        if opt.param_init != 0.0:
            print('Intializing params')
            for p in model.parameters():
                p.data.uniform_(-opt.param_init, opt.param_init)

        encoder.embeddings.load_pretrained_vectors(opt.pre_word_vecs_enc)
        decoder.embeddings.load_pretrained_vectors(opt.pre_word_vecs_dec)

        optim = onmt.Optim(opt.optim,
                           opt.learning_rate,
                           opt.max_grad_norm,
                           lr_decay=opt.learning_rate_decay,
                           start_decay_at=opt.start_decay_at,
                           opt=opt)
    else:
        print('Loading optimizer from checkpoint:')
        optim = checkpoint['optim']
        print(optim)

    optim.set_parameters(model.parameters())

    if opt.train_from or opt.train_from_state_dict:
        optim.optimizer.load_state_dict(
            checkpoint['optim'].optimizer.state_dict())

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)
    enc = 0
    dec = 0
    for name, param in model.named_parameters():
        if 'encoder' in name:
            enc += param.nelement()
        elif 'decoder' in name:
            dec += param.nelement()
        else:
            print(name, param.nelement())
    print('encoder: ', enc)
    print('decoder: ', dec)

    check_model_path()

    trainModel(model, trainData, validData, dataset, optim)
예제 #27
0
def main():

    if opt.data_format == 'raw':
        start = time.time()
        if opt.data.endswith(".train.pt"):
            print("Loading data from '%s'" % opt.data)
            dataset = torch.load(opt.data)
        else:
            print("Loading data from %s" % opt.data + ".train.pt")
            dataset = torch.load(opt.data + ".train.pt")

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse )

        train_data = onmt.Dataset(dataset['train']['src'],
                                 dataset['train']['tgt'], opt.batch_size_words,
                                 data_type=dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier = opt.batch_size_multiplier,
                                 reshape_speech=opt.reshape_speech,
                                 augment=opt.augment_speech)
        valid_data = onmt.Dataset(dataset['valid']['src'],
                                 dataset['valid']['tgt'], opt.batch_size_words,
                                 data_type=dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents,
                                 reshape_speech=opt.reshape_speech)

        dicts = dataset['dicts']
        if "src" in dicts:
            print(' * vocabulary size. source = %d; target = %d' %
            (dicts['src'].size(), dicts['tgt'].size()))
        else:
            print(' * vocabulary size. target = %d' %
            (dicts['tgt'].size()))

        print(' * number of training sentences. %d' %
          len(dataset['train']['src']))
        print(' * maximum batch size (words per batch). %d' % opt.batch_size_words)

    elif opt.data_format == 'bin':

        from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset

        dicts = torch.load(opt.data + ".dict.pt")

        #~ train = {}
        train_path = opt.data + '.train'
        train_src = IndexedInMemoryDataset(train_path + '.src')
        train_tgt = IndexedInMemoryDataset(train_path + '.tgt')

        train_data = onmt.Dataset(train_src,
                                 train_tgt, opt.batch_size_words,
                                 data_type=opt.encoder_type,
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier = opt.batch_size_multiplier)

        valid_path = opt.data + '.valid'
        valid_src = IndexedInMemoryDataset(valid_path + '.src')
        valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt')

        valid_data = onmt.Dataset(valid_src,
                                 valid_tgt, opt.batch_size_words,
                                 data_type=opt.encoder_type,
                                 batch_size_sents=opt.batch_size_sents)

    else:
        raise NotImplementedError

    print('Building model...')

    if not opt.fusion:
        model = build_model(opt, dicts)

        """ Building the loss function """
        if opt.ctc_loss != 0:
            loss_function = NMTAndCTCLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing,ctc_weight = opt.ctc_loss)
        else:
            loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing)
    else:
        from onmt.ModelConstructor import build_fusion
        from onmt.modules.Loss import FusionLoss

        model = build_fusion(opt, dicts)

        loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing)


    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
            raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.")
    else:
        if opt.fp16:
            trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt)
        else:
            trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt)

    
    trainer.run(save_file=opt.load_from)
def main():

    print("Loading data from '%s'" % opt.data)
    
    dataset = torch.load(opt.data)

    dict_checkpoint = opt.train_from if opt.train_from else opt.train_from_state_dict
    if dict_checkpoint:
        print('Loading dicts from checkpoint at %s' % dict_checkpoint)
        checkpoint = torch.load(dict_checkpoint)
        dataset['dicts'] = checkpoint['dicts']

    trainData = onmt.Dataset(dataset['train']['src'], dataset['train']['cxt'],
                             dataset['train']['tgt'], opt.batch_size, opt.gpus)
    validData = onmt.Dataset(dataset['valid']['src'], dataset['valid']['cxt'],
                             dataset['valid']['tgt'], opt.batch_size, opt.gpus,
                             volatile=True)

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (dicts['src'].size(), dicts['tgt'].size()))
    print(' * number of training sentences. %d' %
          len(dataset['train']['src']))
    print(' * maximum batch size. %d' % opt.batch_size)

    print('Building model...')

    encoder_src = onmt.Models.Encoder(opt, dicts['src'])
    if opt.add_context:
        encoder_cxt = onmt.Models.Encoder(opt, dicts['cxt'])
    else:
        encoder_cxt = None
    decoder = onmt.Models.Decoder(opt, dicts['tgt'])

    generator = nn.Sequential(
        nn.Linear(opt.rnn_size, dicts['tgt'].size()),
        nn.LogSoftmax())

    if opt.train_from:
        print('Loading model from checkpoint at %s' % opt.train_from)
        chk_model = checkpoint['model']
        generator_state_dict = chk_model.generator.state_dict()
        model_state_dict = {k: v for k, v in chk_model.state_dict().items() if 'generator' not in k}
        model.load_state_dict(model_state_dict)
        generator.load_state_dict(generator_state_dict)
        opt.start_epoch = checkpoint['epoch'] + 1

    if opt.train_from_state_dict:
        print('Loading model from checkpoint at %s' % opt.train_from_state_dict)
        encoder_src.load_state_dict(checkpoint['encoder_src'])
        encoder_cxt.load_state_dict(checkpoint['encoder_cxt'])
        decoder.load_state_dict(checkpoint['decoder'])
        generator.load_state_dict(checkpoint['generator'])
        opt.start_epoch = checkpoint['epoch'] + 1

    model = onmt.Models.NMTModel(encoder_src, encoder_src, decoder)

    if len(opt.gpus) >= 1:
        model.cuda()
        generator.cuda()
    else:
        model.cpu()
        generator.cpu()

    if len(opt.gpus) > 1:
        model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
        generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

    model.generator = generator

    if not opt.train_from_state_dict and not opt.train_from:
        for p in model.parameters():
            p.data.uniform_(-opt.param_init, opt.param_init)

        encoder_src.load_pretrained_vectors(opt)
        if opt.add_context:
            encoder_cxt.load_pretrained_vectors(opt)
        decoder.load_pretrained_vectors(opt)

        optim = onmt.Optim(
            opt.optim, opt.learning_rate, opt.max_grad_norm,
            lr_decay=opt.learning_rate_decay,
            start_decay_at=opt.start_decay_at
        )
    else:
        print('Loading optimizer from checkpoint:')
        optim = checkpoint['optim']
        #optim.start_decay_at = opt.start_decay_at
        #optim.lr = 0.5
        #optim.start_decay = False
        #optim.optimizer.param_groups[0]['lr'] = 1.0
        print(optim.start_decay_at)

    optim.set_parameters(model.parameters())

    if opt.train_from or opt.train_from_state_dict:
        optim.optimizer.load_state_dict(checkpoint['optim'].optimizer.state_dict())

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)

    trainModel(model, trainData, validData, dataset, optim)
예제 #29
0
def main():

    print("Loading data from '%s'" % opt.data)
    dataset = torch.load(opt.data)

    dict_checkpoint = opt.train_from if opt.train_from else None

    if dict_checkpoint:
        print('Loading dicts from checkpoint at %s' % dict_checkpoint)
        checkpoint = torch.load(dict_checkpoint)
        dataset['dicts'] = checkpoint['dicts']

    trainData = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'],
                             opt.batch_size, opt.gpus)
    validData = onmt.Dataset(dataset['valid']['src'],
                             dataset['valid']['tgt'],
                             opt.batch_size,
                             opt.gpus,
                             volatile=True)

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (dicts['src'].size(), dicts['tgt'].size()))
    print(' * number of training sentences. %d' % len(dataset['train']['src']))
    print(' * maximum batch size. %d' % opt.batch_size)

    print('Building model...')

    encoder = onmt.Models.Encoder(opt, dicts['src'], opt.fix_src_emb)
    decoder = onmt.Models.Decoder(opt, dicts['tgt'], opt.tie_emb)

    output_dim = opt.output_emb_size

    if not opt.nonlin_gen:
        generator = nn.Sequential(nn.Linear(opt.rnn_size, output_dim))
    else:  #add a non-linear layer before generating the continuous vector
        generator = nn.Sequential(nn.Linear(opt.rnn_size, output_dim),
                                  nn.ReLU(), nn.Linear(output_dim, output_dim))

    #output is just an embedding
    target_embeddings = nn.Embedding(dicts['tgt'].size(), opt.output_emb_size)

    #normalize the embeddings
    norm = dicts['tgt'].embeddings.norm(p=2, dim=1,
                                        keepdim=True).clamp(min=1e-12)
    target_embeddings.weight.data.copy_(dicts['tgt'].embeddings.div(norm))

    #target embeddings are fixed and not trained
    target_embeddings.weight.requires_grad = False
    # elif opt.loss != "maxmargin": # with max-margin loss, the target embeddings can be fine-tuned as well.
    # target_embeddings.weight.requires_grad=False

    model = onmt.Models.NMTModel(encoder, decoder)

    if opt.train_from:
        print('Loading model from checkpoint at %s' % opt.train_from)
        generator_state_dict = checkpoint['generator']
        encoder_state_dict = [('encoder.' + k, v)
                              for k, v in checkpoint['encoder'].items()]
        decoder_state_dict = [('decoder.' + k, v)
                              for k, v in checkpoint['decoder'].items()]
        model_state_dict = dict(encoder_state_dict + decoder_state_dict)

        model.load_state_dict(model_state_dict)
        generator.load_state_dict(generator_state_dict)

        if not opt.train_anew:  #load from
            opt.start_epoch = checkpoint['epoch'] + 1

    if len(opt.gpus) >= 1:
        model.cuda()
        generator.cuda()
        target_embeddings.cuda()
    else:
        model.cpu()
        generator.cpu()
        target_embeddings.cpu()

    if len(opt.gpus) > 1:
        model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
        generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

    model.generator = generator

    if not opt.train_from:
        for p in model.parameters():
            p.data.uniform_(-opt.param_init, opt.param_init)

        encoder.load_pretrained_vectors(opt)
        decoder.load_pretrained_vectors(opt)

        if opt.tie_emb:
            decoder.tie_embeddings(target_embeddings)

        if opt.fix_src_emb:
            #fix and normalize the source embeddings
            source_embeddings = nn.Embedding(dicts['src'].size(),
                                             opt.output_emb_size)
            norm = dicts['src'].embeddings.norm(p=2, dim=1,
                                                keepdim=True).clamp(min=1e-12)
            source_embeddings.weight.data.copy_(
                dicts['src'].embeddings.div(norm))

            #turn this off to initialize embeddings as well as make them trainable
            source_embeddings.weight.requires_grad = False
            if len(opt.gpus) >= 1:
                source_embeddings.cuda()
            else:
                source_embeddings.cpu()
            encoder.fix_embeddings(source_embeddings)

        optim = onmt.Optim(opt.optim,
                           opt.learning_rate,
                           opt.max_grad_norm,
                           lr_decay=opt.learning_rate_decay,
                           start_decay_at=opt.start_decay_at)
    elif opt.train_anew:  #restart optimizer, sometimes useful for training with
        optim = onmt.Optim(opt.optim,
                           opt.learning_rate,
                           opt.max_grad_norm,
                           lr_decay=opt.learning_rate_decay,
                           start_decay_at=opt.start_decay_at)
    else:
        print('Loading optimizer from checkpoint:')
        optim = checkpoint['optim']
        print(optim)

    optim.set_parameters(model.parameters())

    if opt.train_from and not opt.train_anew:
        optim.optimizer.load_state_dict(
            checkpoint['optim'].optimizer.state_dict())

    nParams = sum(
        [p.nelement() for p in model.parameters() if p.requires_grad])
    print('* number of trainable parameters: %d' % nParams)

    trainModel(model, trainData, validData, dataset, target_embeddings, optim)
예제 #30
0
def main():
    if opt.data_format == 'raw':
        start = time.time()
        if opt.data.endswith(".train.pt"):
            print("Loading data from '%s'" % opt.data)
            dataset = torch.load(
                opt.data)  # This requires a lot of cpu memory!
        else:
            print("Loading data from %s" % opt.data + ".train.pt")
            dataset = torch.load(opt.data + ".train.pt")

        elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
        print("Done after %s" % elapse)

        train_data = onmt.Dataset(dataset['train']['src'],
                                  dataset['train']['tgt'],
                                  opt.batch_size_words,
                                  data_type=dataset.get("type", "text"),
                                  batch_size_sents=opt.batch_size_sents,
                                  multiplier=opt.batch_size_multiplier,
                                  reshape_speech=opt.reshape_speech,
                                  augment=opt.augment_speech)
        valid_data = onmt.Dataset(dataset['valid']['src'],
                                  dataset['valid']['tgt'],
                                  opt.batch_size_words,
                                  data_type=dataset.get("type", "text"),
                                  batch_size_sents=opt.batch_size_sents,
                                  reshape_speech=opt.reshape_speech)

        dicts = dataset['dicts']

        print(' * number of training sentences. %d' %
              len(dataset['train']['src']))
        print(' * maximum batch size (words per batch). %d' %
              opt.batch_size_words)

    elif opt.data_format == 'bin':

        from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset

        dicts = torch.load(opt.data + ".dict.pt")

        train_path = opt.data + '.train'
        train_src = IndexedInMemoryDataset(train_path + '.src')
        train_tgt = IndexedInMemoryDataset(train_path + '.tgt')

        train_data = onmt.Dataset(train_src,
                                  train_tgt,
                                  opt.batch_size_words,
                                  data_type=opt.encoder_type,
                                  batch_size_sents=opt.batch_size_sents,
                                  multiplier=opt.batch_size_multiplier)

        valid_path = opt.data + '.valid'
        valid_src = IndexedInMemoryDataset(valid_path + '.src')
        valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt')

        valid_data = onmt.Dataset(valid_src,
                                  valid_tgt,
                                  opt.batch_size_words,
                                  data_type=opt.encoder_type,
                                  batch_size_sents=opt.batch_size_sents)

    else:
        raise NotImplementedError

    additional_data = []
    if (opt.additional_data != "none"):
        add_data = opt.additional_data.split(";")
        add_format = opt.additional_data_format.split(";")
        assert (len(add_data) == len(add_format))
        for i in range(len(add_data)):
            if add_format[i] == 'raw':
                if add_data[i].endswith(".train.pt"):
                    print("Loading data from '%s'" % add_data[i])
                    add_dataset = torch.load(add_data[i])
                else:
                    print("Loading data from %s" % add_data[i] + ".train.pt")
                    add_dataset = torch.load(add_data[i] + ".train.pt")

                additional_data.append(
                    onmt.Dataset(add_dataset['train']['src'],
                                 add_dataset['train']['tgt'],
                                 opt.batch_size_words,
                                 data_type=add_dataset.get("type", "text"),
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier=opt.batch_size_multiplier,
                                 reshape_speech=opt.reshape_speech,
                                 augment=opt.augment_speech))
                add_dicts = add_dataset['dicts']

                for d in ['src', 'tgt']:
                    if (d in dicts):
                        if (d in add_dicts):
                            assert (dicts[d].size() == add_dicts[d].size())
                    else:
                        if (d in add_dicts):
                            dicts[d] = add_dicts[d]

            elif add_format[i] == 'bin':

                from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset

                train_path = add_data[i] + '.train'
                train_src = IndexedInMemoryDataset(train_path + '.src')
                train_tgt = IndexedInMemoryDataset(train_path + '.tgt')

                additional_data.append(
                    onmt.Dataset(train_src,
                                 train_tgt,
                                 opt.batch_size_words,
                                 data_type=opt.encoder_type,
                                 batch_size_sents=opt.batch_size_sents,
                                 multiplier=opt.batch_size_multiplier))

    # Restore from checkpoint
    if opt.load_from:
        checkpoint = torch.load(opt.load_from,
                                map_location=lambda storage, loc: storage)
        print("* Loading dictionaries from the checkpoint")
        dicts = checkpoint['dicts']
    else:
        dicts['tgt'].patch(opt.patch_vocab_multiplier)
        checkpoint = None

    if "src" in dicts:
        print(' * vocabulary size. source = %d; target = %d' %
              (dicts['src'].size(), dicts['tgt'].size()))
    else:
        print(' * vocabulary size. target = %d' % (dicts['tgt'].size()))

    print('Building model...')

    if not opt.fusion:
        model = build_model(opt, dicts)
        """ Building the loss function """
        if opt.ctc_loss != 0:
            loss_function = NMTAndCTCLossFunc(
                dicts['tgt'].size(),
                label_smoothing=opt.label_smoothing,
                ctc_weight=opt.ctc_loss)
        else:
            loss_function = NMTLossFunc(dicts['tgt'].size(),
                                        label_smoothing=opt.label_smoothing)
    else:
        from onmt.ModelConstructor import build_fusion
        from onmt.modules.Loss import FusionLoss

        model = build_fusion(opt, dicts)

        loss_function = FusionLoss(dicts['tgt'].size(),
                                   label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
        raise NotImplementedError(
            "Warning! Multi-GPU training is not fully tested and potential bugs can happen."
        )
    else:
        # if opt.fp16:
        #     trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt)
        # else:
        trainer = XETrainer(model, loss_function, train_data, valid_data,
                            dicts, opt)
        if (len(additional_data) > 0):
            trainer.add_additional_data(additional_data, opt.data_ratio)

    trainer.run(checkpoint=checkpoint)