Ejemplo n.º 1
0
def train(opt):
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)
    device = check_envirionment(opt.use_cuda)

    audio_paths, script_paths = load_data_list(opt.data_list_path,
                                               opt.dataset_path)

    epoch_time_step, trainset_list, validset = split_dataset(
        opt, audio_paths, script_paths)
    model = build_ensemble(['model_path1', 'model_path2', 'model_path3'],
                           opt.ensemble_method, device)

    optimizer = optim.Adam(model.module.parameters(), lr=opt.init_lr)
    optimizer = Optimizer(optimizer, None, 0, opt.max_grad_norm)
    criterion = nn.NLLLoss(reduction='sum', ignore_index=PAD_token).to(device)

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=opt.num_workers,
        high_plateau_lr=opt.high_plateau_lr,
        low_plateau_lr=opt.low_plateau_lr,
        decay_threshold=opt.decay_threshold,
        exp_decay_period=opt.exp_decay_period,
        device=device,
        teacher_forcing_step=opt.teacher_forcing_step,
        min_teacher_forcing_ratio=opt.min_teacher_forcing_ratio,
        print_every=opt.print_every,
        save_result_every=opt.save_result_every,
        checkpoint_every=opt.checkpoint_every)
    model = trainer.train(model=model,
                          batch_size=opt.batch_size,
                          epoch_time_step=epoch_time_step,
                          num_epochs=opt.num_epochs,
                          teacher_forcing_ratio=opt.teacher_forcing_ratio,
                          resume=opt.resume)
    Checkpoint(model, model.optimizer, model.criterion, model.trainset_list,
               model.validset, opt.num_epochs).save()
Ejemplo n.º 2
0
    def train(self,
              model: nn.Module,
              batch_size: int,
              epoch_time_step: int,
              num_epochs: int,
              teacher_forcing_ratio: float = 0.99,
              resume: bool = False) -> nn.Module:
        """
        Run training for a given model.

        Args:
            model (torch.nn.Module): model to train
            batch_size (int): batch size for experiment
            epoch_time_step (int): number of time step for training
            num_epochs (int): number of epochs for training
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
        """
        start_epoch = 0

        if resume:
            checkpoint = Checkpoint()
            latest_checkpoint_path = checkpoint.get_latest_checkpoint()
            resume_checkpoint = checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer
            self.trainset_list = resume_checkpoint.trainset_list
            self.validset = resume_checkpoint.validset
            start_epoch = resume_checkpoint.epoch + 1
            epoch_time_step = 0

            for trainset in self.trainset_list:
                epoch_time_step += len(trainset)

            epoch_time_step = math.ceil(epoch_time_step / batch_size)

        logger.info('start')
        train_begin_time = time.time()

        for epoch in range(start_epoch, num_epochs):
            logger.info('Epoch %d start' % epoch)
            train_queue = queue.Queue(self.num_workers << 1)

            for trainset in self.trainset_list:
                trainset.shuffle()

            # Training
            train_loader = MultiDataLoader(self.trainset_list, train_queue,
                                           batch_size, self.num_workers)
            train_loader.start()

            train_loss, train_cer = self.__train_epoches(
                model, epoch, epoch_time_step, train_begin_time, train_queue,
                teacher_forcing_ratio)
            train_loader.join()

            Checkpoint(model, self.optimizer, self.trainset_list,
                       self.validset, epoch).save()
            logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' %
                        (epoch, train_loss, train_cer))

            teacher_forcing_ratio -= self.teacher_forcing_step
            teacher_forcing_ratio = max(self.min_teacher_forcing_ratio,
                                        teacher_forcing_ratio)

            # Validation
            valid_queue = queue.Queue(self.num_workers << 1)
            valid_loader = AudioDataLoader(self.validset, valid_queue,
                                           batch_size, 0)
            valid_loader.start()

            valid_loss, valid_cer = self.validate(model, valid_queue)
            valid_loader.join()

            logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' %
                        (epoch, valid_loss, valid_cer))
            self.__save_epoch_result(
                train_result=[self.train_dict, train_loss, train_cer],
                valid_result=[self.valid_dict, valid_loss, valid_cer])
            logger.info(
                'Epoch %d Training result saved as a csv file complete !!' %
                epoch)
            torch.cuda.empty_cache()

        Checkpoint(model, self.optimizer, self.criterion, self.trainset_list,
                   self.validset, num_epochs).save()
        return model
