def train(self, train_params: AttributeDict, loss_func, optimizer): # Merge common and train params params = AttributeDict(self.common_params.copy()) params.update(train_params) self._set_mode(Estimator.Mode.TRAIN) encoder_params = params.encoder_params decoder_params = params.decoder_params src_corpus_file_path = os.path.join(self.data_set_dir, params.src_corpus_filename) tgt_corpus_file_path = os.path.join(self.data_set_dir, params.tgt_corpus_filename) data_loader = self._prepare_data_loader(src_corpus_file_path, tgt_corpus_file_path, params, encoder_params.max_seq_len, decoder_params.max_seq_len) epoch = 0 avg_loss = 0. for epoch in range(params.n_epochs): avg_loss = self._train_model(data_loader, params, self.model, loss_func, optimizer, self.device, epoch + 1) save_dir_path = os.path.join(train_params.model_save_directory, get_checkpoint_dir_path(epoch + 1)) if not os.path.exists(save_dir_path): os.makedirs(save_dir_path) # save checkpoint for last epoch torch.save( { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_loss }, os.path.join(save_dir_path, 'checkpoint.tar'))
def eval(self, eval_params: AttributeDict, loss_func): self._set_mode(Estimator.Mode.EVAL) params = AttributeDict(self.common_params.copy()) params.update(eval_params) encoder_params = params.encoder_params decoder_params = params.decoder_params # load checkpoint checkpoint = self._load_checkpoint(params) self.model.load_state_dict(checkpoint['model_state_dict']) src_corpus_file_path = os.path.join(self.data_set_dir, params.src_corpus_filename) tgt_corpus_file_path = os.path.join(self.data_set_dir, params.tgt_corpus_filename) data_loader = self._prepare_data_loader(src_corpus_file_path, tgt_corpus_file_path, params, encoder_params.max_seq_len, decoder_params.max_seq_len) avg_loss, bleu_score = self._eval_model(data_loader, params, self.model, loss_func, self.device, self.tgt_id2word) print(f'Avg loss: {avg_loss:05.3f}, BLEU score: {bleu_score}')