Пример #1
0
    def save(self):
        """
        Saves the current model and related training parameters into a subdirectory of the checkpoint directory.
        The name of the subdirectory is the current local time in Y_M_D_H_M_S format.
        """
        date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())

        trainer_states = {
            'optimizer': self.optimizer,
            'trainset_list': self.trainset_list,
            'validset': self.validset,
            'epoch': self.epoch
        }
        torch.save(trainer_states, os.path.join(os.getcwd(), self.TRAINER_STATE_NAME))
        torch.save(self.model, os.path.join(os.getcwd(), self.MODEL_NAME))
        logger.info('save checkpoints\n%s\n%s'
                    % (os.path.join(os.getcwd(), self.TRAINER_STATE_NAME),
                       os.path.join(os.getcwd(), self.MODEL_NAME)))
Пример #2
0
    def extract_noise(self, audio_path):
        try:
            signal = np.memmap(audio_path, dtype='h', mode='r').astype('float32')
            non_silence_indices = split(signal, top_db=30)

            for (start, end) in non_silence_indices:
                signal[start:end] = 0

            noise = signal[signal != 0]
            return noise / 32767

        except RuntimeError:
            logger.info("RuntimeError in {0}".format(audio_path))
            return None

        except ValueError:
            logger.info("RuntimeError in {0}".format(audio_path))
            return None
    def _validate(self, model: nn.Module, queue: queue.Queue) -> float:
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            queue (queue.Queue): validation queue, containing input, targets, input_lengths, target_lengths

        Returns: loss, cer
            - **loss** (float): loss of validation
            - **cer** (float): character error rate of validation
        """
        target_list, predict_list = list(), list()
        cer = 1.0

        model.eval()
        logger.info('validate() start')

        while True:
            inputs, targets, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                break

            inputs = inputs.to(self.device)
            targets = targets[:, 1:].to(self.device)
            model.to(self.device)

            if isinstance(model, nn.DataParallel):
                y_hats = model.module.recognize(inputs, input_lengths)
            else:
                y_hats = model.recognize(inputs, input_lengths)
                
            for idx in range(targets.size(0)):
                target_list.append(self.vocab.label_to_string(targets[idx]))
                predict_list.append(self.vocab.label_to_string(y_hats[idx].cpu().detach().numpy()))
                
            cer = self.metric(targets, y_hats)

        self._save_result(target_list, predict_list)
        logger.info('validate() completed')

        return cer
    def validate(self, model, queue):
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            queue (queue.Queue): validation queue, containing input, targets, input_lengths, target_lengths

        Returns: loss, cer
            - **loss** (float): loss of validation
            - **cer** (float): character error rate of validation
        """
        cer = 1.0

        model.eval()
        logger.info('validate() start')

        with torch.no_grad():
            while True:
                inputs, scripts, input_lengths, script_lengths = queue.get()

                if inputs.shape[0] == 0:
                    break

                inputs = inputs.to(self.device)
                scripts = scripts.to(self.device)
                targets = scripts[:, 1:]

                model.module.flatten_parameters()
                output = model(inputs,
                               input_lengths,
                               teacher_forcing_ratio=0.0)[0]

                logit = torch.stack(output, dim=1).to(self.device)
                hypothesis = logit.max(-1)[1]

                cer = self.metric(targets, hypothesis)

                del inputs, input_lengths, scripts, targets, output, logit, hypothesis

        logger.info('validate() completed')
        return cer
Пример #5
0
    def search(self, model, queue, device, print_every):
        cer = 0
        total_sent_num = 0
        timestep = 0

        model.eval()

        with torch.no_grad():
            while True:
                inputs, scripts, input_lengths, target_lengths = queue.get()
                if inputs.shape[0] == 0:
                    break

                inputs = inputs.to(device)
                scripts = scripts.to(device)
                targets = scripts[:, 1:]

                output, _ = model(inputs,
                                  input_lengths,
                                  teacher_forcing_ratio=0.0,
                                  language_model=self.language_model)

                logit = torch.stack(output, dim=1).to(device)
                hypothesis = logit.max(-1)[1]

                for idx in range(targets.size(0)):
                    self.target_list.append(
                        label_to_string(scripts[idx], id2char, EOS_token))
                    self.hypothesis_list.append(
                        label_to_string(hypothesis[idx].cpu().detach().numpy(),
                                        id2char, EOS_token))

                cer = self.metric(targets, hypothesis)
                total_sent_num += scripts.size(0)

                if timestep % print_every == 0:
                    logger.info('cer: {:.2f}'.format(cer))

                timestep += 1

        return cer
Пример #6
0
    def search(self, model: nn.Module, queue: Queue, device: str,
               print_every: int) -> float:
        cer = 0
        total_sent_num = 0
        timestep = 0

        model.eval()

        with torch.no_grad():
            while True:
                inputs, targets, input_lengths, target_lengths = queue.get()
                if inputs.shape[0] == 0:
                    break

                inputs = inputs.to(device)
                targets = targets.to(device)

                output = model(inputs,
                               input_lengths,
                               teacher_forcing_ratio=0.0,
                               language_model=self.language_model,
                               return_decode_dict=False)
                logit = torch.stack(output, dim=1).to(device)
                y_hat = logit.max(-1)[1]

                for idx in range(targets.size(0)):
                    self.target_list.append(
                        label_to_string(targets[idx], id2char, EOS_token))
                    self.y_hats.append(
                        label_to_string(y_hat[idx].cpu().detach().numpy(),
                                        id2char, EOS_token))

                cer = self.metric(targets[:, 1:], y_hat)
                total_sent_num += targets.size(0)

                if timestep % print_every == 0:
                    logger.info('cer: {:.2f}'.format(cer))

                timestep += 1

        return cer
