def train(self, model, crit, train_loader, valid_loader): optimizer = optim.Adam(model.parameters()) trainer = Engine(Trainer.step) trainer.model, trainer.crit, trainer.optimizer = model, crit, optimizer evaluator = Engine(Trainer.validate) evaluator.model, evaluator.crit = model, crit evaluator.lowest_loss = np.inf Trainer.attach(trainer, evaluator, verbose=self.config.verbose) def run_validation(engine, evaluator, valid_loader): evaluator.run(valid_loader, max_epochs=1) trainer.add_event_handler(Events.EPOCH_COMPLETED, run_validation, evaluator, valid_loader) @evaluator.on(Events.EPOCH_COMPLETED) def check_loss(engine): from copy import deepcopy loss = float(engine.state.metrics['loss']) if loss <= engine.lowest_loss: engine.lowest_loss = loss engine.best_model = deepcopy(engine.model.state_dict()) trainer.run(train_loader, max_epochs=self.config.n_epochs) return evaluator.best_model
def train( self, model, crit, optimizer, train_loader, valid_loader, src_vocab, tgt_vocab, n_epochs, lr_scheduler=None ): if src_vocab is not None and tgt_vocab is not None: raise NotImplementedError('You should assign None one of vocab to designate target language.') if src_vocab is None: is_src_target = False elif tgt_vocab is None: is_src_target = True else: raise NotImplementedError('You cannot assign None both vocab.') trainer = Engine(self.step) trainer.config = self.config trainer.model, trainer.crit = model, crit trainer.optimizer, trainer.lr_scheduler = optimizer, lr_scheduler trainer.epoch_idx = 0 trainer.is_src_target = is_src_target evaluator = Engine(self.validate) evaluator.config = self.config evaluator.model, evaluator.crit = model, crit evaluator.best_loss = np.inf evaluator.is_src_target = is_src_target self.attach(trainer, evaluator, verbose=self.config.verbose) def run_validation(engine, evaluator, valid_loader): evaluator.run(valid_loader, max_epochs=1) if engine.lr_scheduler is not None: engine.lr_scheduler.step() trainer.add_event_handler( Events.EPOCH_COMPLETED, run_validation, evaluator, valid_loader ) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.check_best ) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.save_model, trainer, self.config, src_vocab, tgt_vocab, ) trainer.run(train_loader, max_epochs=n_epochs) if n_epochs > 0: model.load_state_dict(evaluator.best_model) return model
def train( self, model, crit, optimizer, train_loader, valid_loader, src_vocab, tgt_vocab, n_epochs, lr_scheduler=None ): trainer = Engine(self.step) trainer.config = self.config trainer.model, trainer.crit = model, crit trainer.optimizer, trainer.lr_scheduler = optimizer, lr_scheduler trainer.epoch_idx = 0 evaluator = Engine(self.validate) evaluator.config = self.config evaluator.model, evaluator.crit = model, crit evaluator.best_loss = np.inf self.attach(trainer, evaluator, verbose=self.config.verbose) def run_validation(engine, evaluator, valid_loader): evaluator.run(valid_loader, max_epochs=1) if engine.lr_scheduler is not None: engine.lr_scheduler.step() trainer.add_event_handler( Events.EPOCH_COMPLETED, run_validation, evaluator, valid_loader ) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.check_best ) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.save_model, trainer, self.config, src_vocab, tgt_vocab, ) trainer.run(train_loader, max_epochs=n_epochs) return model