Ejemplo n.º 1
0
    def run(self):
        """ Load data from MelSpectrogramDataset """
        logger.debug('loader %d start' % self.thread_id)

        while True:
            items = list()

            for _ in range(self.batch_size):
                if self.index >= self.dataset_count:
                    break

                feature_vector, transcript = self.dataset.get_item(self.index)

                if feature_vector is not None:
                    items.append((feature_vector, transcript))

                self.index += 1

            if len(items) == 0:
                batch = self._create_empty_batch()
                self.queue.put(batch)
                break

            batch = self.collate_fn(items, self.pad_id)
            self.queue.put(batch)

        logger.debug('loader %d stop' % self.thread_id)
Ejemplo n.º 2
0
def load_audio(audio_path: str, del_silence: bool = False, extension: str = 'pcm') -> np.ndarray:
    """
    Load audio file (PCM) to sound. if del_silence is True, Eliminate all sounds below 30dB.
    If exception occurs in numpy.memmap(), return None.
    """
    try:
        if extension == 'pcm':
            signal = np.memmap(audio_path, dtype='h', mode='r').astype('float32')

            if del_silence:
                non_silence_indices = split(signal, top_db=30)
                signal = np.concatenate([signal[start:end] for start, end in non_silence_indices])

            return signal / 32767  # normalize audio

        elif extension == 'wav' or extension == 'flac':
            signal, _ = librosa.load(audio_path, sr=16000)
            return signal

    except ValueError:
        logger.debug('ValueError in {0}'.format(audio_path))
        return None
    except RuntimeError:
        logger.debug('RuntimeError in {0}'.format(audio_path))
        return None
    except IOError:
        logger.debug('IOError in {0}'.format(audio_path))
        return None
Ejemplo n.º 3
0
def load_audio(audio_path, del_silence):
    """
    Load audio file (PCM) to sound. if del_silence is True, Eliminate all sounds below 30dB.
    If exception occurs in numpy.memmap(), return None.
    """
    try:
        signal = np.memmap(audio_path, dtype='h', mode='r').astype('float32')

        if del_silence:
            non_silence_indices = split(signal, top_db=30)
            signal = np.concatenate([signal[start:end] for start, end in non_silence_indices])

        return signal / 32767  # normalize audio

    except ValueError:
        logger.debug('ValueError in {0}'.format(audio_path))
        return None
    except RuntimeError:
        logger.debug('RuntimeError in {0}'.format(audio_path))
        return None
    except IOError:
        logger.debug('IOError in {0}'.format(audio_path))
        return None