Пример #7
0
    def load(self, path):
        """
        Loads a Checkpoint object that was previously saved to disk.

        Args:
            path (str): path to the checkpoint subdirectory

        Returns:
            checkpoint (Checkpoint): checkpoint object with fields copied from those stored on disk
       """
        logger.info('load checkpoints\n%s\n%s' %
                    (os.path.join(path, self.TRAINER_STATE_NAME),
                     os.path.join(path, self.MODEL_NAME)))

        if torch.cuda.is_available():
            resume_checkpoint = torch.load(
                os.path.join(path, self.TRAINER_STATE_NAME))
            model = torch.load(os.path.join(path, self.MODEL_NAME))

        else:
            resume_checkpoint = torch.load(
                os.path.join(path, self.TRAINER_STATE_NAME),
                map_location=lambda storage, loc: storage)
            model = torch.load(os.path.join(path, self.MODEL_NAME),
                               map_location=lambda storage, loc: storage)

        if isinstance(model, ListenAttendSpell):
            if isinstance(model, nn.DataParallel):
                model.module.flatten_parameters(
                )  # make RNN parameters contiguous
            else:
                model.flatten_parameters()

        return Checkpoint(
            model=model,
            optimizer=resume_checkpoint['optimizer'],
            epoch=resume_checkpoint['epoch'],
            trainset_list=resume_checkpoint['trainset_list'],
            validset=resume_checkpoint['validset'],
        )
