def train_model(model, train_dataset, valid_dataset, fields, text_model, text_train_dataset, text_valid_dataset, text_fields, discrim_models, discrim_optims, gen_optims, optim, text_optim, model_opt): train_iter = make_train_data_iter(train_dataset, opt) valid_iter = make_valid_data_iter(valid_dataset, opt) text_train_iter = make_train_data_iter(text_train_dataset, opt) text_valid_iter = make_valid_data_iter(text_valid_dataset, opt) train_loss = make_loss_compute(model, fields["tgt"].vocab, train_dataset, opt) text_loss = make_loss_compute(model, fields["tgt"].vocab, train_dataset, opt, True) valid_loss = make_loss_compute(model, fields["tgt"].vocab, valid_dataset, opt) text_valid_loss = make_loss_compute(model, fields["tgt"].vocab, text_valid_dataset, opt, True) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches data_type = train_dataset.data_type trainer = onmt.AudioTextTrainerAdv( model, train_iter, valid_iter, text_model, text_train_iter, text_valid_iter, train_loss, text_loss, valid_loss, text_valid_loss, optim, text_optim, discrim_models, [model_opt.gen_label, model_opt.gen_label], gen_optims, model_opt.gen_lambda, trunc_size, shard_size, data_type, model_opt.mult) train_iter = make_train_data_iter(train_dataset, opt, 32) text_train_iter = make_train_data_iter(text_train_dataset, opt, 32) discrim_trainer = onmt.DiscrimTrainer(discrim_models, [train_iter, text_train_iter], [valid_iter, text_valid_iter], discrim_optims, [0.1, 0.9], shard_size) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') if epoch > 1: src_valid_stats, tgt_valid_stats, st_valid_stats = discrim_trainer.validate( ) print('(before) Discrim validation src loss: %g' % src_valid_stats.loss) print('(before) Discrim validation src/tgt loss: %g' % st_valid_stats.loss) print('(before) Discrim validation tgt loss: %g' % tgt_valid_stats.loss) src_train_stats, tgt_train_stats = discrim_trainer.train( epoch, discrim_report_func) print('Discrim src loss: %g' % src_train_stats.loss) print('Discrim tgt loss: %g' % tgt_train_stats.loss) src_valid_stats, tgt_valid_stats, st_valid_stats = discrim_trainer.validate( ) print('(after) Discrim validation src loss: %g' % src_valid_stats.loss) print('(after) Discrim validation src/tgt loss: %g' % st_valid_stats.loss) print('(after) Discrim validation tgt loss: %g' % tgt_valid_stats.loss) # 1. Train for one epoch on the training set. train_stats, text_train_stats = trainer.train(epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) print('Text perplexity: %g' % text_train_stats.ppl()) print('Text accuracy: %g' % text_train_stats.accuracy()) # 2. Validate on the validation set. valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) text_valid_stats = trainer.validate_text() print('Text Validation perplexity: %g' % text_valid_stats.ppl()) print('Text Validation accuracy: %g' % text_valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats) discrim_trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
def train_model(auto_models, valid_model, train_data, valid_data, fields_list, valid_fields, optims, discrim_models, discrim_optims, labels): # train_model(models, valid_model, train, valid, fields, fields_valid, optims, # discrim_models, discrim_optims, advers_optims, labels) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches valid_iter = make_valid_data_iter(valid_data, opt) valid_loss = make_loss_compute(valid_model, valid_fields["tgt"].vocab, valid_data, opt) trainers = [] trainers.append( onmt.Trainer(valid_model, valid_iter, valid_iter, valid_loss, valid_loss, optims[0], trunc_size, shard_size)) for model, discrim_model, optim, label, train, fields in zip( auto_models, discrim_models, optims, labels, train_data, fields_list): train_iter = make_train_data_iter(train, opt) train_loss = make_loss_compute(model, fields["tgt"].vocab, train, opt) trainers.append( onmt.AdvTrainer(model, discrim_model, train_iter, valid_iter, train_loss, valid_loss, optim, label, trunc_size, shard_size)) discrim_trainers = [] for model, discrim_optim, data in zip(discrim_models, discrim_optims, train_data): train_iter = make_train_data_iter(data, opt) discrim_trainers.append( onmt.DiscrimTrainer(model, train_iter, discrim_optim, shard_size)) #for model, optim, data in zip(discrim_models, advers_optims, advers_data): # train_iter = make_train_data_iter(data, opt) # discrim_trainers.append(onmt.DiscrimTrainer(model, train_iter, optim, shard_size)) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') for label, trainer in zip(labels, discrim_trainers): # 1. Train for one epoch on the training set. train_stats = trainer.train(epoch, label, discrim_report_func) print('Train loss: %g' % train_stats.loss) #print('Train accuracy: %g' % train_stats.accuracy()) if opt.exp_host: train_stats.log("train", experiment, optim.lr) for trainer in trainers[1:]: # 1. Train for one epoch on the training set. train_stats = trainer.train(epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) if opt.exp_host: train_stats.log("train", experiment, optim.lr) # 2. Validate on the validation set. valid_stats = trainers[0].validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: valid_stats.log("valid", experiment, optim.lr) ''' for trainer in trainers[1:]: # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) for trainer in discrim_trainers: # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) ''' # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainers[0].drop_checkpoint(opt, epoch, fields, valid_stats)
def train_model(model, train_dataset, valid_dataset, fields, text_model, text_train_dataset, text_valid_dataset, text_fields, speech_model, speech_train_dataset, discrim_models, discrim_optims, optim, adv_optim, speech_optim, model_opt, big_text): train_iter = make_train_data_iter(train_dataset, opt) valid_iter = make_valid_data_iter(valid_dataset, opt) text_valid_iter = make_valid_data_iter(text_valid_dataset, opt) text_train_iter = make_train_data_iter(text_train_dataset, opt) try: text_train_iter.dR = model_opt.delete_rate except: text_train_iter.dR = 0.2 speech_train_iter = make_train_data_iter(speech_train_dataset, opt) train_loss = make_loss_compute(model, fields["tgt"].vocab, train_dataset, opt) text_loss = make_loss_compute(model, fields["tgt"].vocab, train_dataset, opt, True) valid_loss = make_loss_compute(model, fields["tgt"].vocab, valid_dataset, opt) text_valid_loss = make_loss_compute(model, fields["tgt"].vocab, text_valid_dataset, opt) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches data_type = train_dataset.data_type speech_lambda = opt.auto_lambda if model_opt.weighted: speech_lambda = speech_lambda/float(opt.mult) print "label:", model_opt.gen_label try: print model_opt.unsup except: model_opt.unsup = False if opt.no_adv: discrim_models = [None,None] try: if opt.feature_match: trainer = onmt.AudioTextSpeechTrainerAdvFMatch(model, train_iter, valid_iter, text_model, text_train_iter, text_valid_iter, speech_model, speech_train_iter, train_loss, text_loss, valid_loss, text_valid_loss, optim, adv_optim, speech_optim, discrim_models, [model_opt.gen_label, model_opt.gen_label], model_opt.gen_lambda, speech_lambda, trunc_size, shard_size, data_type, model_opt.mult, model_opt.t_mult, model_opt.unsup, big_text=big_text) else: trainer = onmt.AudioTextSpeechTrainerAdv(model, train_iter, valid_iter, text_model, text_train_iter, text_valid_iter, speech_model, speech_train_iter, train_loss, text_loss, valid_loss, text_valid_loss, optim, adv_optim, speech_optim, discrim_models, [model_opt.gen_label, model_opt.gen_label], model_opt.gen_lambda, speech_lambda, trunc_size, shard_size, data_type, model_opt.mult, model_opt.t_mult, model_opt.unsup, big_text=big_text) except: trainer = onmt.AudioTextSpeechTrainerAdv(model, train_iter, valid_iter, text_model, text_train_iter, text_valid_iter, speech_model, speech_train_iter, train_loss, text_loss, valid_loss, text_valid_loss, optim, adv_optim, speech_optim, discrim_models, [model_opt.gen_label, model_opt.gen_label], model_opt.gen_lambda, speech_lambda, trunc_size, shard_size, data_type, model_opt.mult, model_opt.t_mult, model_opt.unsup, big_text=big_text) if opt.ff_speech_decoder: trainer.ff = True if not opt.no_adv: speech_train_iter = make_train_data_iter(speech_train_dataset, opt, 32) text_train_iter = make_train_data_iter(text_train_dataset, opt, 32) try: text_train_iter.dR = model_opt.delete_rate except: text_train_iter.dR = 0.2 discrim_trainer = onmt.DiscrimTrainer(discrim_models, [speech_train_iter, text_train_iter], [valid_iter, text_valid_iter], discrim_optims, [0.1, 0.9], shard_size, big_text) if model_opt.unsup: override = 50 print("OVERRIDE: " + str(override)) else: #override = -1 override = 2000 print("OVERRIDE: " + str(override)) try: model_opt.start_mask = max(0, model_opt.start_mask) except AttributeError: model_opt.start_mask = 0 model_opt.end_mask = 0 advOnly = False if big_text: nText = len(glob.glob(opt.text_data + '.train.[0-9]*.pt')) text_idx = opt.start_epoch % nText print "idx:", text_idx, opt.start_epoch, nText for epoch in range(opt.start_epoch, opt.epochs + 1): print('') if big_text: try: text_train_dataset = torch.load(opt.text_data + '.train.' + str(text_idx) + '.pt') except: text_train_dataset = torch.load(opt.text_data + '.train.1.pt') text_idx = 1 text_train_dataset.fields = text_fields print "LOADED BIG TEXT:", text_idx text_idx += 1 if not opt.no_adv: if epoch > 1: src_valid_stats, tgt_valid_stats, st_valid_stats = discrim_trainer.validate() print('(before) Discrim validation src loss: %g' % src_valid_stats.loss) print('(before) Discrim validation src/tgt loss: %g' % st_valid_stats.loss) print('(before) Discrim validation tgt loss: %g' % tgt_valid_stats.loss) if big_text: text_train_iter = make_train_data_iter(text_train_dataset, opt, 32) text_train_iter.dR = model_opt.delete_rate src_train_stats, tgt_train_stats = discrim_trainer.train(epoch, discrim_report_func, text=text_train_iter, startMask=model_opt.start_mask, endMask=model_opt.end_mask) #, override) else: src_train_stats, tgt_train_stats = discrim_trainer.train(epoch, discrim_report_func, startMask=model_opt.start_mask, endMask=model_opt.end_mask) #, override) print('Discrim src loss: %g' % src_train_stats.loss) print('Discrim tgt loss: %g' % tgt_train_stats.loss) src_valid_stats, tgt_valid_stats, st_valid_stats = discrim_trainer.validate() print('(after) Discrim validation src loss: %g' % src_valid_stats.loss) print('(after) Discrim validation src/tgt loss: %g' % st_valid_stats.loss) print('(after) Discrim validation tgt loss: %g' % tgt_valid_stats.loss) # 1. Train for one epoch on the training set. if big_text: text_train_iter = make_train_data_iter(text_train_dataset, opt) text_train_iter.dR = model_opt.delete_rate train_stats, text_train_stats, speech_train_stats, discrim_train_stats = trainer.train(epoch, report_func, override, text=text_train_iter, startMask=model_opt.start_mask, endMask=model_opt.end_mask, advOnly=advOnly) else: train_stats, text_train_stats, speech_train_stats, discrim_train_stats = trainer.train(epoch, report_func, override, startMask=model_opt.start_mask, endMask=model_opt.end_mask, advOnly=advOnly) if not opt.unsup and not advOnly: print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) if not advOnly: print('Text perplexity: %g' % text_train_stats.ppl()) print('Text accuracy: %g' % text_train_stats.accuracy()) try: print('Speech MSE: %g' % speech_train_stats.loss) except: pass try: print('Discrim Loss: %g' % discrim_train_stats.loss) except: pass # 2. Validate on the validation set. valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) text_valid_stats = trainer.validate_text() print('Text validation perplexity: %g' % text_valid_stats.ppl()) print('Text validation accuracy: %g' % text_valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats) if not opt.no_adv: discrim_trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
def train_model(auto_models, valid_model, train_data, valid_data, fields_list, valid_fields, optims, discrim_models, discrim_optims, labels, advers_optims): # train_model(models, valid_model, train, valid, fields, fields_valid, optims, # discrim_models, discrim_optims, advers_optims, labels) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches valid_iter = make_valid_data_iter(valid_data, opt) valid_loss = make_loss_compute(valid_model, valid_fields["tgt"].vocab, valid_data, opt) valid_trainer = onmt.Trainer(valid_model, valid_iter, valid_iter, valid_loss, valid_loss, optims[0], trunc_size, shard_size) src_train_iter = make_train_data_iter(train_data[0], opt) src_train_loss = make_loss_compute(auto_models[0], fields_list[0]["tgt"].vocab, train_data[0], opt) src_trainer = onmt.Trainer(auto_models[0], src_train_iter, valid_iter, src_train_loss, valid_loss, optims[0], trunc_size, shard_size) tgt_train_iter = make_train_data_iter(train_data[1], opt) tgt_train_loss = make_loss_compute(auto_models[1], fields_list[1]["tgt"].vocab, train_data[1], opt) tgt_trainer = onmt.Trainer(auto_models[1], tgt_train_iter, valid_iter, tgt_train_loss, valid_loss, optims[1], trunc_size, shard_size) src_train_iter = make_train_data_iter(train_data[0], opt) tgt_train_iter = make_train_data_iter(train_data[1], opt) src_train_loss = make_loss_compute(auto_models[0], fields_list[0]["tgt"].vocab, train_data[0], opt) tgt_train_loss = make_loss_compute(auto_models[1], fields_list[1]["tgt"].vocab, train_data[1], opt) unsup_trainer = onmt.UnsupTrainer( auto_models, [None, None], discrim_models, [src_train_iter, tgt_train_iter], valid_iter, [src_train_loss, tgt_train_loss], [None, None], valid_loss, [optims[0], None], [None, None], [0.9, 0.9], trunc_size, shard_size) src_train_iter = make_train_data_iter(train_data[0], opt) tgt_train_iter = make_train_data_iter(train_data[1], opt) discrim_trainer = onmt.DiscrimTrainer(discrim_models, [src_train_iter, tgt_train_iter], discrim_optims, labels, shard_size) src_train_iter = make_train_data_iter(train_data[0], opt) tgt_train_iter = make_train_data_iter(train_data[1], opt) advers_trainer = onmt.DiscrimTrainer(discrim_models, [src_train_iter, tgt_train_iter], advers_optims, [0.9, 0.9], shard_size) ''' for epoch in range(10): train_stats = discrim_trainer.train(epoch, discrim_report_func) print('Discrim Train loss: %g' % train_stats.loss) for epoch in range(10): train_stats = discrim_trainer.train(epoch, discrim_report_func) print('Discrim Train loss: %g' % train_stats.loss) train_stats = advers_trainer.train(epoch, discrim_report_func) print('Advers Train loss: %g' % train_stats.loss) ''' for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_stats = src_trainer.train(epoch, discrim_report_func) print('SRC Train perplexity: %g' % train_stats.ppl()) print('SRC Train accuracy: %g' % train_stats.accuracy()) # 1. Train for one epoch on the training set. train_stats = tgt_trainer.train(epoch, discrim_report_func) print('TGT Train perplexity: %g' % train_stats.ppl()) print('TGT Train accuracy: %g' % train_stats.accuracy()) if opt.exp_host: train_stats.log("train", experiment, optim.lr) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_stats = src_trainer.train(epoch, discrim_report_func) print('SRC Train perplexity: %g' % train_stats.ppl()) print('SRC Train accuracy: %g' % train_stats.accuracy()) train_stats = discrim_trainer.train(epoch, discrim_report_func) print('Discrim Train loss: %g' % train_stats.loss) train_stats = advers_trainer.train(epoch, discrim_report_func) print('Advers Train loss: %g' % train_stats.loss) ''' train_stats = unsup_trainer.train(epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) ''' # 2. Validate on the validation set. valid_stats = valid_trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: valid_stats.log("valid", experiment, optim.lr) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: valid_trainer.drop_checkpoint(opt, epoch, valid_fields, valid_stats)