Ejemplo n.º 4
0
    def _train_epoches(
        self,
        model: nn.Module,
        epoch: int,
        epoch_time_step: int,
        train_begin_time: float,
        queue: queue.Queue,
        teacher_forcing_ratio: float,
    ) -> Tuple[nn.Module, float, float]:
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            epoch (int): number of current epoch
            epoch_time_step (int): total time step in one epoch
            train_begin_time (float): time of train begin
            queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)

        Returns: loss, cer
            - **loss** (float): loss of current epoch
            - **cer** (float): character error rate of current epoch
        """
        architecture = self.architecture
        if self.architecture == 'conformer':
            if isinstance(model, nn.DataParallel):
                architecture = 'conformer_t' if model.module.decoder is not None else 'conformer_ctc'
            else:
                architecture = 'conformer_t' if model.decoder is not None else 'conformer_ctc'

        cer = 1.0
        epoch_loss_total = 0.
        total_num = 0
        timestep = 0

        model.train()

        begin_time = epoch_begin_time = time.time()
        num_workers = self.num_workers

        while True:
            inputs, targets, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                # Empty feats means closing one loader
                num_workers -= 1
                logger.debug('left train_loader: %d' % num_workers)

                if num_workers == 0:
                    break
                else:
                    continue

            self.optimizer.zero_grad()

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            input_lengths = input_lengths.to(self.device)
            target_lengths = torch.as_tensor(target_lengths).to(self.device)

            model = model.to(self.device)
            output, loss, ctc_loss, cross_entropy_loss = self._model_forward(
                teacher_forcing_ratio=teacher_forcing_ratio,
                inputs=inputs,
                input_lengths=input_lengths,
                targets=targets,
                target_lengths=target_lengths,
                model=model,
                architecture=architecture,
            )

            if architecture not in ('rnnt', 'conformer_t'):
                y_hats = output.max(-1)[1]
                cer = self.metric(targets[:, 1:], y_hats)

            loss.backward()
            self.optimizer.step(model)

            total_num += int(input_lengths.sum())
            epoch_loss_total += loss.item()

            timestep += 1
            torch.cuda.empty_cache()

            if timestep % self.print_every == 0:
                current_time = time.time()
                elapsed = current_time - begin_time
                epoch_elapsed = (current_time - epoch_begin_time) / 60.0
                train_elapsed = (current_time - train_begin_time) / 3600.0

                if architecture in ('rnnt', 'conformer_t'):
                    logger.info(
                        self.rnnt_log_format.format(
                            timestep,
                            epoch_time_step,
                            loss,
                            elapsed,
                            epoch_elapsed,
                            train_elapsed,
                            self.optimizer.get_lr(),
                        ))
                else:
                    if self.joint_ctc_attention:
                        logger.info(
                            self.log_format.format(
                                timestep,
                                epoch_time_step,
                                loss,
                                ctc_loss,
                                cross_entropy_loss,
                                cer,
                                elapsed,
                                epoch_elapsed,
                                train_elapsed,
                                self.optimizer.get_lr(),
                            ))
                    else:
                        logger.info(
                            self.log_format.format(
                                timestep,
                                epoch_time_step,
                                loss,
                                cer,
                                elapsed,
                                epoch_elapsed,
                                train_elapsed,
                                self.optimizer.get_lr(),
                            ))
                begin_time = time.time()

            if timestep % self.save_result_every == 0:
                self._save_step_result(self.train_step_result,
                                       epoch_loss_total / total_num, cer)

            if timestep % self.checkpoint_every == 0:
                Checkpoint(model, self.optimizer, self.trainset_list,
                           self.validset, epoch).save()

            del inputs, input_lengths, targets, output, loss

        Checkpoint(model, self.optimizer, self.trainset_list, self.validset,
                   epoch).save()
        logger.info('train() completed')

        return model, epoch_loss_total / total_num, cer
Ejemplo n.º 5
0
    def __train_epoches(self, model: nn.Module, epoch: int,
                        epoch_time_step: int, train_begin_time: float,
                        queue: queue.Queue,
                        teacher_forcing_ratio: float) -> Tuple[float, float]:
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            epoch (int): number of current epoch
            epoch_time_step (int): total time step in one epoch
            train_begin_time (float): time of train begin
            queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)

        Returns: loss, cer
            - **loss** (float): loss of current epoch
            - **cer** (float): character error rate of current epoch
        """
        cer = 1.0
        epoch_loss_total = 0.
        total_num = 0
        timestep = 0

        model.train()

        begin_time = epoch_begin_time = time.time()
        num_workers = self.num_workers

        while True:
            inputs, targets, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                # Empty feats means closing one loader
                num_workers -= 1
                logger.debug('left train_loader: %d' % num_workers)

                if num_workers == 0:
                    break
                else:
                    continue

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            model = model.to(self.device)

            if self.architecture == 'las':
                if isinstance(model, nn.DataParallel):
                    model.module.flatten_parameters()
                else:
                    model.flatten_parameters()

                logit = model(inputs=inputs,
                              input_lengths=input_lengths,
                              targets=targets,
                              teacher_forcing_ratio=teacher_forcing_ratio)
                logit = torch.stack(logit, dim=1).to(self.device)
                targets = targets[:, 1:]

            elif self.architecture == 'transformer':
                logit = model(inputs,
                              input_lengths,
                              targets,
                              return_attns=False)

            else:
                raise ValueError("Unsupported architecture : {0}".format(
                    self.architecture))

            hypothesis = logit.max(-1)[1]
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  targets.contiguous().view(-1))
            epoch_loss_total += loss.item()

            cer = self.metric(targets, hypothesis)
            total_num += int(input_lengths.sum())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step(model)

            timestep += 1
            torch.cuda.empty_cache()

            if timestep % self.print_every == 0:
                current_time = time.time()
                elapsed = current_time - begin_time
                epoch_elapsed = (current_time - epoch_begin_time) / 60.0
                train_elapsed = (current_time - train_begin_time) / 3600.0

                logger.info(
                    'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h, lr: {:.5f}'
                    .format(timestep, epoch_time_step,
                            epoch_loss_total / total_num, cer, elapsed,
                            epoch_elapsed, train_elapsed,
                            self.optimizer.get_lr()))
                begin_time = time.time()

            if timestep % self.save_result_every == 0:
                self.__save_step_result(self.train_step_result,
                                        epoch_loss_total / total_num, cer)

            if timestep % self.checkpoint_every == 0:
                Checkpoint(model, self.optimizer, self.trainset_list,
                           self.validset, epoch).save()

            del inputs, input_lengths, targets, logit, loss, hypothesis

        Checkpoint(model, self.optimizer, self.trainset_list, self.validset,
                   epoch).save()

        logger.info('train() completed')
        return epoch_loss_total / total_num, cer
    def train_epoches(self, model, epoch, epoch_time_step, train_begin_time,
                      queue, teacher_forcing_ratio):
        """
        Run training one epoch

        Args:
            model (torch.nn.Module): model to train
            epoch (int): number of current epoch
            epoch_time_step (int): total time step in one epoch
            train_begin_time (int): time of train begin
            queue (queue.Queue): training queue, containing input, targets, input_lengths, target_lengths
            teacher_forcing_ratio (float): teaching forcing ratio (default 0.99)

        Returns: loss, cer
            - **loss** (float): loss of current epoch
            - **cer** (float): character error rate of current epoch
        """
        cer = 1.0
        epoch_loss_total = 0.
        total_num = 0
        timestep = 0

        model.train()
        begin_time = epoch_begin_time = time.time()

        while True:
            inputs, scripts, input_lengths, target_lengths = queue.get()

            if inputs.shape[0] == 0:
                # Empty feats means closing one loader
                self.num_workers -= 1
                logger.debug('left train_loader: %d' % self.num_workers)

                if self.num_workers == 0:
                    break
                else:
                    continue

            inputs = inputs.to(self.device)
            scripts = scripts.to(self.device)
            targets = scripts[:, 1:]

            model.module.flatten_parameters()
            output = model(inputs,
                           input_lengths,
                           scripts,
                           teacher_forcing_ratio=teacher_forcing_ratio)[0]

            logit = torch.stack(output, dim=1).to(self.device)
            hypothesis = logit.max(-1)[1]

            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  targets.contiguous().view(-1))
            epoch_loss_total += loss.item()

            cer = self.metric(targets, hypothesis)
            total_num += int(input_lengths.sum())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step(model, loss.item())

            timestep += 1
            torch.cuda.empty_cache()

            if timestep % self.print_every == 0:
                current_time = time.time()
                elapsed = current_time - begin_time
                epoch_elapsed = (current_time - epoch_begin_time) / 60.0
                train_elapsed = (current_time - train_begin_time) / 3600.0

                logger.info(
                    'timestep: {:4d}/{:4d}, loss: {:.4f}, cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h'
                    .format(timestep, epoch_time_step,
                            epoch_loss_total / total_num, cer, elapsed,
                            epoch_elapsed, train_elapsed))
                begin_time = time.time()

            if timestep % self.save_result_every == 0:
                self._save_step_result(self.train_step_result,
                                       epoch_loss_total / total_num, cer)

            if timestep % self.checkpoint_every == 0:
                Checkpoint(model, self.optimizer, self.criterion,
                           self.trainset_list, self.validset, epoch).save()

            del inputs, input_lengths, scripts, targets, output, logit, loss, hypothesis

        Checkpoint(model, self.optimizer, self.criterion, self.trainset_list,
                   self.validset, epoch).save()

        logger.info('train() completed')
        return epoch_loss_total / total_num, cer