Пример #8
0
    def search(self, model: nn.Module, queue: Queue, device: str,
               print_every: int) -> float:
        cer = 0
        total_sent_num = 0
        timestep = 0

        if isinstance(model, nn.DataParallel):
            model = model.module

        model.eval()
        model.to(device)

        while True:
            inputs, targets, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                break

            inputs = inputs.to(device)
            targets = targets.to(device)

            y_hats = model.recognize(inputs, input_lengths)

            for idx in range(targets.size(0)):
                self.target_list.append(
                    self.vocab.label_to_string(targets[idx]))
                self.predict_list.append(
                    self.vocab.label_to_string(
                        y_hats[idx].cpu().detach().numpy()))

            cer = self.metric(targets[:, 1:], y_hats)
            total_sent_num += targets.size(0)

            if timestep % print_every == 0:
                logger.info('cer: {:.2f}'.format(cer))

            timestep += 1

        return cer
    def validate(self, model: nn.Module, queue: queue.Queue) -> float:
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            queue (queue.Queue): validation queue, containing input, targets, input_lengths, target_lengths

        Returns: loss, cer
            - **loss** (float): loss of validation
            - **cer** (float): character error rate of validation
        """
        cer = 1.0

        model.eval()
        logger.info('validate() start')

        while True:
            inputs, targets, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                break

            inputs = inputs.to(self.device)
            targets = targets[:, 1:].to(self.device)
            model.to(self.device)

            if isinstance(model, nn.DataParallel):
                y_hats = model.module.greedy_decode(inputs, input_lengths,
                                                    self.device)
            else:
                y_hats = model.greedy_decode(inputs, input_lengths,
                                             self.device)
            cer = self.metric(targets, y_hats)

        logger.info('validate() completed')

        return cer
Пример #10
0
    def save(self):
        """
        Saves the current model and related training parameters into a subdirectory of the checkpoint directory.
        The name of the subdirectory is the current local time in Y_M_D_H_M_S format.
        """
        date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
        path = os.path.join(self.SAVE_PATH, self.CHECKPOINT_DIR_NAME, date_time)

        if os.path.exists(path):
            shutil.rmtree(path)  # delete path dir & sub-files
        os.makedirs(path)

        trainer_states = {
            'optimizer': self.optimizer,
            'trainset_list': self.trainset_list,
            'validset': self.validset,
            'epoch': self.epoch
        }
        torch.save(trainer_states, os.path.join(path, self.TRAINER_STATE_NAME))
        torch.save(self.model, os.path.join(path, self.MODEL_NAME))
        logger.info('save checkpoints\n%s\n%s'
                    % (os.path.join(path, self.TRAINER_STATE_NAME),
                       os.path.join(path, self.MODEL_NAME)))
Пример #11
0
    def parse_audio(self, audio_path: str, augment_method: int) -> Tensor:
        """
        Parses audio.

        Args:
             audio_path (str): path of audio file
             augment_method (int): flag indication which augmentation method to use.

        Returns: feature_vector
            - **feature_vector** (torch.FloatTensor): feature from audio file.
        """
        signal = load_audio(audio_path,
                            self.del_silence,
                            extension=self.audio_extension)

        if signal is None:
            logger.info("Audio is None : {0}".format(audio_path))
            return None

        feature = self.transforms(signal)

        if self.normalize:
            feature -= feature.mean()
            feature /= np.std(feature)

        # Refer to "Sequence to Sequence Learning with Neural Network" paper
        if self.input_reverse:
            feature = feature[:, ::-1]
            feature = FloatTensor(
                np.ascontiguousarray(np.swapaxes(feature, 0, 1)))
        else:
            feature = FloatTensor(feature).transpose(0, 1)

        if augment_method == SpectrogramParser.SPEC_AUGMENT:
            feature = self.spec_augment(feature)

        return feature
Пример #12
0
    def __init__(self, dataset_path, noiseset_size, sample_rate=16000, noise_level=0.7):
        if not os.path.exists(dataset_path):
            logger.info("Directory doesn`t exist: {0}".format(dataset_path))
            raise IOError

        logger.info("Create Noise injector...")

        self.noiseset_size = noiseset_size
        self.sample_rate = sample_rate
        self.noise_level = noise_level
        self.audio_paths = self.create_audio_paths(dataset_path)
        self.dataset = self.create_noiseset(dataset_path)

        logger.info("Create Noise injector complete !!")
Пример #13
0
    def evaluate(self, model):
        """ Evaluate a model on given dataset and return performance. """
        logger.info('evaluate() start')

        eval_queue = queue.Queue(self.num_workers << 1)
        eval_loader = AudioLoader(self.dataset, eval_queue, self.batch_size, 0)
        eval_loader.start()

        cer = self.decoder.search(model, eval_queue, self.device, self.print_every)
        self.decoder.save_result('../data/train_result/%s.csv' % type(self.decoder).__name__)

        logger.info('Evaluate CER: %s' % cer)
        logger.info('evaluate() completed')
        eval_loader.join()
Пример #14
0
def split_dataset(config: DictConfig, transcripts_path: str,
                  vocab: Vocabulary):
    """
    split into training set and validation set.

    Args:
        opt (ArgumentParser): set of options
        transcripts_path (str): path of  transcripts

    Returns: train_batch_num, train_dataset_list, valid_dataset
        - **train_time_step** (int): number of time step for training
        - **trainset_list** (list): list of training dataset
        - **validset** (data_loader.MelSpectrogramDataset): validation dataset
    """
    logger.info("split dataset start !!")
    trainset_list = list()

    if config.train.dataset == 'kspon':
        train_num = 620000
        valid_num = 2545
    elif config.train.dataset == 'libri':
        train_num = 281241
        valid_num = 5567
    else:
        raise NotImplementedError("Unsupported Dataset : {0}".format(
            config.train.dataset))

    audio_paths, transcripts = load_dataset(transcripts_path)

    total_time_step = math.ceil(len(audio_paths) / config.train.batch_size)
    valid_time_step = math.ceil(valid_num / config.train.batch_size)
    train_time_step = total_time_step - valid_time_step

    train_audio_paths = audio_paths[:train_num + 1]
    train_transcripts = transcripts[:train_num + 1]

    valid_audio_paths = audio_paths[train_num + 1:]
    valid_transcripts = transcripts[train_num + 1:]

    if config.audio.spec_augment:
        train_time_step <<= 1

    train_num_per_worker = math.ceil(train_num / config.train.num_workers)

    # audio_paths & script_paths shuffled in the same order
    # for seperating train & validation
    tmp = list(zip(train_audio_paths, train_transcripts))
    random.shuffle(tmp)
    train_audio_paths, train_transcripts = zip(*tmp)

    # seperating the train dataset by the number of workers
    for idx in range(config.train.num_workers):
        train_begin_idx = train_num_per_worker * idx
        train_end_idx = min(train_num_per_worker * (idx + 1), train_num)

        trainset_list.append(
            SpectrogramDataset(
                train_audio_paths[train_begin_idx:train_end_idx],
                train_transcripts[train_begin_idx:train_end_idx],
                vocab.sos_id,
                vocab.eos_id,
                config=config,
                spec_augment=config.audio.spec_augment,
                dataset_path=config.train.dataset_path,
                audio_extension=config.audio.audio_extension,
            ))

    validset = SpectrogramDataset(
        audio_paths=valid_audio_paths,
        transcripts=valid_transcripts,
        sos_id=vocab.sos_id,
        eos_id=vocab.eos_id,
        config=config,
        spec_augment=False,
        dataset_path=config.train.dataset_path,
        audio_extension=config.audio.audio_extension,
    )

    logger.info("split dataset complete !!")
    return train_time_step, trainset_list, validset
Пример #15
0
    def _train_epoches(
        self,
        model: nn.Module,
        epoch: int,
        epoch_time_step: int,
        train_begin_time: float,
        queue: queue.Queue,
        teacher_forcing_ratio: float,
    ) -> Tuple[nn.Module, float, float]:
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            epoch (int): number of current epoch
            epoch_time_step (int): total time step in one epoch
            train_begin_time (float): time of train begin
            queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)

        Returns: loss, cer
            - **loss** (float): loss of current epoch
            - **cer** (float): character error rate of current epoch
        """
        architecture = self.architecture
        if self.architecture == 'conformer':
            if isinstance(model, nn.DataParallel):
                architecture = 'conformer_t' if model.module.decoder is not None else 'conformer_ctc'
            else:
                architecture = 'conformer_t' if model.decoder is not None else 'conformer_ctc'

        cer = 1.0
        epoch_loss_total = 0.
        total_num = 0
        timestep = 0

        model.train()

        begin_time = epoch_begin_time = time.time()
        num_workers = self.num_workers

        while True:
            inputs, targets, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                # Empty feats means closing one loader
                num_workers -= 1
                logger.debug('left train_loader: %d' % num_workers)

                if num_workers == 0:
                    break
                else:
                    continue

            self.optimizer.zero_grad()

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            input_lengths = input_lengths.to(self.device)
            target_lengths = torch.as_tensor(target_lengths).to(self.device)

            model = model.to(self.device)
            output, loss, ctc_loss, cross_entropy_loss = self._model_forward(
                teacher_forcing_ratio=teacher_forcing_ratio,
                inputs=inputs,
                input_lengths=input_lengths,
                targets=targets,
                target_lengths=target_lengths,
                model=model,
                architecture=architecture,
            )

            if architecture not in ('rnnt', 'conformer_t'):
                y_hats = output.max(-1)[1]
                cer = self.metric(targets[:, 1:], y_hats)

            loss.backward()
            self.optimizer.step(model)

            total_num += int(input_lengths.sum())
            epoch_loss_total += loss.item()

            timestep += 1
            torch.cuda.empty_cache()

            if timestep % self.print_every == 0:
                current_time = time.time()
                elapsed = current_time - begin_time
                epoch_elapsed = (current_time - epoch_begin_time) / 60.0
                train_elapsed = (current_time - train_begin_time) / 3600.0

                if architecture in ('rnnt', 'conformer_t'):
                    logger.info(
                        self.rnnt_log_format.format(
                            timestep,
                            epoch_time_step,
                            loss,
                            elapsed,
                            epoch_elapsed,
                            train_elapsed,
                            self.optimizer.get_lr(),
                        ))
                else:
                    if self.joint_ctc_attention:
                        logger.info(
                            self.log_format.format(
                                timestep,
                                epoch_time_step,
                                loss,
                                ctc_loss,
                                cross_entropy_loss,
                                cer,
                                elapsed,
                                epoch_elapsed,
                                train_elapsed,
                                self.optimizer.get_lr(),
                            ))
                    else:
                        logger.info(
                            self.log_format.format(
                                timestep,
                                epoch_time_step,
                                loss,
                                cer,
                                elapsed,
                                epoch_elapsed,
                                train_elapsed,
                                self.optimizer.get_lr(),
                            ))
                begin_time = time.time()

            if timestep % self.save_result_every == 0:
                self._save_step_result(self.train_step_result,
                                       epoch_loss_total / total_num, cer)

            if timestep % self.checkpoint_every == 0:
                Checkpoint(model, self.optimizer, self.trainset_list,
                           self.validset, epoch).save()

            del inputs, input_lengths, targets, output, loss

        Checkpoint(model, self.optimizer, self.trainset_list, self.validset,
                   epoch).save()
        logger.info('train() completed')

        return model, epoch_loss_total / total_num, cer
