Пример #1
0
    def predict_method(
            self,
            data: List[List[str]],
            max_seq_len: int = 128,
            batch_size: int = 1,
            use_gpu: bool = False
    ):
        """
        Run predict method as a service.
        Serving as a task which is specified from serving config.
        Tasks supported:
        1. seq-cls: sequence classification;
        2. token-cls: sequence labeling;
        3. None: embedding.

        Args:
            data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
            max_seq_len (:obj:`int`, `optional`, defaults to 128):
                If set to a number, will limit the total sequence returned so that it has a maximum length.
            batch_size(obj:`int`, defaults to 1): The number of batch.
            use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.

        Returns:
            results(obj:`list`): All the predictions labels.
        """
        if self.task in self._tasks_supported:  # cls service
            if self.label_map:
                # compatible with json decoding label_map
                self.label_map = {int(k): v for k, v in self.label_map.items()}
            results = self.predict(data, max_seq_len, batch_size, use_gpu)

            if self.task == 'token-cls':
                # remove labels of [CLS] token and pad tokens
                results = [
                    token_labels[1:len(data[i][0])+1] for i, token_labels in enumerate(results)
                ]
            return results
        elif self.task is None:                 # embedding service
            token_results, sentence_results = self.get_embedding(data, max_seq_len, batch_size, use_gpu)
            token_results = [
                token_embeddings[1:len(data[i][0])+1] for i, token_embeddings in enumerate(token_results)
            ]
            return token_results, sentence_results
        else:                                   # unknown service
            logger.error(
                f'Unknown task {self.task}, current tasks supported:\n'
                '1. seq-cls: sequence classification service;\n'
                '2. token-cls: sequence labeling service;\n'
                '3. None: embedding service'
            )
        return
Пример #2
0
def extract_melspectrogram(y,
                           sample_rate: int = 32000,
                           window_size: int = 1024,
                           hop_size: int = 320,
                           mel_bins: int = 64,
                           fmin: int = 50,
                           fmax: int = 14000,
                           window: str = 'hann',
                           center: bool = True,
                           pad_mode: str = 'reflect',
                           ref: float = 1.0,
                           amin: float = 1e-10,
                           top_db: float = None):
    '''
    Extract Mel Spectrogram from a waveform.
    '''
    try:
        import librosa
    except Exception:
        logger.error(
            'Failed to import librosa. Please check that librosa and numba are correctly installed.'
        )
        raise

    s = librosa.stft(y,
                     n_fft=window_size,
                     hop_length=hop_size,
                     win_length=window_size,
                     window=window,
                     center=center,
                     pad_mode=pad_mode)

    power = np.abs(s)**2
    melW = librosa.filters.mel(sr=sample_rate,
                               n_fft=window_size,
                               n_mels=mel_bins,
                               fmin=fmin,
                               fmax=fmax)
    mel = np.matmul(melW, power)
    db = librosa.power_to_db(mel, ref=ref, amin=amin, top_db=None)
    db = db.transpose()
    return db
Пример #3
0
    def __init__(self, extract_embedding: bool = True, checkpoint: str = None):

        super(CNN14, self).__init__()
        self.bn0 = nn.BatchNorm2D(64)
        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
        self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
        self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)

        self.fc1 = nn.Linear(2048, self.emb_size)
        self.fc_audioset = nn.Linear(self.emb_size, 527)

        if checkpoint is not None and os.path.isfile(checkpoint):
            state_dict = paddle.load(checkpoint)
            self.set_state_dict(state_dict)
            logger.info(
                f'Loaded CNN14 pretrained parameters from: {checkpoint}')
        else:
            logger.error(
                'No valid checkpoints for CNN14. Start training from scratch.')

        self.extract_embedding = extract_embedding