Ejemplo n.º 3
0
    def __train_epoches(self, model: nn.Module, epoch: int,
                        epoch_time_step: int, train_begin_time: float,
                        queue: queue.Queue,
                        teacher_forcing_ratio: float) -> Tuple[float, float]:
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            epoch (int): number of current epoch
            epoch_time_step (int): total time step in one epoch
            train_begin_time (float): time of train begin
            queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)

        Returns: loss, cer
            - **loss** (float): loss of current epoch
            - **cer** (float): character error rate of current epoch
        """
        cer = 1.0
        epoch_loss_total = 0.
        total_num = 0
        timestep = 0

        model.train()

        begin_time = epoch_begin_time = time.time()
        num_workers = self.num_workers

        while True:
            inputs, targets, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                # Empty feats means closing one loader
                num_workers -= 1
                logger.debug('left train_loader: %d' % num_workers)

                if num_workers == 0:
                    break
                else:
                    continue

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            model = model.to(self.device)

            if self.architecture == 'las':
                if isinstance(model, nn.DataParallel):
                    model.module.flatten_parameters()
                else:
                    model.flatten_parameters()

                logit = model(inputs=inputs,
                              input_lengths=input_lengths,
                              targets=targets,
                              teacher_forcing_ratio=teacher_forcing_ratio)
                logit = torch.stack(logit, dim=1).to(self.device)
                targets = targets[:, 1:]

            elif self.architecture == 'transformer':
                logit = model(inputs,
                              input_lengths,
                              targets,
                              return_attns=False)

            else:
                raise ValueError("Unsupported architecture : {0}".format(
                    self.architecture))

            hypothesis = logit.max(-1)[1]
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  targets.contiguous().view(-1))
            epoch_loss_total += loss.item()

            cer = self.metric(targets, hypothesis)
            total_num += int(input_lengths.sum())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step(model)

            timestep += 1
            torch.cuda.empty_cache()

            if timestep % self.print_every == 0:
                current_time = time.time()
                elapsed = current_time - begin_time
                epoch_elapsed = (current_time - epoch_begin_time) / 60.0
                train_elapsed = (current_time - train_begin_time) / 3600.0

                logger.info(
                    'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h, lr: {:.5f}'
                    .format(timestep, epoch_time_step,
                            epoch_loss_total / total_num, cer, elapsed,
                            epoch_elapsed, train_elapsed,
                            self.optimizer.get_lr()))
                begin_time = time.time()

            if timestep % self.save_result_every == 0:
                self.__save_step_result(self.train_step_result,
                                        epoch_loss_total / total_num, cer)

            if timestep % self.checkpoint_every == 0:
                Checkpoint(model, self.optimizer, self.trainset_list,
                           self.validset, epoch).save()

            del inputs, input_lengths, targets, logit, loss, hypothesis

        Checkpoint(model, self.optimizer, self.trainset_list, self.validset,
                   epoch).save()

        logger.info('train() completed')
        return epoch_loss_total / total_num, cer
Ejemplo n.º 4
0
def train(opt):
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)
    device = check_envirionment(opt.use_cuda)

    if not opt.resume:
        audio_paths, script_paths = load_data_list(opt.data_list_path,
                                                   opt.dataset_path)

        epoch_time_step, trainset_list, validset = split_dataset(
            opt, audio_paths, script_paths)
        model = build_model(opt, device)

        optimizer = optim.Adam(model.module.parameters(),
                               lr=opt.init_lr,
                               weight_decay=1e-05)

        if opt.rampup_period > 0:
            scheduler = RampUpLR(optimizer, opt.init_lr, opt.high_plateau_lr,
                                 opt.rampup_period)
            optimizer = Optimizer(optimizer, scheduler, opt.rampup_period,
                                  opt.max_grad_norm)
        else:
            optimizer = Optimizer(optimizer, None, 0, opt.max_grad_norm)

        if opt.label_smoothing == 0.0:
            criterion = nn.NLLLoss(reduction='sum',
                                   ignore_index=PAD_token).to(device)
        else:
            criterion = LabelSmoothingLoss(len(char2id),
                                           PAD_token,
                                           opt.label_smoothing,
                                           dim=-1).to(device)

    else:
        trainset_list = None
        validset = None
        model = None
        optimizer = None
        criterion = None
        epoch_time_step = None

    trainer = SupervisedTrainer(
        optimizer=optimizer,
        criterion=criterion,
        trainset_list=trainset_list,
        validset=validset,
        num_workers=opt.num_workers,
        high_plateau_lr=opt.high_plateau_lr,
        low_plateau_lr=opt.low_plateau_lr,
        decay_threshold=opt.decay_threshold,
        exp_decay_period=opt.exp_decay_period,
        device=device,
        teacher_forcing_step=opt.teacher_forcing_step,
        min_teacher_forcing_ratio=opt.min_teacher_forcing_ratio,
        print_every=opt.print_every,
        save_result_every=opt.save_result_every,
        checkpoint_every=opt.checkpoint_every)
    model = trainer.train(model=model,
                          batch_size=opt.batch_size,
                          epoch_time_step=epoch_time_step,
                          num_epochs=opt.num_epochs,
                          teacher_forcing_ratio=opt.teacher_forcing_ratio,
                          resume=opt.resume)
    Checkpoint(model, model.optimizer, model.criterion, model.trainset_list,
               model.validset, opt.num_epochs).save()
    def train(self,
              model,
              batch_size,
              epoch_time_step,
              num_epochs,
              teacher_forcing_ratio=0.99,
              resume=False):
        """
        Run training for a given model.

        Args:
            model (torch.nn.Module): model to train
            batch_size (int): batch size for experiment
            epoch_time_step (int): number of time step for training
            num_epochs (int): number of epochs for training
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
        """
        start_epoch = 0
        prev_train_cer = 1.

        if resume:
            checkpoint = Checkpoint()
            latest_checkpoint_path = checkpoint.get_latest_checkpoint()
            resume_checkpoint = checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer
            self.criterion = resume_checkpoint.criterion
            self.trainset_list = resume_checkpoint.trainset_list
            self.validset = resume_checkpoint.validset
            start_epoch = resume_checkpoint.epoch
            epoch_time_step = 0
            for trainset in self.trainset_list:
                epoch_time_step += len(trainset)
            epoch_time_step = math.ceil(epoch_time_step / batch_size)

            epoch_time_step = 0
            for trainset in self.trainset_list:
                epoch_time_step += len(trainset)

            epoch_time_step = math.ceil(epoch_time_step / batch_size)

            for g in self.optimizer.optimizer.param_groups:
                g['lr'] = 1e-04

            print("Learning rate : %f", self.optimizer.get_lr())

        logger.info('start')
        train_begin_time = time.time()

        for epoch in range(start_epoch, num_epochs):
            train_queue = queue.Queue(self.num_workers << 1)
            for trainset in self.trainset_list:
                trainset.shuffle()

            # Training
            train_loader = MultiAudioLoader(self.trainset_list, train_queue,
                                            batch_size, self.num_workers)
            train_loader.start()
            train_loss, train_cer = self.train_epoches(model, epoch,
                                                       epoch_time_step,
                                                       train_begin_time,
                                                       train_queue,
                                                       teacher_forcing_ratio)
            train_loader.join()

            Checkpoint(model, self.optimizer, self.criterion,
                       self.trainset_list, self.validset, epoch).save()
            logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' %
                        (epoch, train_loss, train_cer))

            if prev_train_cer - train_cer < self.decay_threshold:
                self.optimizer.set_scheduler(
                    ExponentialDecayLR(self.optimizer.optimizer,
                                       self.optimizer.get_lr(),
                                       self.low_plateau_lr,
                                       self.exp_decay_period),
                    self.exp_decay_period)

            prev_train_cer = train_cer
            teacher_forcing_ratio -= self.teacher_forcing_step
            teacher_forcing_ratio = max(self.min_teacher_forcing_ratio,
                                        teacher_forcing_ratio)

            # Validation
            valid_queue = queue.Queue(self.num_workers << 1)
            valid_loader = AudioLoader(self.validset, valid_queue, batch_size,
                                       0)
            valid_loader.start()

            valid_cer = self.validate(model, valid_queue)
            valid_loader.join()

            logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' %
                        (epoch, 0.0, valid_cer))
            self._save_epoch_result(
                train_result=[self.train_dict, train_loss, train_cer],
                valid_result=[self.valid_dict, 0.0, valid_cer])
            logger.info(
                'Epoch %d Training result saved as a csv file complete !!' %
                epoch)

        return model
    def train_epoches(self, model, epoch, epoch_time_step, train_begin_time,
                      queue, teacher_forcing_ratio):
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            epoch (int): number of current epoch
            epoch_time_step (int): total time step in one epoch
            train_begin_time (int): time of train begin
            queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)

        Returns: loss, cer
            - **loss** (float): loss of current epoch
            - **cer** (float): character error rate of current epoch
        """
        cer = 1.0
        epoch_loss_total = 0.
        total_num = 0
        timestep = 0

        model.train()
        begin_time = epoch_begin_time = time.time()

        while True:
            inputs, scripts, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                # Empty feats means closing one loader
                self.num_workers -= 1
                logger.debug('left train_loader: %d' % self.num_workers)

                if self.num_workers == 0:
                    break
                else:
                    continue

            inputs = inputs.to(self.device)
            scripts = scripts.to(self.device)
            targets = scripts[:, 1:]

            model.module.flatten_parameters()
            output = model(inputs,
                           input_lengths,
                           scripts,
                           teacher_forcing_ratio=teacher_forcing_ratio)[0]

            logit = torch.stack(output, dim=1).to(self.device)
            hypothesis = logit.max(-1)[1]

            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  targets.contiguous().view(-1))
            epoch_loss_total += loss.item()

            cer = self.metric(targets, hypothesis)
            total_num += int(input_lengths.sum())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step(model, loss.item())

            timestep += 1
            torch.cuda.empty_cache()

            if timestep % self.print_every == 0:
                current_time = time.time()
                elapsed = current_time - begin_time
                epoch_elapsed = (current_time - epoch_begin_time) / 60.0
                train_elapsed = (current_time - train_begin_time) / 3600.0

                logger.info(
                    'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h'
                    .format(timestep, epoch_time_step,
                            epoch_loss_total / total_num, cer, elapsed,
                            epoch_elapsed, train_elapsed))
                begin_time = time.time()

            if timestep % self.save_result_every == 0:
                self._save_step_result(self.train_step_result,
                                       epoch_loss_total / total_num, cer)

            if timestep % self.checkpoint_every == 0:
                Checkpoint(model, self.optimizer, self.criterion,
                           self.trainset_list, self.validset, epoch).save()

            del inputs, input_lengths, scripts, targets, output, logit, loss, hypothesis

        Checkpoint(model, self.optimizer, self.criterion, self.trainset_list,
                   self.validset, epoch).save()

        logger.info('train() completed')
        return epoch_loss_total / total_num, cer