Пример #16
0
    def train(
            self,
            model: nn.Module,  # model to train
            batch_size: int,  # batch size for experiment
            epoch_time_step: int,  # number of time step for training
            num_epochs: int,  # number of epochs (iteration) for training
            teacher_forcing_ratio: float = 0.99,  # teacher forcing ratio
            resume: bool = False,  # resume training with the latest checkpoint
    ) -> nn.Module:
        """
        Run training for a given model.

        Args:
            model (torch.nn.Module): model to train
            batch_size (int): batch size for experiment
            epoch_time_step (int): number of time step for training
            num_epochs (int): number of epochs for training
            teacher_forcing_ratio (float): teacher forcing ratio (default 0.99)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
        """
        start_epoch = 0

        if resume:
            checkpoint = Checkpoint()
            latest_checkpoint_path = checkpoint.get_latest_checkpoint()
            resume_checkpoint = checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer
            self.trainset_list = resume_checkpoint.trainset_list
            self.validset = resume_checkpoint.validset
            start_epoch = resume_checkpoint.epoch + 1
            epoch_time_step = 0

            for trainset in self.trainset_list:
                epoch_time_step += len(trainset)

            epoch_time_step = math.ceil(epoch_time_step / batch_size)

        logger.info('start')
        train_begin_time = time.time()

        for epoch in range(start_epoch, num_epochs):
            logger.info('Epoch %d start' % epoch)
            train_queue = queue.Queue(self.num_workers << 1)

            for trainset in self.trainset_list:
                trainset.shuffle()

            # Training
            train_loader = MultiDataLoader(self.trainset_list, train_queue,
                                           batch_size, self.num_workers,
                                           self.vocab.pad_id)
            train_loader.start()

            model, train_loss, train_cer = self._train_epoches(
                model=model,
                epoch=epoch,
                epoch_time_step=epoch_time_step,
                train_begin_time=train_begin_time,
                queue=train_queue,
                teacher_forcing_ratio=teacher_forcing_ratio,
            )
            train_loader.join()

            Checkpoint(model, self.optimizer, self.trainset_list,
                       self.validset, epoch).save()
            logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' %
                        (epoch, train_loss, train_cer))

            teacher_forcing_ratio -= self.teacher_forcing_step
            teacher_forcing_ratio = max(self.min_teacher_forcing_ratio,
                                        teacher_forcing_ratio)

            # Validation
            valid_queue = queue.Queue(self.num_workers << 1)
            valid_loader = AudioDataLoader(self.validset, valid_queue,
                                           batch_size, 0, self.vocab.pad_id)
            valid_loader.start()

            valid_cer = self._validate(model, valid_queue)
            valid_loader.join()

            logger.info('Epoch %d CER %0.4f' % (epoch, valid_cer))
            self._save_epoch_result(
                train_result=[self.train_dict, train_loss, train_cer],
                valid_result=[self.valid_dict, train_loss, valid_cer])
            logger.info(
                'Epoch %d Training result saved as a csv file complete !!' %
                epoch)
            torch.cuda.empty_cache()

        Checkpoint(model, self.optimizer, self.trainset_list, self.validset,
                   num_epochs).save()
        return model
