Exemplo n.º 1
0
    def train(self, train_params: AttributeDict, loss_func, optimizer):
        # Merge common and train params
        params = AttributeDict(self.common_params.copy())
        params.update(train_params)
        self._set_mode(Estimator.Mode.TRAIN)

        encoder_params = params.encoder_params
        decoder_params = params.decoder_params

        src_corpus_file_path = os.path.join(self.data_set_dir,
                                            params.src_corpus_filename)
        tgt_corpus_file_path = os.path.join(self.data_set_dir,
                                            params.tgt_corpus_filename)

        data_loader = self._prepare_data_loader(src_corpus_file_path,
                                                tgt_corpus_file_path, params,
                                                encoder_params.max_seq_len,
                                                decoder_params.max_seq_len)

        epoch = 0
        avg_loss = 0.
        for epoch in range(params.n_epochs):
            avg_loss = self._train_model(data_loader, params, self.model,
                                         loss_func, optimizer, self.device,
                                         epoch + 1)

        save_dir_path = os.path.join(train_params.model_save_directory,
                                     get_checkpoint_dir_path(epoch + 1))
        if not os.path.exists(save_dir_path):
            os.makedirs(save_dir_path)

        # save checkpoint for last epoch
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss
            }, os.path.join(save_dir_path, 'checkpoint.tar'))
Exemplo n.º 2
0
    def eval(self, eval_params: AttributeDict, loss_func):
        self._set_mode(Estimator.Mode.EVAL)
        params = AttributeDict(self.common_params.copy())
        params.update(eval_params)
        encoder_params = params.encoder_params
        decoder_params = params.decoder_params

        # load checkpoint
        checkpoint = self._load_checkpoint(params)
        self.model.load_state_dict(checkpoint['model_state_dict'])

        src_corpus_file_path = os.path.join(self.data_set_dir,
                                            params.src_corpus_filename)
        tgt_corpus_file_path = os.path.join(self.data_set_dir,
                                            params.tgt_corpus_filename)

        data_loader = self._prepare_data_loader(src_corpus_file_path,
                                                tgt_corpus_file_path, params,
                                                encoder_params.max_seq_len,
                                                decoder_params.max_seq_len)
        avg_loss, bleu_score = self._eval_model(data_loader, params,
                                                self.model, loss_func,
                                                self.device, self.tgt_id2word)
        print(f'Avg loss: {avg_loss:05.3f}, BLEU score: {bleu_score}')