def main(opt, device_id): # opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: raise Exception('You need to load a model') logger.info('Loading data from %s' % opt.data) dataset = next(lazily_load_dataset("train", opt)) data_type = dataset.data_type logger.info('Data type %s' % data_type) # Load fields generated from preprocess phase. fields = _load_fields(dataset, data_type, opt, checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) dataset_iter = build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) out_file = codecs.open(opt.output, 'w+', 'utf-8') scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta, opt.coverage_penalty, opt.length_penalty) translation_builder = TranslationBuilder(dataset, fields, n_best=opt.n_best, replace_unk=opt.replace_unk, has_tgt=False) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) translator = Translator(trainer.model, fields, opt.beam_size, global_scorer=scorer, out_file=out_file, report_score=False, copy_attn=model_opt.copy_attn, logger=logger) for i, batch in enumerate(dataset_iter): unprocessed_translations = translator.translate_batch(batch, dataset) translations = translation_builder.from_batch(unprocessed_translations) print "Translations: ", ' '.join(translations[0].pred_sents[0]) trainer.train_from_data(batch, train_steps=1) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features data_keys = [f"src.{src_type}" for src_type in opt.src_types] + ["tgt"] for side in data_keys: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. logger.info('Building model...') if opt.type_append: model = MultiSourceS2STypeAppendedModelBuilder.build_model( opt.src_types, model_opt, opt, fields, checkpoint) else: model = MultiSourceModelBuilder.build_model(opt.src_types, model_opt, opt, fields, checkpoint) logger.info(model) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = MultiSourceModelSaver.build_model_saver( opt.src_types, model_opt, opt, model, fields, optim) if opt.consist_reg: trainer = MultiSourceCRTrainer.build_trainer(opt.src_types, opt, device_id, model, fields, optim, model_saver=model_saver) elif opt.type_append: trainer = MultiSourceTypeAppendedTrainer.build_trainer( opt.src_types, opt, device_id, model, fields, optim, model_saver=model_saver) else: trainer = MultiSourceTrainer.build_trainer(opt.src_types, opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = MultiSourceInputter.build_dataset_iter_multiple( opt.src_types, train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = MultiSourceInputter.build_dataset_iter( opt.src_types, shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() valid_iter = MultiSourceInputter.build_dataset_iter(opt.src_types, "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def train_single(self, output_model_dir: Path, opt, device_id, batch_queue=None, semaphore=None): from roosterize.ml.onmt.MultiSourceInputter import MultiSourceInputter from roosterize.ml.onmt.MultiSourceModelBuilder import MultiSourceModelBuilder from roosterize.ml.onmt.MultiSourceModelSaver import MultiSourceModelSaver from roosterize.ml.onmt.MultiSourceTrainer import MultiSourceTrainer from onmt.inputters.inputter import load_old_vocab, old_style_vocab from onmt.train_single import configure_process, _tally_parameters, _check_save_model_path from onmt.utils.optimizers import Optimizer from onmt.utils.parse import ArgumentParser configure_process(opt, device_id) assert len(opt.accum_count) == len( opt.accum_steps ), 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: self.logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) self.logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # end if # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # end if # Report src and tgt vocab sizes, including for features data_keys = [ f"src.{src_type}" for src_type in self.config.get_src_types() ] + ["tgt"] for side in data_keys: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] # end try for sn, sf in f_iter: if sf.use_vocab: self.logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # end for # Build model model = MultiSourceModelBuilder.build_model( self.config.get_src_types(), model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) self.logger.info('encoder: %d' % enc) self.logger.info('decoder: %d' % dec) self.logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = MultiSourceModelSaver.build_model_saver( self.config.get_src_types(), model_opt, opt, model, fields, optim) trainer = MultiSourceTrainer.build_trainer(self.config.get_src_types(), opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) # end for train_iter = MultiSourceInputter.build_dataset_iter_multiple( self.config.get_src_types(), train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" # end if train_iter = MultiSourceInputter.build_dataset_iter( self.config.get_src_types(), shard_base, fields, opt) # end if else: assert semaphore is not None, "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch # end while # end def train_iter = _train_iter() # end if valid_iter = MultiSourceInputter.build_dataset_iter( self.config.get_src_types(), "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): self.logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: self.logger.info('Starting training on CPU, could be very slow') # end if train_steps = opt.train_steps if opt.single_pass and train_steps > 0: self.logger.warning( "Option single_pass is enabled, ignoring train_steps.") train_steps = 0 # end if trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) time_begin = trainer.report_manager.start_time time_end = time.time() if opt.tensorboard: trainer.report_manager.tensorboard_writer.close() # Dump train metrics train_history = trainer.report_manager.get_joint_history() train_metrics = { "time_begin": time_begin, "time_end": time_end, "time": time_end - time_begin, "train_history": train_history, } IOUtils.dump(output_model_dir / "train-metrics.json", train_metrics, IOUtils.Format.jsonNoSort) # Get the best step, depending on the lowest val_xent (cross entropy) best_loss = min([th["val_xent"] for th in train_history]) best_step = [ th["step"] for th in train_history if th["val_xent"] == best_loss ][-1] # Take the last if multiple IOUtils.dump(output_model_dir / "best-step.json", best_step, IOUtils.Format.json) return
def validate(opt, device_id=0): configure_process(opt, device_id) configure_process if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) #_check_save_model_path(opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) tgt_field = dict(fields)["tgt"].base_field valid_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt, train=False) model.eval() with torch.no_grad(): stats = onmt.utils.Statistics() for batch in valid_iter: src, src_lengths = batch.src if isinstance(batch.src, tuple) \ else (batch.src, None) tgt = batch.tgt # F-prop through the model. outputs, attns = model(src, tgt, src_lengths) # Compute loss. _, batch_stats = valid_loss(batch, outputs, attns) # Update statistics. stats.update(batch_stats) print('n words: %d' % stats.n_words) print('Validation perplexity: %g' % stats.ppl()) print('Validation accuracy: %g' % stats.accuracy()) print('Validation avg attention entropy: %g' % stats.attn_entropy())