Пример #17
0
def split_dataset(opt, audio_paths, script_paths):
    """
    split into training set and validation set.

    Args:
        opt (ArgumentParser): set of options
        audio_paths (list): set of audio path
        script_paths (list): set of script path

    Returns: train_batch_num, train_dataset_list, valid_dataset
        - **train_time_step** (int): number of time step for training
        - **trainset_list** (list): list of training dataset
        - **validset** (data_loader.MelSpectrogramDataset): validation dataset
    """
    target_dict = load_targets(script_paths)

    logger.info("split dataset start !!")
    trainset_list = list()
    train_num = math.ceil(len(audio_paths) * (1 - opt.valid_ratio))
    total_time_step = math.ceil(len(audio_paths) / opt.batch_size)
    valid_time_step = math.ceil(total_time_step * opt.valid_ratio)
    train_time_step = total_time_step - valid_time_step
    base_time_step = train_time_step

    if opt.spec_augment:
        train_time_step += base_time_step

    if opt.noise_augment:
        train_time_step += base_time_step

    train_num_per_worker = math.ceil(train_num / opt.num_workers)

    # audio_paths & script_paths shuffled in the same order
    # for seperating train & validation
    tmp = list(zip(audio_paths, script_paths))
    random.shuffle(tmp)
    audio_paths, script_paths = zip(*tmp)

    # seperating the train dataset by the number of workers
    for idx in range(opt.num_workers):
        train_begin_idx = train_num_per_worker * idx
        train_end_idx = min(train_num_per_worker * (idx + 1), train_num)

        trainset_list.append(
            SpectrogramDataset(audio_paths[train_begin_idx:train_end_idx],
                               script_paths[train_begin_idx:train_end_idx],
                               SOS_token,
                               EOS_token,
                               target_dict=target_dict,
                               opt=opt,
                               spec_augment=opt.spec_augment,
                               noise_augment=opt.noise_augment,
                               dataset_path=opt.dataset_path,
                               noiseset_size=opt.noiseset_size,
                               noise_level=opt.noise_level))

    validset = SpectrogramDataset(audio_paths=audio_paths[train_num:],
                                  script_paths=script_paths[train_num:],
                                  sos_id=SOS_token,
                                  eos_id=EOS_token,
                                  target_dict=target_dict,
                                  opt=opt,
                                  spec_augment=False,
                                  noise_augment=False)

    logger.info("split dataset complete !!")
    return train_time_step, trainset_list, validset
Пример #18
0
    def train(
            self,
            model: nn.Module,
            batch_size: int,
            epoch_time_step: int,
            num_epochs: int,
            teacher_forcing_ratio: float = 0.99,
            resume: bool = False
    ) -> nn.Module:
        """
        Run training for a given model.

        Args:
            model (torch.nn.Module): model to train
            batch_size (int): batch size for experiment
            epoch_time_step (int): number of time step for training
            num_epochs (int): number of epochs for training
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
        """
        start_epoch = 0

        if resume:
            checkpoint = Checkpoint()
            latest_checkpoint_path = checkpoint.get_latest_checkpoint()
            resume_checkpoint = checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer
            self.trainset_list = resume_checkpoint.trainset_list
            self.validset = resume_checkpoint.validset
            start_epoch = resume_checkpoint.epoch + 1
            epoch_time_step = 0

            for trainset in self.trainset_list:
                epoch_time_step += len(trainset)

            epoch_time_step = math.ceil(epoch_time_step / batch_size)

        logger.info('start')
        train_begin_time = time.time()

        for epoch in range(start_epoch, num_epochs):
            logger.info('Epoch %d start' % epoch)
            train_queue = queue.Queue(self.num_workers << 1)

            for trainset in self.trainset_list:
                trainset.shuffle()

            # Training
            train_loader = MultiDataLoader(self.trainset_list, train_queue, batch_size, self.num_workers)
            train_loader.start()

            if epoch == 1:
                self.optimizer.set_lr(1e-04)
            elif epoch == 2:
                self.optimizer.set_lr(5e-05)
            elif epoch == 3:
                self.optimizer.set_scheduler(ReduceLROnPlateau(self.optimizer.optimizer, patience=1, factor=0.5), 999999)

            train_loss, train_cer = self.__train_epoches(model, epoch, epoch_time_step, train_begin_time,
                                                         train_queue, teacher_forcing_ratio)
            train_loader.join()

            Checkpoint(model, self.optimizer, self.trainset_list, self.validset, epoch).save()
            logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' % (epoch, train_loss, train_cer))

            teacher_forcing_ratio -= self.teacher_forcing_step
            teacher_forcing_ratio = max(self.min_teacher_forcing_ratio, teacher_forcing_ratio)

            # Validation
            valid_queue = queue.Queue(self.num_workers << 1)
            valid_loader = AudioDataLoader(self.validset, valid_queue, batch_size, 0)
            valid_loader.start()

            valid_loss, valid_cer = self.validate(model, valid_queue)
            valid_loader.join()

            if isinstance(self.optimizer.scheduler, ReduceLROnPlateau):
                self.optimizer.scheduler.step(valid_loss)

            logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' % (epoch, valid_loss, valid_cer))
            self.__save_epoch_result(train_result=[self.train_dict, train_loss, train_cer],
                                     valid_result=[self.valid_dict, valid_loss, valid_cer])
            logger.info('Epoch %d Training result saved as a csv file complete !!' % epoch)

        Checkpoint(model, self.optimizer, self.criterion, self.trainset_list, self.validset, num_epochs).save()
        return model
