def train( self, model, crit, optimizer, train_loader, valid_loader, src_vocab, tgt_vocab, n_epochs, lr_scheduler=None ): if src_vocab is not None and tgt_vocab is not None: raise NotImplementedError('You should assign None one of vocab to designate target language.') if src_vocab is None: is_src_target = False elif tgt_vocab is None: is_src_target = True else: raise NotImplementedError('You cannot assign None both vocab.') trainer = Engine(self.step) trainer.config = self.config trainer.model, trainer.crit = model, crit trainer.optimizer, trainer.lr_scheduler = optimizer, lr_scheduler trainer.epoch_idx = 0 trainer.is_src_target = is_src_target evaluator = Engine(self.validate) evaluator.config = self.config evaluator.model, evaluator.crit = model, crit evaluator.best_loss = np.inf evaluator.is_src_target = is_src_target self.attach(trainer, evaluator, verbose=self.config.verbose) def run_validation(engine, evaluator, valid_loader): evaluator.run(valid_loader, max_epochs=1) if engine.lr_scheduler is not None: engine.lr_scheduler.step() trainer.add_event_handler( Events.EPOCH_COMPLETED, run_validation, evaluator, valid_loader ) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.check_best ) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.save_model, trainer, self.config, src_vocab, tgt_vocab, ) trainer.run(train_loader, max_epochs=n_epochs) if n_epochs > 0: model.load_state_dict(evaluator.best_model) return model
def train(self, models, language_models, crits, optimizers, train_loader, valid_loader, vocabs, n_epochs, lr_schedulers=None): trainer = Engine(self.step) trainer.config = self.config trainer.models, trainer.crits = models, crits trainer.optimizers, trainer.lr_schedulers = optimizers, lr_schedulers trainer.language_models = language_models trainer.epoch_idx = 0 evaluator = Engine(self.validate) evaluator.config = self.config evaluator.models, evaluator.crits = models, crits evaluator.best_x2y, evaluator.best_y2x = np.inf, np.inf self.attach(trainer, evaluator, verbose=self.config.verbose) def run_validation(engine, evaluator, valid_loader): evaluator.run(valid_loader, max_epochs=1) if engine.lr_schedulers is not None: for s in engine.lr_schedulers: s.step() trainer.add_event_handler(Events.EPOCH_COMPLETED, run_validation, evaluator, valid_loader) evaluator.add_event_handler(Events.EPOCH_COMPLETED, self.check_best) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.save_model, trainer, self.config, vocabs, ) trainer.run(train_loader, max_epochs=n_epochs) return models
def train( self, model, crit, optimizer, train_loader, valid_loader, src_vocab, tgt_vocab, n_epochs, lr_scheduler=None ): trainer = Engine(self.step) trainer.config = self.config trainer.model, trainer.crit = model, crit trainer.optimizer, trainer.lr_scheduler = optimizer, lr_scheduler trainer.epoch_idx = 0 evaluator = Engine(self.validate) evaluator.config = self.config evaluator.model, evaluator.crit = model, crit evaluator.best_loss = np.inf self.attach(trainer, evaluator, verbose=self.config.verbose) def run_validation(engine, evaluator, valid_loader): evaluator.run(valid_loader, max_epochs=1) if engine.lr_scheduler is not None: engine.lr_scheduler.step() trainer.add_event_handler( Events.EPOCH_COMPLETED, run_validation, evaluator, valid_loader ) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.check_best ) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.save_model, trainer, self.config, src_vocab, tgt_vocab, ) trainer.run(train_loader, max_epochs=n_epochs) return model