Пример #1
0
    def evaluate(self, model):
        """ Evaluate a model on given dataset and return performance. """
        logger.info('evaluate() start')

        eval_queue = queue.Queue(self.num_workers << 1)
        eval_loader = AudioDataLoader(self.dataset, eval_queue, self.batch_size, 0)
        eval_loader.start()

        cer = self.decoder.search(model, eval_queue, self.device, self.print_every)
        self.decoder.save_result('../data/train_result/%s.csv' % type(self.decoder).__name__)

        logger.info('Evaluate CER: %s' % cer)
        logger.info('evaluate() completed')
        eval_loader.join()
Пример #2
0
    def train(self,
              model: nn.Module,
              batch_size: int,
              epoch_time_step: int,
              num_epochs: int,
              teacher_forcing_ratio: float = 0.99,
              resume: bool = False) -> nn.Module:
        """
        Run training for a given model.

        Args:
            model (torch.nn.Module): model to train
            batch_size (int): batch size for experiment
            epoch_time_step (int): number of time step for training
            num_epochs (int): number of epochs for training
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
        """
        start_epoch = 0

        if resume:
            checkpoint = Checkpoint()
            latest_checkpoint_path = checkpoint.get_latest_checkpoint()
            resume_checkpoint = checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer
            self.trainset_list = resume_checkpoint.trainset_list
            self.validset = resume_checkpoint.validset
            start_epoch = resume_checkpoint.epoch + 1
            epoch_time_step = 0

            for trainset in self.trainset_list:
                epoch_time_step += len(trainset)

            epoch_time_step = math.ceil(epoch_time_step / batch_size)

        logger.info('start')
        train_begin_time = time.time()

        for epoch in range(start_epoch, num_epochs):
            logger.info('Epoch %d start' % epoch)
            train_queue = queue.Queue(self.num_workers << 1)

            for trainset in self.trainset_list:
                trainset.shuffle()

            # Training
            train_loader = MultiDataLoader(self.trainset_list, train_queue,
                                           batch_size, self.num_workers)
            train_loader.start()

            train_loss, train_cer = self.__train_epoches(
                model, epoch, epoch_time_step, train_begin_time, train_queue,
                teacher_forcing_ratio)
            train_loader.join()

            Checkpoint(model, self.optimizer, self.trainset_list,
                       self.validset, epoch).save()
            logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' %
                        (epoch, train_loss, train_cer))

            teacher_forcing_ratio -= self.teacher_forcing_step
            teacher_forcing_ratio = max(self.min_teacher_forcing_ratio,
                                        teacher_forcing_ratio)

            # Validation
            valid_queue = queue.Queue(self.num_workers << 1)
            valid_loader = AudioDataLoader(self.validset, valid_queue,
                                           batch_size, 0)
            valid_loader.start()

            valid_loss, valid_cer = self.validate(model, valid_queue)
            valid_loader.join()

            logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' %
                        (epoch, valid_loss, valid_cer))
            self.__save_epoch_result(
                train_result=[self.train_dict, train_loss, train_cer],
                valid_result=[self.valid_dict, valid_loss, valid_cer])
            logger.info(
                'Epoch %d Training result saved as a csv file complete !!' %
                epoch)
            torch.cuda.empty_cache()

        Checkpoint(model, self.optimizer, self.criterion, self.trainset_list,
                   self.validset, num_epochs).save()
        return model
Пример #3
0
    def train(self,
              model,
              batch_size,
              epoch_time_step,
              num_epochs,
              teacher_forcing_ratio=0.99,
              resume=False):
        """
        Run training for a given model.

        Args:
            model (torch.nn.Module): model to train
            batch_size (int): batch size for experiment
            epoch_time_step (int): number of time step for training
            num_epochs (int): number of epochs for training
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
        """
        start_epoch = 0
        prev_train_cer = 1.

        if resume:
            checkpoint = Checkpoint()
            latest_checkpoint_path = checkpoint.get_latest_checkpoint()
            resume_checkpoint = checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer
            self.criterion = resume_checkpoint.criterion
            self.trainset_list = resume_checkpoint.trainset_list
            self.validset = resume_checkpoint.validset
            start_epoch = resume_checkpoint.epoch
            epoch_time_step = 0
            for trainset in self.trainset_list:
                epoch_time_step += len(trainset)
            epoch_time_step = math.ceil(epoch_time_step / batch_size)

        logger.info('start')
        train_begin_time = time.time()

        for epoch in range(start_epoch, num_epochs):
            train_queue = queue.Queue(self.num_workers << 1)
            for trainset in self.trainset_list:
                trainset.shuffle()

            # Training
            train_loader = MultiDataLoader(self.trainset_list, train_queue,
                                           batch_size, self.num_workers)
            train_loader.start()
            train_loss, train_cer = self.train_epoches(model, epoch,
                                                       epoch_time_step,
                                                       train_begin_time,
                                                       train_queue,
                                                       teacher_forcing_ratio)
            train_loader.join()

            Checkpoint(model, self.optimizer, self.criterion,
                       self.trainset_list, self.validset, epoch).save()
            logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' %
                        (epoch, train_loss, train_cer))

            if prev_train_cer - train_cer < self.decay_threshold:
                self.optimizer.set_scheduler(
                    ExponentialDecayLR(self.optimizer.optimizer,
                                       self.optimizer.get_lr(),
                                       self.low_plateau_lr,
                                       self.exp_decay_period),
                    self.exp_decay_period)

            prev_train_cer = train_cer
            teacher_forcing_ratio -= self.teacher_forcing_step
            teacher_forcing_ratio = max(self.min_teacher_forcing_ratio,
                                        teacher_forcing_ratio)

            # Validation
            valid_queue = queue.Queue(self.num_workers << 1)
            valid_loader = AudioDataLoader(self.validset, valid_queue,
                                           batch_size, 0)
            valid_loader.start()

            valid_cer = self.validate(model, valid_queue)
            valid_loader.join()

            logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' %
                        (epoch, 1.0, valid_cer))
            self._save_epoch_result(
                train_result=[self.train_dict, train_loss, train_cer],
                valid_result=[self.valid_dict, 1.0, valid_cer])
            logger.info(
                'Epoch %d Training result saved as a csv file complete !!' %
                epoch)

        return model