Пример #19
0
def print_train_opts(opt):
    """ Print train options """
    logger.info('--dataset_path: %s' % str(opt.dataset_path))
    logger.info('--data_list_path: %s' % str(opt.data_list_path))
    logger.info('--label_path: %s' % str(opt.label_path))
    logger.info('--spec_augment: %s' % str(opt.spec_augment))
    logger.info('--noise_augment: %s' % str(opt.noise_augment))
    logger.info('--noiseset_size: %s' % str(opt.noiseset_size))
    logger.info('--noise_level: %s' % str(opt.noise_level))
    logger.info('--use_cuda: %s' % str(opt.use_cuda))
    logger.info('--batch_size: %s' % str(opt.batch_size))
    logger.info('--num_workers: %s' % str(opt.num_workers))
    logger.info('--num_epochs: %s' % str(opt.num_epochs))
    logger.info('--init_lr: %s' % str(opt.init_lr))
    logger.info('--high_plateau_lr: %s' % str(opt.high_plateau_lr))
    logger.info('--low_plateau_lr: %s' % str(opt.low_plateau_lr))
    logger.info('--decay_threshold: %s' % str(opt.decay_threshold))
    logger.info('--rampup_period: %s' % str(opt.rampup_period))
    logger.info('--exp_decay_period: %s' % str(opt.exp_decay_period))
    logger.info('--valid_ratio: %s' % str(opt.valid_ratio))
    logger.info('--max_len: %s' % str(opt.max_len))
    logger.info('--max_grad_norm: %s' % str(opt.max_grad_norm))
    logger.info('--teacher_forcing_step: %s' % str(opt.teacher_forcing_step))
    logger.info('--min_teacher_forcing_ratio: %s' %
                str(opt.min_teacher_forcing_ratio))
    logger.info('--seed: %s' % str(opt.seed))
    logger.info('--save_result_every: %s' % str(opt.save_result_every))
    logger.info('--checkpoint_every: %s' % str(opt.checkpoint_every))
    logger.info('--print_every: %s' % str(opt.print_every))
    logger.info('--resume: %s' % str(opt.resume))
    def train_epoches(self, model, epoch, epoch_time_step, train_begin_time,
                      queue, teacher_forcing_ratio):
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            epoch (int): number of current epoch
            epoch_time_step (int): total time step in one epoch
            train_begin_time (int): time of train begin
            queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)

        Returns: loss, cer
            - **loss** (float): loss of current epoch
            - **cer** (float): character error rate of current epoch
        """
        cer = 1.0
        epoch_loss_total = 0.
        total_num = 0
        timestep = 0

        model.train()
        begin_time = epoch_begin_time = time.time()

        while True:
            inputs, scripts, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                # Empty feats means closing one loader
                self.num_workers -= 1
                logger.debug('left train_loader: %d' % self.num_workers)

                if self.num_workers == 0:
                    break
                else:
                    continue

            inputs = inputs.to(self.device)
            scripts = scripts.to(self.device)
            targets = scripts[:, 1:]

            model.module.flatten_parameters()
            output = model(inputs,
                           input_lengths,
                           scripts,
                           teacher_forcing_ratio=teacher_forcing_ratio)[0]

            logit = torch.stack(output, dim=1).to(self.device)
            hypothesis = logit.max(-1)[1]

            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  targets.contiguous().view(-1))
            epoch_loss_total += loss.item()

            cer = self.metric(targets, hypothesis)
            total_num += int(input_lengths.sum())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step(model, loss.item())

            timestep += 1
            torch.cuda.empty_cache()

            if timestep % self.print_every == 0:
                current_time = time.time()
                elapsed = current_time - begin_time
                epoch_elapsed = (current_time - epoch_begin_time) / 60.0
                train_elapsed = (current_time - train_begin_time) / 3600.0

                logger.info(
                    'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h'
                    .format(timestep, epoch_time_step,
                            epoch_loss_total / total_num, cer, elapsed,
                            epoch_elapsed, train_elapsed))
                begin_time = time.time()

            if timestep % self.save_result_every == 0:
                self._save_step_result(self.train_step_result,
                                       epoch_loss_total / total_num, cer)

            if timestep % self.checkpoint_every == 0:
                Checkpoint(model, self.optimizer, self.criterion,
                           self.trainset_list, self.validset, epoch).save()

            del inputs, input_lengths, scripts, targets, output, logit, loss, hypothesis

        Checkpoint(model, self.optimizer, self.criterion, self.trainset_list,
                   self.validset, epoch).save()

        logger.info('train() completed')
        return epoch_loss_total / total_num, cer
Пример #21
0
    def search(self, model: nn.Module, queue: Queue, device: str,
               print_every: int) -> float:
        cer = 0
        total_sent_num = 0
        timestep = 0

        if isinstance(model, nn.DataParallel):
            model = model.module
            if isinstance(model, ListenAttendSpell):
                architecture = 'las'
            elif isinstance(model, SpeechTransformer):
                architecture = 'transformer'
            elif isinstance(model, DeepSpeech2):
                architecture = 'deepspeech2'
            else:
                raise ValueError("Unsupported model : {0}".format(type(model)))
        else:
            if isinstance(model, ListenAttendSpell):
                architecture = 'las'
            elif isinstance(model, SpeechTransformer):
                architecture = 'transformer'
            elif isinstance(model, DeepSpeech2):
                architecture = 'deepspeech2'
            else:
                raise ValueError("Unsupported model : {0}".format(type(model)))

        model.eval()
        model.to(device)

        while True:
            inputs, targets, input_lengths, target_lengths = queue.get()
            if inputs.shape[0] == 0:
                break

            inputs = inputs.to(device)
            targets = targets.to(device)

            if architecture == 'las':
                y_hats = model.greedy_search(inputs, input_lengths, device)
            elif architecture == 'transformer':
                y_hats = model.greedy_search(inputs, input_lengths, device)
            elif architecture == 'deepspeech2':
                y_hats = model.greedy_search(inputs, input_lengths, device)
            else:
                raise ValueError(
                    "Unsupported model : {0}".format(architecture))

            for idx in range(targets.size(0)):
                self.target_list.append(
                    self.vocab.label_to_string(targets[idx]))
                self.predict_list.append(
                    self.vocab.label_to_string(
                        y_hats[idx].cpu().detach().numpy()))

            cer = self.metric(targets[:, 1:], y_hats)
            total_sent_num += targets.size(0)

            if timestep % print_every == 0:
                logger.info('cer: {:.2f}'.format(cer))

            timestep += 1

        return cer
Пример #22
0
def main(config: DictConfig) -> None:
    warnings.filterwarnings('ignore')
    logger.info(OmegaConf.to_yaml(config))
    inference(config)
Пример #23
0
def main(config: DictConfig):
    warnings.filterwarnings('ignore')
    logger.info(OmegaConf.to_yaml(config))
    train(config)
Пример #24
0
def print_train_opts(opt):
    """ Print train options """
    logger.info('--spec_augment: %s' % str(opt.spec_augment))
    logger.info('--use_cuda: %s' % str(opt.use_cuda))
    logger.info('--batch_size: %s' % str(opt.batch_size))
    logger.info('--num_workers: %s' % str(opt.num_workers))
    logger.info('--num_epochs: %s' % str(opt.num_epochs))
    logger.info('--init_lr: %s' % str(opt.init_lr))
    logger.info('--warmup_steps: %s' % str(opt.warmup_steps))
    logger.info('--max_len: %s' % str(opt.max_len))
    logger.info('--max_grad_norm: %s' % str(opt.max_grad_norm))
    logger.info('--teacher_forcing_step: %s' % str(opt.teacher_forcing_step))
    logger.info('--min_teacher_forcing_ratio: %s' %
                str(opt.min_teacher_forcing_ratio))
    logger.info('--seed: %s' % str(opt.seed))
    logger.info('--save_result_every: %s' % str(opt.save_result_every))
    logger.info('--checkpoint_every: %s' % str(opt.checkpoint_every))
    logger.info('--print_every: %s' % str(opt.print_every))
    logger.info('--resume: %s' % str(opt.resume))
Пример #25
0
def print_preprocess_opts(opt):
    """ Print preprocess options """
    logger.info('--mode: %s' % str(opt.mode))
    logger.info('--transform_method: %s' % str(opt.transform_method))
    logger.info('--sample_rate: %s' % str(opt.sample_rate))
    logger.info('--frame_length: %s' % str(opt.frame_length))
    logger.info('--frame_shift: %s' % str(opt.frame_shift))
    logger.info('--n_mels: %s' % str(opt.n_mels))
    logger.info('--normalize: %s' % str(opt.normalize))
    logger.info('--del_silence: %s' % str(opt.del_silence))
    logger.info('--input_reverse: %s' % str(opt.input_reverse))
    logger.info('--feature_extract_by: %s' % str(opt.feature_extract_by))
    logger.info('--time_mask_para: %s' % str(opt.time_mask_para))
    logger.info('--freq_mask_para: %s' % str(opt.freq_mask_para))
    logger.info('--time_mask_num: %s' % str(opt.time_mask_num))
    logger.info('--freq_mask_num: %s' % str(opt.freq_mask_num))
Пример #26
0
def main(config: DictConfig) -> None:
    warnings.filterwarnings('ignore')
    logger.info(OmegaConf.to_yaml(config))
    last_model_checkpoint = train(config)
    torch.save(last_model_checkpoint, os.path.join(os.getcwd(), "last_model_checkpoint.pt"))
Пример #27
0
def print_model_opts(opt):
    """ Print model options """
    logger.info('--architecture: %s' % str(opt.architecture))
    logger.info('--use_bidirectional: %s' % str(opt.use_bidirectional))
    logger.info('--mask_conv: %s' % str(opt.mask_conv))
    logger.info('--hidden_dim: %s' % str(opt.hidden_dim))
    logger.info('--dropout: %s' % str(opt.dropout))
    logger.info('--attn_mechanism: %s' % str(opt.attn_mechanism))
    logger.info('--num_heads: %s' % str(opt.num_heads))
    logger.info('--label_smoothing: %s' % str(opt.label_smoothing))
    logger.info('--num_encoder_layers: %s' % str(opt.num_encoder_layers))
    logger.info('--num_decoder_layers: %s' % str(opt.num_decoder_layers))
    logger.info('--extractor: %s' % str(opt.extractor))
    logger.info('--activation: %s' % str(opt.activation))
    logger.info('--rnn_type: %s' % str(opt.rnn_type))
    logger.info('--teacher_forcing_ratio: %s' % str(opt.teacher_forcing_ratio))
Пример #28
0
    def __train_epoches(self, model: nn.Module, epoch: int,
                        epoch_time_step: int, train_begin_time: float,
                        queue: queue.Queue,
                        teacher_forcing_ratio: float) -> Tuple[float, float]:
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            epoch (int): number of current epoch
            epoch_time_step (int): total time step in one epoch
            train_begin_time (float): time of train begin
            queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)

        Returns: loss, cer
            - **loss** (float): loss of current epoch
            - **cer** (float): character error rate of current epoch
        """
        cer = 1.0
        epoch_loss_total = 0.
        total_num = 0
        timestep = 0

        model.train()

        begin_time = epoch_begin_time = time.time()
        num_workers = self.num_workers

        while True:
            inputs, targets, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                # Empty feats means closing one loader
                num_workers -= 1
                logger.debug('left train_loader: %d' % num_workers)

                if num_workers == 0:
                    break
                else:
                    continue

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            model = model.to(self.device)

            if self.architecture == 'las':
                if isinstance(model, nn.DataParallel):
                    model.module.flatten_parameters()
                else:
                    model.flatten_parameters()

                logit = model(inputs=inputs,
                              input_lengths=input_lengths,
                              targets=targets,
                              teacher_forcing_ratio=teacher_forcing_ratio)
                logit = torch.stack(logit, dim=1).to(self.device)
                targets = targets[:, 1:]

            elif self.architecture == 'transformer':
                logit = model(inputs,
                              input_lengths,
                              targets,
                              return_attns=False)

            else:
                raise ValueError("Unsupported architecture : {0}".format(
                    self.architecture))

            hypothesis = logit.max(-1)[1]
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  targets.contiguous().view(-1))
            epoch_loss_total += loss.item()

            cer = self.metric(targets, hypothesis)
            total_num += int(input_lengths.sum())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step(model)

            timestep += 1
            torch.cuda.empty_cache()

            if timestep % self.print_every == 0:
                current_time = time.time()
                elapsed = current_time - begin_time
                epoch_elapsed = (current_time - epoch_begin_time) / 60.0
                train_elapsed = (current_time - train_begin_time) / 3600.0

                logger.info(
                    'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h, lr: {:.5f}'
                    .format(timestep, epoch_time_step,
                            epoch_loss_total / total_num, cer, elapsed,
                            epoch_elapsed, train_elapsed,
                            self.optimizer.get_lr()))
                begin_time = time.time()

            if timestep % self.save_result_every == 0:
                self.__save_step_result(self.train_step_result,
                                        epoch_loss_total / total_num, cer)

            if timestep % self.checkpoint_every == 0:
                Checkpoint(model, self.optimizer, self.trainset_list,
                           self.validset, epoch).save()

            del inputs, input_lengths, targets, logit, loss, hypothesis

        Checkpoint(model, self.optimizer, self.trainset_list, self.validset,
                   epoch).save()

        logger.info('train() completed')
        return epoch_loss_total / total_num, cer
