def main(): start = time.time() print("Loading data from '%s'" % opt.data) if opt.data_format == 'raw': dataset = torch.load(opt.data) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse ) train_data = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier = opt.batch_size_multiplier, sort_by_target=opt.sort_by_target) valid_data = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents) dicts = dataset['dicts'] if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print(' * number of training sentences. %d' % train_data.size()) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) else: raise NotImplementedError print('Building model...') model = build_language_model(opt, dicts) """ Building the loss function """ loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.") else: if opt.fp16: trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt) else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) trainer.run(save_file=opt.load_from)
def main(): start = time.time() print("Loading data from '%s'" % opt.data) if opt.data_format == 'raw': dataset = torch.load(opt.data) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) train_data = LanguageModelDataset( dataset['train']['tgt'], batch_size_sents=opt.batch_size_sents, seq_length=opt.lm_seq_length) valid_data = LanguageModelDataset( dataset['valid']['tgt'], batch_size_sents=opt.batch_size_sents, seq_length=opt.lm_seq_length) dicts = dataset['dicts'] if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print(' * number of training sentences. %d' % train_data.size()) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) else: raise NotImplementedError print('Building model...') model = build_language_model(opt, dicts) print(model) """ Building the loss function """ loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError("Multi-GPU training is not supported ATM.") else: # if opt.fp16: # trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt) # else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) trainer.run(save_file=opt.load_from)
def main(): if opt.data_format == 'raw': start = time.time() if opt.data.endswith(".train.pt"): print("Loading data from '%s'" % opt.data) dataset = torch.load(opt.data) else: print("Loading data from %s" % opt.data + ".train.pt") dataset = torch.load(opt.data + ".train.pt") elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse ) train_data = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier = opt.batch_size_multiplier, reshape_speech=opt.reshape_speech, augment=opt.augment_speech) valid_data = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, reshape_speech=opt.reshape_speech) dicts = dataset['dicts'] if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset dicts = torch.load(opt.data + ".dict.pt") #~ train = {} train_path = opt.data + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') train_data = onmt.Dataset(train_src, train_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents, multiplier = opt.batch_size_multiplier) valid_path = opt.data + '.valid' valid_src = IndexedInMemoryDataset(valid_path + '.src') valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt') valid_data = onmt.Dataset(valid_src, valid_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents) else: raise NotImplementedError print('Building model...') if not opt.fusion: model = build_model(opt, dicts) """ Building the loss function """ if opt.ctc_loss != 0: loss_function = NMTAndCTCLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing,ctc_weight = opt.ctc_loss) else: loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) else: from onmt.ModelConstructor import build_fusion from onmt.modules.Loss import FusionLoss model = build_fusion(opt, dicts) loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.") else: if opt.fp16: trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt) else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) trainer.run(save_file=opt.load_from)
def main(): if opt.data_format == 'raw': start = time.time() if opt.data.endswith(".train.pt"): print("Loading data from '%s'" % opt.data) dataset = torch.load( opt.data) # This requires a lot of cpu memory! else: print("Loading data from %s" % opt.data + ".train.pt") dataset = torch.load(opt.data + ".train.pt") elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) train_data = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, reshape_speech=opt.reshape_speech, augment=opt.augment_speech) valid_data = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, reshape_speech=opt.reshape_speech) dicts = dataset['dicts'] print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset dicts = torch.load(opt.data + ".dict.pt") train_path = opt.data + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') train_data = onmt.Dataset(train_src, train_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier) valid_path = opt.data + '.valid' valid_src = IndexedInMemoryDataset(valid_path + '.src') valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt') valid_data = onmt.Dataset(valid_src, valid_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents) else: raise NotImplementedError additional_data = [] if (opt.additional_data != "none"): add_data = opt.additional_data.split(";") add_format = opt.additional_data_format.split(";") assert (len(add_data) == len(add_format)) for i in range(len(add_data)): if add_format[i] == 'raw': if add_data[i].endswith(".train.pt"): print("Loading data from '%s'" % add_data[i]) add_dataset = torch.load(add_data[i]) else: print("Loading data from %s" % add_data[i] + ".train.pt") add_dataset = torch.load(add_data[i] + ".train.pt") additional_data.append( onmt.Dataset(add_dataset['train']['src'], add_dataset['train']['tgt'], opt.batch_size_words, data_type=add_dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, reshape_speech=opt.reshape_speech, augment=opt.augment_speech)) add_dicts = add_dataset['dicts'] for d in ['src', 'tgt']: if (d in dicts): if (d in add_dicts): assert (dicts[d].size() == add_dicts[d].size()) else: if (d in add_dicts): dicts[d] = add_dicts[d] elif add_format[i] == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset train_path = add_data[i] + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') additional_data.append( onmt.Dataset(train_src, train_tgt, opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier)) # Restore from checkpoint if opt.load_from: checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage) print("* Loading dictionaries from the checkpoint") dicts = checkpoint['dicts'] else: dicts['tgt'].patch(opt.patch_vocab_multiplier) checkpoint = None if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print('Building model...') if not opt.fusion: model = build_model(opt, dicts) """ Building the loss function """ if opt.ctc_loss != 0: loss_function = NMTAndCTCLossFunc( dicts['tgt'].size(), label_smoothing=opt.label_smoothing, ctc_weight=opt.ctc_loss) else: loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) else: from onmt.ModelConstructor import build_fusion from onmt.modules.Loss import FusionLoss model = build_fusion(opt, dicts) loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % n_params) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError( "Warning! Multi-GPU training is not fully tested and potential bugs can happen." ) else: # if opt.fp16: # trainer = FP16XETrainer(model, loss_function, train_data, valid_data, dicts, opt) # else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt) if (len(additional_data) > 0): trainer.add_additional_data(additional_data, opt.data_ratio) trainer.run(checkpoint=checkpoint)
def main(): start = time.time() print("Loading data from '%s'" % opt.data) if opt.data_format == 'raw': dataset = torch.load(opt.data) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) trainData = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, opt.gpus, max_seq_num=opt.batch_size_sents, pad_count=opt.pad_count, multiplier=opt.batch_size_multiplier, sort_by_target=opt.sort_by_target) validData = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, opt.gpus, max_seq_num=opt.batch_size_sents) dicts = dataset['dicts'] print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset dicts = torch.load(opt.data + ".dict.pt") #~ train = {} train_path = opt.data + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') trainData = onmt.Dataset(train_src, train_tgt, opt.batch_size_words, opt.gpus, max_seq_num=opt.batch_size_sents, pad_count=opt.pad_count, multiplier=opt.batch_size_multiplier, sort_by_target=opt.sort_by_target) valid_path = opt.data + '.valid' valid_src = IndexedInMemoryDataset(valid_path + '.src') valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt') validData = onmt.Dataset(valid_src, valid_tgt, opt.batch_size_words, opt.gpus, max_seq_num=opt.batch_size_sents) else: raise NotImplementedError print('Building model...') model = build_model(opt, dicts) """ Building the loss function """ loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) nParams = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % nParams) optim = None if len(opt.gpus) > 1 or opt.virtual_gpu > 1: #~ trainer = MultiGPUXETrainer(model, loss_function, trainData, validData, dataset, opt) raise NotImplementedError( "Warning! Multi-GPU training is not fully tested and potential bugs can happen." ) else: if opt.fp16: trainer = FP16XETrainer(model, loss_function, trainData, validData, dicts, opt) else: trainer = XETrainer(model, loss_function, trainData, validData, dicts, opt) trainer.run(save_file=opt.load_from)
def main(): start = time.time() print("Loading data from '%s'" % opt.data) dataset = torch.load(opt.data) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) #~ dict_checkpoint = opt.load_from #~ if dict_checkpoint: #~ print('Loading dicts from checkpoint at %s' % dict_checkpoint) #~ checkpoint = torch.load(dict_checkpoint, map_location=lambda storage, loc: storage) #~ dataset['dicts'] = checkpoint['dicts'] #~ else: #~ checkpoint = None trainData = onmt.Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size_words, opt.gpus, data_type=dataset.get("type", "text"), max_seq_num=opt.batch_size_sents) validData = onmt.Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size_words, opt.gpus, volatile=True, data_type=dataset.get("type", "text"), max_seq_num=opt.batch_size_sents) dicts = dataset['dicts'] print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) print('Building model...') model = build_model(opt, dicts) """ Building the loss function """ loss_function = NMTLossFunc(dataset['dicts']['tgt'].size(), label_smoothing=opt.label_smoothing, shard_size=opt.max_generator_batches) #~ print(model) #~ print(loss_function) nParams = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % nParams) optim = None if len(opt.gpus) > 1 or opt.virtual_gpu > 1: trainer = MultiGPUXETrainer(model, loss_function, trainData, validData, dataset, opt) print( "Warning! Multi-GPU training is used. Not fully tested and potential bugs can happen." ) else: trainer = XETrainer(model, loss_function, trainData, validData, dataset, opt) trainer.run(save_file=opt.load_from)
def main(): if opt.data_format == 'raw': start = time.time() if opt.data.endswith(".train.pt"): print("Loading data from '%s'" % opt.data) dataset = torch.load(opt.data) else: print("Loading data from %s" % opt.data + ".train.pt") dataset = torch.load(opt.data + ".train.pt") elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse ) # For backward compatibility train_dict = defaultdict(lambda: None, dataset['train']) valid_dict = defaultdict(lambda: None, dataset['valid']) train_data = onmt.Dataset(train_dict['src'], train_dict['tgt'], train_dict['src_atbs'], train_dict['tgt_atbs'], batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, augment=opt.augment_speech, upsampling=opt.upsampling) valid_data = onmt.Dataset(valid_dict['src'], valid_dict['tgt'], valid_dict['src_atbs'], valid_dict['tgt_atbs'], batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, upsampling=opt.upsampling) dicts = dataset['dicts'] print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size (words per batch). %d' % opt.batch_size_words) elif opt.data_format == 'bin': print("Loading memory binned data files ....") start = time.time() from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset dicts = torch.load(opt.data + ".dict.pt") train_path = opt.data + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') train_data = onmt.Dataset(train_src, train_tgt, batch_size_words=opt.batch_size_words, data_type="text", batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier) valid_path = opt.data + '.valid' valid_src = IndexedInMemoryDataset(valid_path + '.src') valid_tgt = IndexedInMemoryDataset(valid_path + '.tgt') valid_data = onmt.Dataset(valid_src, valid_tgt, batch_size_words=opt.batch_size_words, data_type="text", batch_size_sents=opt.batch_size_sents) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) elif opt.data_format == 'mmem': print("Loading memory mapped data files ....") start = time.time() from onmt.data_utils.MMapIndexedDataset import MMapIndexedDataset dicts = torch.load(opt.data + ".dict.pt") train_path = opt.data + '.train' train_src = MMapIndexedDataset(train_path + '.src') train_tgt = MMapIndexedDataset(train_path + '.tgt') train_data = onmt.Dataset(train_src, train_tgt, batch_size_words=opt.batch_size_words, data_type="text", batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier) valid_path = opt.data + '.valid' valid_src = MMapIndexedDataset(valid_path + '.src') valid_tgt = MMapIndexedDataset(valid_path + '.tgt') valid_data = onmt.Dataset(valid_src, valid_tgt, batch_size_words=opt.batch_size_words, data_type="text", batch_size_sents=opt.batch_size_sents) elapse = str(datetime.timedelta(seconds=int(time.time() - start))) print("Done after %s" % elapse) else: raise NotImplementedError additional_data = [] if opt.additional_data != "none": add_data = opt.additional_data.split(";") add_format = opt.additional_data_format.split(";") assert(len(add_data) == len(add_format)) for i in range(len(add_data)): if add_format[i] == 'raw': if add_data[i].endswith(".train.pt"): print("Loading data from '%s'" % opt.data) add_dataset = torch.load(add_data[i]) else: print("Loading data from %s" % opt.data + ".train.pt") add_dataset = torch.load(add_data[i] + ".train.pt") additional_data.append(onmt.Dataset(add_dataset['train']['src'], dataset['train']['tgt'], batch_size_words=opt.batch_size_words, data_type=dataset.get("type", "text"), batch_size_sents=opt.batch_size_sents, multiplier=opt.batch_size_multiplier, reshape_speech=opt.reshape_speech, augment=opt.augment_speech)) elif add_format[i] == 'bin': from onmt.data_utils.IndexedDataset import IndexedInMemoryDataset train_path = add_data[i] + '.train' train_src = IndexedInMemoryDataset(train_path + '.src') train_tgt = IndexedInMemoryDataset(train_path + '.tgt') additional_data.append(onmt.Dataset(train_src, train_tgt, batch_size_words=opt.batch_size_words, data_type=opt.encoder_type, batch_size_sents=opt.batch_size_sents, multiplier = opt.batch_size_multiplier)) if opt.load_from: checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage) print("* Loading dictionaries from the checkpoint") dicts = checkpoint['dicts'] else: dicts['tgt'].patch(opt.patch_vocab_multiplier) checkpoint = None if "src" in dicts: print(' * vocabulary size. source = %d; target = %d' % (dicts['src'].size(), dicts['tgt'].size())) else: print(' * vocabulary size. target = %d' % (dicts['tgt'].size())) print('Building model...') if not opt.fusion: if opt.bert_scalar and opt.finetune_bert: print("WARNING: we only fine tune bert, we don't finetune scalar parameters, please set opt.bert_scalar False") print("Using scalared bert vector: ", opt.bert_scalar) print("Using Bert+Transformer to finetuning : ", opt.finetune_bert) model = build_model(opt, dicts) if not opt.finetune_bert: for param in model.bert.parameters(): param.requires_grad = False if not opt.finetune_bert and opt.bert_scalar: scalar_mix = ScalarMix( onmt.Constants.BERT_LAYERS, do_layer_norm=False, initial_scalar_parameters=None, trainable=True, ) model.add_module("scalar_mix", scalar_mix) print(model) # for name, param in model.bert_model.named_parameters(): # print(name, param, param.requires_grad) # the params in bert_model which require gradient: # if param.requires_grad: # print(name) """ Building the loss function """ if opt.ctc_loss != 0: loss_function = NMTAndCTCLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing, ctc_weight=opt.ctc_loss) else: loss_function = NMTLossFunc(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) else: from onmt.ModelConstructor import build_fusion from onmt.modules.Loss import FusionLoss model = build_fusion(opt, dicts) loss_function = FusionLoss(dicts['tgt'].size(), label_smoothing=opt.label_smoothing) n_params = sum([p.nelement() for p in model.parameters()]) print('* number of all parameters: %d' % n_params) n_params_grad = sum([p.nelement() for p in model.parameters() if p.requires_grad == True]) print('* number of all parameters that need gradient: %d' % n_params_grad) n_params_nograd = sum([p.nelement() for p in model.parameters() if p.requires_grad == False]) print('* number of all parameters that do not need gradient: %d' % n_params_nograd) assert n_params == (n_params_grad + n_params_nograd) # print(model) if len(opt.gpus) > 1 or opt.virtual_gpu > 1: raise NotImplementedError("Warning! Multi-GPU training is not fully tested and potential bugs can happen.") else: trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt,setup_optimizer=True) if len(additional_data) > 0: trainer.add_additional_data(additional_data, opt.data_ratio) trainer.run(checkpoint=checkpoint)