def main(): start = time.time() print("Loading data from '%s'" % opt.data) if opt.data_format == 'raw': dataset = torch.load(opt.data) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse ) train_data = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier = opt.batch_size_multiplier, sort_by_target=opt.sort_by_target) valid_data = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents) dicts = dataset['dicts'] if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print(' * number of training sentences. %d' % train_data.size()) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) else: raise NotImplementedError print('Building model...') model = build_language_model(opt, dicts) """ Building the loss function """ loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.") else: if opt.fp16: trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt) else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) trainer.run(save_file=opt.load_from)
def main(): start = time.time() print("Loading data from '%s'" % opt.data) if opt.data_format == 'raw': dataset = torch.load(opt.data) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) train_data = LanguageModelDataset( dataset['train']['tgt'], batch_size_sents=opt.batch_size_sents, seq_length=opt.lm_seq_length) valid_data = LanguageModelDataset( dataset['valid']['tgt'], batch_size_sents=opt.batch_size_sents, seq_length=opt.lm_seq_length) dicts = dataset['dicts'] if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print(' * number of training sentences. %d' % train_data.size()) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) else: raise NotImplementedError print('Building model...') model = build_language_model(opt, dicts) print(model) """ Building the loss function """ loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError("Multi-GPU training is not supported ATM.") else: # if opt.fp16: # trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt) # else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) trainer.run(save_file=opt.load_from)
def main(): if opt.data_format == 'raw': start = time.time() if opt.data.endswith(".train.pt"): print("Loading data from '%s'" % opt.data) dataset = torch.load(opt.data) else: print("Loading data from %s" % opt.data + ".train.pt") dataset = torch.load(opt.data + ".train.pt") elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse ) train_data = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier = opt.batch_size_multiplier, reshape_speech=opt.reshape_speech, augment=opt.augment_speech) valid_data = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, reshape_speech=opt.reshape_speech) dicts = dataset['dicts'] if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset dicts = torch.load(opt.data + ".dict.pt") #~ train = {} train_path = opt.data + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') train_data = onmt.Dataset(train_src, train_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents, multiplier = opt.batch_size_multiplier) valid_path = opt.data + '.valid' valid_src = IndexedInMemoryDataset(valid_path + '.src') valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt') valid_data = onmt.Dataset(valid_src, valid_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents) else: raise NotImplementedError print('Building model...') if not opt.fusion: model = build_model(opt, dicts) """ Building the loss function """ if opt.ctc_loss != 0: loss_function = NMTAndCTCLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing,ctc_weight = opt.ctc_loss) else: loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) else: from onmt.ModelConstructor import build_fusion from onmt.modules.Loss import FusionLoss model = build_fusion(opt, dicts) loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.") else: if opt.fp16: trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt) else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) trainer.run(save_file=opt.load_from)
def main(): if opt.data_format == 'raw': start = time.time() if opt.data.endswith(".train.pt"): print("Loading data from '%s'" % opt.data) dataset = torch.load( opt.data) # This requires a lot of cpu memory! else: print("Loading data from %s" % opt.data + ".train.pt") dataset = torch.load(opt.data + ".train.pt") elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) train_data = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, reshape_speech=opt.reshape_speech, augment=opt.augment_speech) valid_data = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, reshape_speech=opt.reshape_speech) dicts = dataset['dicts'] print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset dicts = torch.load(opt.data + ".dict.pt") train_path = opt.data + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') train_data = onmt.Dataset(train_src, train_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier) valid_path = opt.data + '.valid' valid_src = IndexedInMemoryDataset(valid_path + '.src') valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt') valid_data = onmt.Dataset(valid_src, valid_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents) else: raise NotImplementedError additional_data = [] if (opt.additional_data != "none"): add_data = opt.additional_data.split(";") add_format = opt.additional_data_format.split(";") assert (len(add_data) == len(add_format)) for i in range(len(add_data)): if add_format[i] == 'raw': if add_data[i].endswith(".train.pt"): print("Loading data from '%s'" % add_data[i]) add_dataset = torch.load(add_data[i]) else: print("Loading data from %s" % add_data[i] + ".train.pt") add_dataset = torch.load(add_data[i] + ".train.pt") additional_data.append( onmt.Dataset(add_dataset['train']['src'], add_dataset['train']['tgt'], opt.batch_size_words, data_type=add_dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, reshape_speech=opt.reshape_speech, augment=opt.augment_speech)) add_dicts = add_dataset['dicts'] for d in ['src', 'tgt']: if (d in dicts): if (d in add_dicts): assert (dicts[d].size() == add_dicts[d].size()) else: if (d in add_dicts): dicts[d] = add_dicts[d] elif add_format[i] == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset train_path = add_data[i] + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') additional_data.append( onmt.Dataset(train_src, train_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier)) # Restore from checkpoint if opt.load_from: checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage) print("* Loading dictionaries from the checkpoint") dicts = checkpoint['dicts'] else: dicts['tgt'].patch(opt.patch_vocab_multiplier) checkpoint = None if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print('Building model...') if not opt.fusion: model = build_model(opt, dicts) """ Building the loss function """ if opt.ctc_loss != 0: loss_function = NMTAndCTCLossFunc( dicts['tgt'].size(), label_smoothing=opt.label_smoothing, ctc_weight=opt.ctc_loss) else: loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) else: from onmt.ModelConstructor import build_fusion from onmt.modules.Loss import FusionLoss model = build_fusion(opt, dicts) loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError( "Warning! Multi-GPU training is not fully tested and potential bugs can happen." ) else: # if opt.fp16: # trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt) # else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) if (len(additional_data) > 0): trainer.add_additional_data(additional_data, opt.data_ratio) trainer.run(checkpoint=checkpoint)
def main(): start = time.time() print("Loading data from '%s'" % opt.data) if opt.data_format == 'raw': dataset = torch.load(opt.data) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) trainData = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, opt.gpus, max_seq_num=opt.batch_size_sents, pad_count=opt.pad_count, multiplier=opt.batch_size_multiplier, sort_by_target=opt.sort_by_target) validData = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, opt.gpus, max_seq_num=opt.batch_size_sents) dicts = dataset['dicts'] print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset dicts = torch.load(opt.data + ".dict.pt") #~ train = {} train_path = opt.data + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') trainData = onmt.Dataset(train_src, train_tgt, opt.batch_size_words, opt.gpus, max_seq_num=opt.batch_size_sents, pad_count=opt.pad_count, multiplier=opt.batch_size_multiplier, sort_by_target=opt.sort_by_target) valid_path = opt.data + '.valid' valid_src = IndexedInMemoryDataset(valid_path + '.src') valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt') validData = onmt.Dataset(valid_src, valid_tgt, opt.batch_size_words, opt.gpus, max_seq_num=opt.batch_size_sents) else: raise NotImplementedError print('Building model...') model = build_model(opt, dicts) """ Building the loss function """ loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) nParams = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % nParams) optim = None if len(opt.gpus) > 1 or opt.virtual_gpu > 1: #~ trainer = MultiGPUXETrainer(model, loss_function, trainData, validData, dataset, opt) raise NotImplementedError( "Warning! Multi-GPU training is not fully tested and potential bugs can happen." ) else: if opt.fp16: trainer = FP16XETrainer(model, loss_function, trainData, validData, dicts, opt) else: trainer = XETrainer(model, loss_function, trainData, validData, dicts, opt) trainer.run(save_file=opt.load_from)
def main(): if not opt.multi_dataset: if opt.data_format in ['bin', 'raw']: start = time.time() if opt.data.endswith(".train.pt"): print("Loading data from '%s'" % opt.data) dataset = torch.load(opt.data) else: print("Loading data from %s" % opt.data + ".train.pt") dataset = torch.load(opt.data + ".train.pt") elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) dicts = dataset['dicts'] # For backward compatibility train_dict = defaultdict(lambda: None, dataset['train']) valid_dict = defaultdict(lambda: None, dataset['valid']) if train_dict['src_lang'] is not None: assert 'langs' in dicts train_src_langs = train_dict['src_lang'] train_tgt_langs = train_dict['tgt_lang'] else: # allocate new languages dicts['langs'] = {'src': 0, 'tgt': 1} train_src_langs = list() train_tgt_langs = list() # Allocation one for the bilingual case train_src_langs.append(torch.Tensor([dicts['langs']['src']])) train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) if not opt.streaming: train_data = onmt.Dataset( numpy_to_torch(train_dict['src']), numpy_to_torch(train_dict['tgt']), train_dict['src_sizes'], train_dict['tgt_sizes'], train_src_langs, train_tgt_langs, batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), sorting=True, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, augment=opt.augment_speech, upsampling=opt.upsampling, num_split=len(opt.gpus)) else: train_data = onmt.StreamDataset( train_dict['src'], train_dict['tgt'], train_src_langs, train_tgt_langs, batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), sorting=True, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, augment=opt.augment_speech, upsampling=opt.upsampling) if valid_dict['src_lang'] is not None: assert 'langs' in dicts valid_src_langs = valid_dict['src_lang'] valid_tgt_langs = valid_dict['tgt_lang'] else: # allocate new languages valid_src_langs = list() valid_tgt_langs = list() # Allocation one for the bilingual case valid_src_langs.append(torch.Tensor([dicts['langs']['src']])) valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) if not opt.streaming: valid_data = onmt.Dataset( numpy_to_torch(valid_dict['src']), numpy_to_torch(valid_dict['tgt']), valid_dict['src_sizes'], valid_dict['tgt_sizes'], valid_src_langs, valid_tgt_langs, batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), sorting=True, batch_size_sents=opt.batch_size_sents, upsampling=opt.upsampling, num_split=len(opt.gpus)) else: valid_data = onmt.StreamDataset( numpy_to_torch(valid_dict['src']), numpy_to_torch(valid_dict['tgt']), valid_src_langs, valid_tgt_langs, batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), sorting=True, batch_size_sents=opt.batch_size_sents, upsampling=opt.upsampling) print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format in ['scp', 'scpmem', 'mmem']: print("Loading memory mapped data files ....") start = time.time() from onmt.data.mmap_indexed_dataset import MMapIndexedDataset from onmt.data.scp_dataset import SCPIndexDataset dicts = torch.load(opt.data + ".dict.pt") if opt.data_format in ['scp', 'scpmem']: audio_data = torch.load(opt.data + ".scp_path.pt") # allocate languages if not if 'langs' not in dicts: dicts['langs'] = {'src': 0, 'tgt': 1} else: print(dicts['langs']) train_path = opt.data + '.train' if opt.data_format in ['scp', 'scpmem']: train_src = SCPIndexDataset(audio_data['train'], concat=opt.concat) else: train_src = MMapIndexedDataset(train_path + '.src') train_tgt = MMapIndexedDataset(train_path + '.tgt') # check the lang files if they exist (in the case of multi-lingual models) if os.path.exists(train_path + '.src_lang.bin'): assert 'langs' in dicts train_src_langs = MMapIndexedDataset(train_path + '.src_lang') train_tgt_langs = MMapIndexedDataset(train_path + '.tgt_lang') else: train_src_langs = list() train_tgt_langs = list() # Allocate a Tensor(1) for the bilingual case train_src_langs.append(torch.Tensor([dicts['langs']['src']])) train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) # check the length files if they exist if os.path.exists(train_path + '.src_sizes.npy'): train_src_sizes = np.load(train_path + '.src_sizes.npy') train_tgt_sizes = np.load(train_path + '.tgt_sizes.npy') else: train_src_sizes, train_tgt_sizes = None, None if opt.encoder_type == 'audio': data_type = 'audio' else: data_type = 'text' if not opt.streaming: train_data = onmt.Dataset( train_src, train_tgt, train_src_sizes, train_tgt_sizes, train_src_langs, train_tgt_langs, batch_size_words=opt.batch_size_words, data_type=data_type, sorting=True, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, src_align_right=opt.src_align_right, augment=opt.augment_speech, upsampling=opt.upsampling, cleaning=True, verbose=True, num_split=len(opt.gpus)) else: train_data = onmt.StreamDataset( train_src, train_tgt, train_src_langs, train_tgt_langs, batch_size_words=opt.batch_size_words, data_type=data_type, sorting=False, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, upsampling=opt.upsampling) valid_path = opt.data + '.valid' if opt.data_format in ['scp', 'scpmem']: valid_src = SCPIndexDataset(audio_data['valid'], concat=opt.concat) else: valid_src = MMapIndexedDataset(valid_path + '.src') valid_tgt = MMapIndexedDataset(valid_path + '.tgt') if os.path.exists(valid_path + '.src_lang.bin'): assert 'langs' in dicts valid_src_langs = MMapIndexedDataset(valid_path + '.src_lang') valid_tgt_langs = MMapIndexedDataset(valid_path + '.tgt_lang') else: valid_src_langs = list() valid_tgt_langs = list() # Allocation one for the bilingual case valid_src_langs.append(torch.Tensor([dicts['langs']['src']])) valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) # check the length files if they exist if os.path.exists(valid_path + '.src_sizes.npy'): valid_src_sizes = np.load(valid_path + '.src_sizes.npy') valid_tgt_sizes = np.load(valid_path + '.tgt_sizes.npy') else: valid_src_sizes, valid_tgt_sizes = None, None if not opt.streaming: valid_data = onmt.Dataset( valid_src, valid_tgt, valid_src_sizes, valid_tgt_sizes, valid_src_langs, valid_tgt_langs, batch_size_words=opt.batch_size_words, data_type=data_type, sorting=True, batch_size_sents=opt.batch_size_sents, src_align_right=opt.src_align_right, cleaning=True, verbose=True, debug=True, num_split=len(opt.gpus)) else: # for validation data, we have to go through sentences (very slow but to ensure correctness) valid_data = onmt.StreamDataset( valid_src, valid_tgt, valid_src_langs, valid_tgt_langs, batch_size_words=opt.batch_size_words, data_type=data_type, sorting=True, batch_size_sents=opt.batch_size_sents) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) else: raise NotImplementedError print(' * number of sentences in training data: %d' % train_data.size()) print(' * number of sentences in validation data: %d' % valid_data.size()) else: print("[INFO] Reading multiple dataset ...") # raise NotImplementedError dicts = torch.load(opt.data + ".dict.pt") root_dir = os.path.dirname(opt.data) print("Loading training data ...") train_dirs, valid_dirs = dict(), dict() # scan the data directory to find the training data for dir_ in os.listdir(root_dir): if os.path.isdir(os.path.join(root_dir, dir_)): if str(dir_).startswith("train"): idx = int(dir_.split(".")[1]) train_dirs[idx] = dir_ if dir_.startswith("valid"): idx = int(dir_.split(".")[1]) valid_dirs[idx] = dir_ train_sets, valid_sets = list(), list() for (idx_, dir_) in sorted(train_dirs.items()): data_dir = os.path.join(root_dir, dir_) print("[INFO] Loading training data %i from %s" % (idx_, dir_)) if opt.data_format in ['bin', 'raw']: raise NotImplementedError elif opt.data_format in ['scp', 'scpmem', 'mmem']: from onmt.data.mmap_indexed_dataset import MMapIndexedDataset from onmt.data.scp_dataset import SCPIndexDataset if opt.data_format in ['scp', 'scpmem']: audio_data = torch.load( os.path.join(data_dir, "data.scp_path.pt")) src_data = SCPIndexDataset(audio_data, concat=opt.concat) else: src_data = MMapIndexedDataset( os.path.join(data_dir, "data.src")) tgt_data = MMapIndexedDataset( os.path.join(data_dir, "data.tgt")) src_lang_data = MMapIndexedDataset( os.path.join(data_dir, 'data.src_lang')) tgt_lang_data = MMapIndexedDataset( os.path.join(data_dir, 'data.tgt_lang')) if os.path.exists(os.path.join(data_dir, 'data.src_sizes.npy')): src_sizes = np.load( os.path.join(data_dir, 'data.src_sizes.npy')) tgt_sizes = np.load( os.path.join(data_dir, 'data.tgt_sizes.npy')) else: src_sizes, sizes = None, None if opt.encoder_type == 'audio': data_type = 'audio' else: data_type = 'text' if not opt.streaming: train_data = onmt.Dataset( src_data, tgt_data, src_sizes, tgt_sizes, src_lang_data, tgt_lang_data, batch_size_words=opt.batch_size_words, data_type=data_type, sorting=True, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, src_align_right=opt.src_align_right, augment=opt.augment_speech, upsampling=opt.upsampling, cleaning=True, verbose=True, num_split=len(opt.gpus)) train_sets.append(train_data) else: print("Multi-dataset not implemented for Streaming tasks.") raise NotImplementedError for (idx_, dir_) in sorted(valid_dirs.items()): data_dir = os.path.join(root_dir, dir_) print("[INFO] Loading validation data %i from %s" % (idx_, dir_)) if opt.data_format in ['bin', 'raw']: raise NotImplementedError elif opt.data_format in ['scp', 'scpmem', 'mmem']: if opt.data_format in ['scp', 'scpmem']: audio_data = torch.load( os.path.join(data_dir, "data.scp_path.pt")) src_data = SCPIndexDataset(audio_data, concat=opt.concat) else: src_data = MMapIndexedDataset( os.path.join(data_dir, "data.src")) tgt_data = MMapIndexedDataset( os.path.join(data_dir, "data.tgt")) src_lang_data = MMapIndexedDataset( os.path.join(data_dir, 'data.src_lang')) tgt_lang_data = MMapIndexedDataset( os.path.join(data_dir, 'data.tgt_lang')) if os.path.exists(os.path.join(data_dir, 'data.src_sizes.npy')): src_sizes = np.load( os.path.join(data_dir, 'data.src_sizes.npy')) tgt_sizes = np.load( os.path.join(data_dir, 'data.tgt_sizes.npy')) else: src_sizes, sizes = None, None if opt.encoder_type == 'audio': data_type = 'audio' else: data_type = 'text' if not opt.streaming: valid_data = onmt.Dataset( src_data, tgt_data, src_sizes, tgt_sizes, src_lang_data, tgt_lang_data, batch_size_words=opt.batch_size_words, data_type=data_type, sorting=True, batch_size_sents=opt.batch_size_sents, src_align_right=opt.src_align_right, cleaning=True, verbose=True, debug=True, num_split=len(opt.gpus)) valid_sets.append(valid_data) else: raise NotImplementedError train_data = train_sets valid_data = valid_sets if opt.load_from: checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage) print("* Loading dictionaries from the checkpoint") dicts = checkpoint['dicts'] else: dicts['tgt'].patch(opt.patch_vocab_multiplier) checkpoint = None # Put the vocab mask from dicts to the datasets for data in [train_data, valid_data]: if isinstance(data, list): for i, data_ in enumerate(data): data_.set_mask(dicts['tgt'].vocab_mask) data[i] = data_ else: data.set_mask(dicts['tgt'].vocab_mask) if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print('[INFO] vocabulary size. target = %d' % (dicts['tgt'].size())) print('* Building model...') if not opt.fusion: if opt.bayes_by_backprop: model = build_bayesian_model(opt, dicts) else: model = build_model(opt, dicts) """ Building the loss function """ # if opt.ctc_loss != 0: # pass # loss_function = NMTAndCTCLossFunc(dicts['tgt'].size(), # label_smoothing=opt.label_smoothing, # ctc_weight=opt.ctc_loss) if opt.nce: from onmt.modules.nce.nce_loss import NCELoss loss_function = NCELoss(opt.model_size, dicts['tgt'].size(), noise_ratio=opt.nce_noise, logz=9, label_smoothing=opt.label_smoothing) else: loss_function = NMTLossFunc(opt.model_size, dicts['tgt'].size(), label_smoothing=opt.label_smoothing, mirror=opt.mirror_loss, fast_xentropy=opt.fast_xentropy) # This function replaces modules with the more optimized counterparts so that it can run faster # Currently exp with LayerNorm if not opt.memory_profiling: optimize_model(model, fp16=opt.fp16) else: from onmt.model_factory import build_fusion from onmt.modules.loss import FusionLoss model = build_fusion(opt, dicts) loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if not opt.debugging and len(opt.gpus) == 1: if opt.bayes_by_backprop: from onmt.train_utils.bayes_by_backprop_trainer import BayesianTrainer trainer = BayesianTrainer(model, loss_function, train_data, valid_data, dicts, opt) else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) else: from onmt.train_utils.new_trainer import Trainer trainer = Trainer(model, loss_function, train_data, valid_data, dicts, opt) trainer.run(checkpoint=checkpoint)
def main(): start = time.time() print("Loading data from '%s'" % opt.data) dataset = torch.load(opt.data) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) #~ dict_checkpoint = opt.load_from #~ if dict_checkpoint: #~ print('Loading dicts from checkpoint at %s' % dict_checkpoint) #~ checkpoint = torch.load(dict_checkpoint, map_location=lambda storage, loc: storage) #~ dataset['dicts'] = checkpoint['dicts'] #~ else: #~ checkpoint = None trainData = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, opt.gpus, data_type=dataset.get("type", "text"), max_seq_num=opt.batch_size_sents) validData = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, opt.gpus, volatile=True, data_type=dataset.get("type", "text"), max_seq_num=opt.batch_size_sents) dicts = dataset['dicts'] print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) print('Building model...') model = build_model(opt, dicts) """ Building the loss function """ loss_function = NMTLossFunc(dataset['dicts']['tgt'].size(), label_smoothing=opt.label_smoothing, shard_size=opt.max_generator_batches) #~ print(model) #~ print(loss_function) nParams = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % nParams) optim = None if len(opt.gpus) > 1 or opt.virtual_gpu > 1: trainer = MultiGPUXETrainer(model, loss_function, trainData, validData, dataset, opt) print( "Warning! Multi-GPU training is used. Not fully tested and potential bugs can happen." ) else: trainer = XETrainer(model, loss_function, trainData, validData, dataset, opt) trainer.run(save_file=opt.load_from)
def main(): if opt.data_format == 'raw': start = time.time() if opt.data.endswith(".train.pt"): print("Loading data from '%s'" % opt.data) dataset = torch.load(opt.data) else: print("Loading data from %s" % opt.data + ".train.pt") dataset = torch.load(opt.data + ".train.pt") elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse ) # For backward compatibility train_dict = defaultdict(lambda: None, dataset['train']) valid_dict = defaultdict(lambda: None, dataset['valid']) train_data = onmt.Dataset(train_dict['src'], train_dict['tgt'], train_dict['src_atbs'], train_dict['tgt_atbs'], batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, augment=opt.augment_speech, upsampling=opt.upsampling) valid_data = onmt.Dataset(valid_dict['src'], valid_dict['tgt'], valid_dict['src_atbs'], valid_dict['tgt_atbs'], batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, upsampling=opt.upsampling) dicts = dataset['dicts'] print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format == 'bin': print("Loading memory binned data files ....") start = time.time() from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset dicts = torch.load(opt.data + ".dict.pt") train_path = opt.data + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') train_data = onmt.Dataset(train_src, train_tgt, batch_size_words=opt.batch_size_words, data_type="text", batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier) valid_path = opt.data + '.valid' valid_src = IndexedInMemoryDataset(valid_path + '.src') valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt') valid_data = onmt.Dataset(valid_src, valid_tgt, batch_size_words=opt.batch_size_words, data_type="text", batch_size_sents=opt.batch_size_sents) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) elif opt.data_format == 'mmem': print("Loading memory mapped data files ....") start = time.time() from onmt.data_utils.MMapIndexedDataset import MMapIndexedDataset dicts = torch.load(opt.data + ".dict.pt") train_path = opt.data + '.train' train_src = MMapIndexedDataset(train_path + '.src') train_tgt = MMapIndexedDataset(train_path + '.tgt') train_data = onmt.Dataset(train_src, train_tgt, batch_size_words=opt.batch_size_words, data_type="text", batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier) valid_path = opt.data + '.valid' valid_src = MMapIndexedDataset(valid_path + '.src') valid_tgt = MMapIndexedDataset(valid_path + '.tgt') valid_data = onmt.Dataset(valid_src, valid_tgt, batch_size_words=opt.batch_size_words, data_type="text", batch_size_sents=opt.batch_size_sents) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) else: raise NotImplementedError additional_data = [] if opt.additional_data != "none": add_data = opt.additional_data.split(";") add_format = opt.additional_data_format.split(";") assert(len(add_data) == len(add_format)) for i in range(len(add_data)): if add_format[i] == 'raw': if add_data[i].endswith(".train.pt"): print("Loading data from '%s'" % opt.data) add_dataset = torch.load(add_data[i]) else: print("Loading data from %s" % opt.data + ".train.pt") add_dataset = torch.load(add_data[i] + ".train.pt") additional_data.append(onmt.Dataset(add_dataset['train']['src'], dataset['train']['tgt'], batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, reshape_speech=opt.reshape_speech, augment=opt.augment_speech)) elif add_format[i] == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset train_path = add_data[i] + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') additional_data.append(onmt.Dataset(train_src, train_tgt, batch_size_words=opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents, multiplier = opt.batch_size_multiplier)) if opt.load_from: checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage) print("* Loading dictionaries from the checkpoint") dicts = checkpoint['dicts'] else: dicts['tgt'].patch(opt.patch_vocab_multiplier) checkpoint = None if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print('Building model...') if not opt.fusion: if opt.bert_scalar and opt.finetune_bert: print("WARNING: we only fine tune bert, we don't finetune scalar parameters, please set opt.bert_scalar False") print("Using scalared bert vector: ", opt.bert_scalar) print("Using Bert+Transformer to finetuning : ", opt.finetune_bert) model = build_model(opt, dicts) if not opt.finetune_bert: for param in model.bert.parameters(): param.requires_grad = False if not opt.finetune_bert and opt.bert_scalar: scalar_mix = ScalarMix( onmt.Constants.BERT_LAYERS, do_layer_norm=False, initial_scalar_parameters=None, trainable=True, ) model.add_module("scalar_mix", scalar_mix) print(model) # for name, param in model.bert_model.named_parameters(): # print(name, param, param.requires_grad) # the params in bert_model which require gradient: # if param.requires_grad: # print(name) """ Building the loss function """ if opt.ctc_loss != 0: loss_function = NMTAndCTCLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing, ctc_weight=opt.ctc_loss) else: loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) else: from onmt.ModelConstructor import build_fusion from onmt.modules.Loss import FusionLoss model = build_fusion(opt, dicts) loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of all parameters: %d' % n_params) n_params_grad = sum([p.nelement() for p in model.parameters() if p.requires_grad == True]) print('* number of all parameters that need gradient: %d' % n_params_grad) n_params_nograd = sum([p.nelement() for p in model.parameters() if p.requires_grad == False]) print('* number of all parameters that do not need gradient: %d' % n_params_nograd) assert n_params == (n_params_grad + n_params_nograd) # print(model) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.") else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt,setup_optimizer=True) if len(additional_data) > 0: trainer.add_additional_data(additional_data, opt.data_ratio) trainer.run(checkpoint=checkpoint)
def main(): if opt.data_format in ['bin', 'raw']: start = time.time() if opt.data.endswith(".train.pt"): print("Loading data from '%s'" % opt.data) dataset = torch.load(opt.data) else: print("Loading data from %s" % opt.data + ".train.pt") dataset = torch.load(opt.data + ".train.pt") elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) dicts = dataset['dicts'] # For backward compatibility train_dict = defaultdict(lambda: None, dataset['train']) valid_dict = defaultdict(lambda: None, dataset['valid']) if train_dict['src_lang'] is not None: assert 'langs' in dicts train_src_langs = train_dict['src_lang'] train_tgt_langs = train_dict['tgt_lang'] else: # allocate new languages dicts['langs'] = {'src': 0, 'tgt': 1} train_src_langs = list() train_tgt_langs = list() # Allocation one for the bilingual case train_src_langs.append(torch.Tensor([dicts['langs']['src']])) train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) if not opt.streaming: train_data = onmt.Dataset(train_dict['src'], train_dict['tgt'], train_src_langs, train_tgt_langs, batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), sorting=True, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, augment=opt.augment_speech, upsampling=opt.upsampling) else: train_data = onmt.StreamDataset( train_dict['src'], train_dict['tgt'], train_src_langs, train_tgt_langs, batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), sorting=True, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, augment=opt.augment_speech, upsampling=opt.upsampling) if valid_dict['src_lang'] is not None: assert 'langs' in dicts valid_src_langs = valid_dict['src_lang'] valid_tgt_langs = valid_dict['tgt_lang'] else: # allocate new languages valid_src_langs = list() valid_tgt_langs = list() # Allocation one for the bilingual case valid_src_langs.append(torch.Tensor([dicts['langs']['src']])) valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) if not opt.streaming: valid_data = onmt.Dataset(valid_dict['src'], valid_dict['tgt'], valid_src_langs, valid_tgt_langs, batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), sorting=True, batch_size_sents=opt.batch_size_sents, upsampling=opt.upsampling) else: valid_data = onmt.StreamDataset( valid_dict['src'], valid_dict['tgt'], valid_src_langs, valid_tgt_langs, batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), sorting=True, batch_size_sents=opt.batch_size_sents, upsampling=opt.upsampling) print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format == 'mmem': print("Loading memory mapped data files ....") start = time.time() from onmt.data.mmap_indexed_dataset import MMapIndexedDataset dicts = torch.load(opt.data + ".dict.pt") # allocate languages if not if 'langs' not in dicts: dicts['langs'] = {'src': 0, 'tgt': 1} else: print(dicts['langs']) train_path = opt.data + '.train' train_src = MMapIndexedDataset(train_path + '.src') train_tgt = MMapIndexedDataset(train_path + '.tgt') # check the lang files if they exist (in the case of multi-lingual models) if os.path.exists(train_path + '.src_lang.bin'): assert 'langs' in dicts train_src_langs = MMapIndexedDataset(train_path + '.src_lang') train_tgt_langs = MMapIndexedDataset(train_path + '.tgt_lang') else: train_src_langs = list() train_tgt_langs = list() # Allocate a Tensor(1) for the bilingual case train_src_langs.append(torch.Tensor([dicts['langs']['src']])) train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) if opt.encoder_type == 'audio': data_type = 'audio' else: data_type = 'text' if not opt.streaming: train_data = onmt.Dataset(train_src, train_tgt, train_src_langs, train_tgt_langs, batch_size_words=opt.batch_size_words, data_type=data_type, sorting=True, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, src_align_right=opt.src_align_right, upsampling=opt.upsampling, cleaning=True, verbose=True) else: train_data = onmt.StreamDataset( train_src, train_tgt, train_src_langs, train_tgt_langs, batch_size_words=opt.batch_size_words, data_type=data_type, sorting=False, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, upsampling=opt.upsampling) valid_path = opt.data + '.valid' valid_src = MMapIndexedDataset(valid_path + '.src') valid_tgt = MMapIndexedDataset(valid_path + '.tgt') if os.path.exists(valid_path + '.src_lang.bin'): assert 'langs' in dicts valid_src_langs = MMapIndexedDataset(valid_path + '.src_lang') valid_tgt_langs = MMapIndexedDataset(valid_path + '.tgt_lang') else: valid_src_langs = list() valid_tgt_langs = list() # Allocation one for the bilingual case valid_src_langs.append(torch.Tensor([dicts['langs']['src']])) valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) if not opt.streaming: valid_data = onmt.Dataset(valid_src, valid_tgt, valid_src_langs, valid_tgt_langs, batch_size_words=opt.batch_size_words, data_type="text", sorting=False, batch_size_sents=opt.batch_size_sents, src_align_right=opt.src_align_right, cleaning=True, verbose=True) else: # for validation data, we have to go through sentences (very slow but to ensure correctness) valid_data = onmt.StreamDataset( valid_src, valid_tgt, valid_src_langs, valid_tgt_langs, batch_size_words=opt.batch_size_words, data_type="text", sorting=True, batch_size_sents=opt.batch_size_sents) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) else: raise NotImplementedError # additional_data = [] # if opt.additional_data != "none": # add_data = opt.additional_data.split(";") # add_format = opt.additional_data_format.split(";") # assert (len(add_data) == len(add_format)) # for i in range(len(add_data)): # if add_format[i] == 'raw': # if add_data[i].endswith(".train.pt"): # print("Loading data from '%s'" % opt.data) # add_dataset = torch.load(add_data[i]) # else: # print("Loading data from %s" % opt.data + ".train.pt") # add_dataset = torch.load(add_data[i] + ".train.pt") # # additional_data.append(onmt.Dataset(add_dataset['train']['src'], # dataset['train']['tgt'], batch_size_words=opt.batch_size_words, # data_type=dataset.get("type", "text"), sorting=True, # batch_size_sents=opt.batch_size_sents, # multiplier=opt.batch_size_multiplier, # reshape_speech=opt.reshape_speech, # augment=opt.augment_speech)) # elif add_format[i] == 'bin': # # from onmt.data.indexed_dataset import IndexedInMemoryDataset # # train_path = add_data[i] + '.train' # train_src = IndexedInMemoryDataset(train_path + '.src') # train_tgt = IndexedInMemoryDataset(train_path + '.tgt') # # additional_data.append(onmt.Dataset(train_src, # train_tgt, # batch_size_words=opt.batch_size_words, # data_type=opt.encoder_type, # batch_size_sents=opt.batch_size_sents, # multiplier=opt.batch_size_multiplier)) if opt.load_from: checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage) print("* Loading dictionaries from the checkpoint") dicts = checkpoint['dicts'] else: dicts['tgt'].patch(opt.patch_vocab_multiplier) checkpoint = None if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print(' * number of sentences in training data: %d' % train_data.size()) print(' * number of sentences in validation data: %d' % valid_data.size()) print('* Building model...') if not opt.fusion: model = build_model(opt, dicts) """ Building the loss function """ if opt.ctc_loss != 0: loss_function = NMTAndCTCLossFunc( dicts['tgt'].size(), label_smoothing=opt.label_smoothing, ctc_weight=opt.ctc_loss) else: loss_function = NMTLossFunc(opt.model_size, dicts['tgt'].size(), label_smoothing=opt.label_smoothing, mirror=opt.mirror_loss) # This function replaces modules with the more optimized counterparts so that it can run faster # Currently exp with LayerNorm optimize_model(model) else: from onmt.model_factory import build_fusion from onmt.modules.loss import FusionLoss model = build_fusion(opt, dicts) loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError( "Multi-GPU training is not supported at the moment.") else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) trainer.run(checkpoint=checkpoint)
def main(): start = time.time() print("Loading data from '%s'" % opt.data) if opt.data_format == 'raw': dataset = torch.load(opt.data) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) dicts = dataset['dicts'] # For backward compatibility train_dict = defaultdict(lambda: None, dataset['train']) valid_dict = defaultdict(lambda: None, dataset['valid']) if train_dict['src_lang'] is not None: assert 'langs' in dicts train_src_langs = train_dict['src_lang'] train_tgt_langs = train_dict['tgt_lang'] else: # allocate new languages dicts['langs'] = {'src': 0, 'tgt': 1} train_src_langs = list() train_tgt_langs = list() # Allocation one for the bilingual case train_src_langs.append(torch.Tensor([dicts['langs']['src']])) train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) train_data = LanguageModelDataset( dataset['train']['tgt'], train_tgt_langs, batch_size_sents=opt.batch_size_sents, seq_length=opt.lm_seq_length) if valid_dict['src_lang'] is not None: assert 'langs' in dicts valid_src_langs = valid_dict['src_lang'] valid_tgt_langs = valid_dict['tgt_lang'] else: # allocate new languages valid_src_langs = list() valid_tgt_langs = list() # Allocation one for the bilingual case valid_src_langs.append(torch.Tensor([dicts['langs']['src']])) valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']])) valid_data = LanguageModelDataset( dataset['valid']['tgt'], valid_tgt_langs, batch_size_sents=opt.batch_size_sents, seq_length=opt.lm_seq_length) if opt.load_from: checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage) print("* Loading dictionaries from the checkpoint") dicts = checkpoint['dicts'] else: dicts['tgt'].patch(opt.patch_vocab_multiplier) checkpoint = None if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print(' * number of training sentences. %d' % train_data.size()) print(' * maximum batch size (words per batch). %d' % (opt.batch_size_sents * opt.lm_seq_length)) else: raise NotImplementedError print('Building model...') model = build_language_model(opt, dicts) optimize_model(model) """ Building the loss function """ loss_function = NMTLossFunc(opt.model_size, dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError("Multi-GPU training is not supported ATM.") else: # if opt.fp16: # trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt) # else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) trainer.run(checkpoint=checkpoint)