Пример #29
0
def print_eval_opts(opt):
    """ Print evaltation options """
    logger.info('--dataset_path: %s' % str(opt.dataset_path))
    logger.info('--data_list_path: %s' % str(opt.data_list_path))
    logger.info('--label_path: %s' % str(opt.label_path))
    logger.info('--num_workers: %s' % str(opt.num_workers))
    logger.info('--use_cuda: %s' % str(opt.use_cuda))
    logger.info('--model_path: %s' % str(opt.model_path))
    logger.info('--batch_size: %s' % str(opt.batch_size))
    logger.info('--decode: %s' % str(opt.decode))
    logger.info('--k: %s' % str(opt.k))
    logger.info('--print_every: %s' % str(opt.print_every))
    def train(self,
              model,
              batch_size,
              epoch_time_step,
              num_epochs,
              teacher_forcing_ratio=0.99,
              resume=False):
        """
        Run training for a given model.

        Args:
            model (torch.nn.Module): model to train
            batch_size (int): batch size for experiment
            epoch_time_step (int): number of time step for training
            num_epochs (int): number of epochs for training
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
        """
        start_epoch = 0
        prev_train_cer = 1.

        if resume:
            checkpoint = Checkpoint()
            latest_checkpoint_path = checkpoint.get_latest_checkpoint()
            resume_checkpoint = checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer
            self.criterion = resume_checkpoint.criterion
            self.trainset_list = resume_checkpoint.trainset_list
            self.validset = resume_checkpoint.validset
            start_epoch = resume_checkpoint.epoch
            epoch_time_step = 0
            for trainset in self.trainset_list:
                epoch_time_step += len(trainset)
            epoch_time_step = math.ceil(epoch_time_step / batch_size)

            epoch_time_step = 0
            for trainset in self.trainset_list:
                epoch_time_step += len(trainset)

            epoch_time_step = math.ceil(epoch_time_step / batch_size)

            for g in self.optimizer.optimizer.param_groups:
                g['lr'] = 1e-04

            print("Learning rate : %f", self.optimizer.get_lr())

        logger.info('start')
        train_begin_time = time.time()

        for epoch in range(start_epoch, num_epochs):
            train_queue = queue.Queue(self.num_workers << 1)
            for trainset in self.trainset_list:
                trainset.shuffle()

            # Training
            train_loader = MultiAudioLoader(self.trainset_list, train_queue,
                                            batch_size, self.num_workers)
            train_loader.start()
            train_loss, train_cer = self.train_epoches(model, epoch,
                                                       epoch_time_step,
                                                       train_begin_time,
                                                       train_queue,
                                                       teacher_forcing_ratio)
            train_loader.join()

            Checkpoint(model, self.optimizer, self.criterion,
                       self.trainset_list, self.validset, epoch).save()
            logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' %
                        (epoch, train_loss, train_cer))

            if prev_train_cer - train_cer < self.decay_threshold:
                self.optimizer.set_scheduler(
                    ExponentialDecayLR(self.optimizer.optimizer,
                                       self.optimizer.get_lr(),
                                       self.low_plateau_lr,
                                       self.exp_decay_period),
                    self.exp_decay_period)

            prev_train_cer = train_cer
            teacher_forcing_ratio -= self.teacher_forcing_step
            teacher_forcing_ratio = max(self.min_teacher_forcing_ratio,
                                        teacher_forcing_ratio)

            # Validation
            valid_queue = queue.Queue(self.num_workers << 1)
            valid_loader = AudioLoader(self.validset, valid_queue, batch_size,
                                       0)
            valid_loader.start()

            valid_cer = self.validate(model, valid_queue)
            valid_loader.join()

            logger.info('Epoch %d (Validate) Loss %0.4f CER %0.4f' %
                        (epoch, 0.0, valid_cer))
            self._save_epoch_result(
                train_result=[self.train_dict, train_loss, train_cer],
                valid_result=[self.valid_dict, 0.0, valid_cer])
            logger.info(
                'Epoch %d Training result saved as a csv file complete !!' %
                epoch)

        return model