Esempio n. 1
0
    def load_from_configuration(configuration):
        use_cuda = configuration['use_cuda'] and torch.cuda.is_available(
        )  # Use cuda if specified and available
        default_device = 'cuda' if use_cuda else 'cpu'  # Use default cuda device if possible or use the cpu
        device = configuration['use_device'] if configuration[
            'use_device'] is not None else default_device  # Use a defined device if specified
        gpu_ids = [i for i in range(torch.cuda.device_count())
                   ] if configuration['use_data_parallel'] else [
                       0
                   ]  # Resolve the gpu ids if gpu parallelization is specified
        if configuration['use_device'] and ':' in configuration['use_device']:
            gpu_ids = [int(configuration['use_device'].split(':')[1])]

        use_data_parallel = True if configuration[
            'use_data_parallel'] and use_cuda and len(gpu_ids) > 1 else False

        ConsoleLogger.status('The used device is: {}'.format(device))
        ConsoleLogger.status('The gpu ids are: {}'.format(gpu_ids))

        # Sanity checks
        if not use_cuda and configuration['use_cuda']:
            ConsoleLogger.warn(
                "The configuration file specified use_cuda=True but cuda isn't available"
            )
        if configuration['use_data_parallel'] and len(gpu_ids) < 2:
            ConsoleLogger.warn(
                'The configuration file specified use_data_parallel=True but there is only {} GPU available'
                .format(len(gpu_ids)))

        return DeviceConfiguration(use_cuda, device, gpu_ids,
                                   use_data_parallel)
    def compute_dataset_stats(self):
        initial_index = 0
        attempts = 10
        current_attempt = 0
        total_length = len(self._training_loader)
        train_mfccs = list()
        while current_attempt < attempts:
            try:
                i = initial_index
                train_bar = tqdm(self._training_loader, initial=initial_index)
                for data in train_bar:
                    input_features = data['input_features']
                    train_mfccs.append(input_features.detach().view(input_features.size(1), input_features.size(2)).numpy())

                    i += 1

                    if i == total_length:
                        train_bar.update(total_length)
                        break

                train_bar.close()
                break

            except KeyboardInterrupt:
                train_bar.close()
                ConsoleLogger.warn('Keyboard interrupt detected. Leaving the function...')
                return
            except:
                error_message = 'An error occured in the data loader at {}. Current attempt: {}/{}'.format(i, current_attempt+1, attempts)
                self._logger.exception(error_message)
                ConsoleLogger.error(error_message)
                initial_index = i
                current_attempt += 1
                continue


        ConsoleLogger.status('Compute mean of mfccs training set...')
        train_mean = np.concatenate(train_mfccs).mean(axis=0)

        ConsoleLogger.status('Compute std of mfccs training set...')
        train_std = np.concatenate(train_mfccs).std(axis=0)

        stats = {
            'train_mean': train_mean,
            'train_std': train_std
        }

        ConsoleLogger.status('Writing stats in file...')
        with open(self._normalizer_path, 'wb') as file: # TODO: do not use hardcoded path
            pickle.dump(stats, file)

        train_mfccs_norm = (train_mfccs[0] - train_mean) / train_std

        ConsoleLogger.status('Computing example plot...')
        _, axs = plt.subplots(2, sharex=True)
        axs[0].imshow(train_mfccs[0].T, aspect='auto', origin='lower')
        axs[0].set_ylabel('Unnormalized')
        axs[1].imshow(train_mfccs_norm.T, aspect='auto', origin='lower')
        axs[1].set_ylabel('Normalized')
        plt.savefig('mfcc_normalization_comparaison.png') # TODO: do not use hardcoded path
    def load_configuration_and_checkpoints(experiments_path, experiment_name):
        configuration_file, checkpoint_files = CheckpointUtils.search_configuration_and_checkpoints_files(
            experiments_path, experiment_name)

        # Check if a configuration file was found
        if not configuration_file:
            raise ValueError(
                'No configuration file found with name: {}'.format(
                    experiment_name))

        # Check if at least one checkpoint file was found
        if len(checkpoint_files) == 0:
            ConsoleLogger.warn(
                'No checkpoint files found with name: {}'.format(
                    experiment_name))

        return configuration_file, checkpoint_files
Esempio n. 4
0
        def process(loader, output_dir, input_features_name,
                    output_features_name, rate, input_filters_number,
                    output_filters_number, input_target_shape,
                    augment_output_features, export_one_hot_features):

            initial_index = 0
            attempts = 10
            current_attempt = 0
            total_length = len(loader)

            while current_attempt < attempts:
                try:
                    i = initial_index
                    bar = tqdm(loader, initial=initial_index)
                    for data in bar:
                        (preprocessed_audio, one_hot, speaker_id, quantized,
                         wav_filename, sampling_rate, shifting_time,
                         random_starting_index, preprocessed_length,
                         top_db) = data

                        output_path = output_dir + os.sep + str(i) + '.pickle'
                        if os.path.isfile(output_path):
                            if os.path.getsize(output_path) == 0:
                                bar.set_description(
                                    '{} already exists but is empty. Computing it again...'
                                    .format(output_path))
                                os.remove(output_path)
                            else:
                                bar.set_description(
                                    '{} already exists'.format(output_path))
                            i += 1
                            continue

                        input_features = SpeechFeatures.features_from_name(
                            name=input_features_name,
                            signal=preprocessed_audio,
                            rate=rate,
                            filters_number=input_filters_number)

                        if input_features.shape[0] != input_target_shape[
                                0] or input_features.shape[
                                    1] != input_target_shape[1]:
                            ConsoleLogger.warn(
                                "Raw features number {} with invalid dimension {} will not be saved. Target shape: {}"
                                .format(i, input_features.shape,
                                        input_target_shape))
                            i += 1
                            continue

                        output_features = SpeechFeatures.features_from_name(
                            name=output_features_name,
                            signal=preprocessed_audio,
                            rate=rate,
                            filters_number=output_filters_number,
                            augmented=augment_output_features)

                        # TODO: add an option in configuration to save quantized/one_hot or not
                        output = {
                            'preprocessed_audio':
                            preprocessed_audio,
                            'wav_filename':
                            wav_filename,
                            'input_features':
                            input_features,
                            'one_hot':
                            one_hot
                            if export_one_hot_features else np.array([]),
                            'quantized':
                            np.array([]),
                            'speaker_id':
                            speaker_id,
                            'output_features':
                            output_features,
                            'shifting_time':
                            shifting_time,
                            'random_starting_index':
                            random_starting_index,
                            'preprocessed_length':
                            preprocessed_length,
                            'sampling_rate':
                            sampling_rate,
                            'top_db':
                            top_db
                        }

                        with open(output_path, 'wb') as file:
                            pickle.dump(output, file)

                        bar.set_description('{} saved'.format(output_path))

                        i += 1

                        if i == total_length:
                            bar.update(total_length)
                            break

                    bar.close()
                    break
                except KeyboardInterrupt:
                    bar.close()
                    ConsoleLogger.warn(
                        'Keyboard interrupt detected. Leaving the function...')
                    return
                except:
                    error_message = 'An error occured in the data loader at {}/{}. Current attempt: {}/{}'.format(
                        output_dir, i, current_attempt + 1, attempts)
                    self._logger.exception(error_message)
                    ConsoleLogger.error(error_message)
                    initial_index = i
                    current_attempt += 1
                    continue
Esempio n. 5
0
    def _many_to_one_mapping(self):
        # TODO: fix it for batch size greater than one

        tokens_selections = list()
        val_speaker_ids = set()

        with tqdm(self._data_stream.validation_loader) as bar:
            for data in bar:
                valid_originals = data['input_features'].to(
                    self._device).permute(0, 2, 1).contiguous().float()
                speaker_ids = data['speaker_id'].to(self._device)
                shifting_times = data['shifting_time'].to(self._device)
                wav_filenames = data['wav_filename']

                speaker_id = wav_filenames[0][0].split(os.sep)[-2]
                val_speaker_ids.add(speaker_id)

                if speaker_id not in os.listdir(self._vctk.raw_folder +
                                                os.sep + 'VCTK-Corpus' +
                                                os.sep + 'phonemes'):
                    # TODO: log the missing folders
                    continue

                z = self._model.encoder(valid_originals)
                z = self._model.pre_vq_conv(z)
                _, quantized, _, encodings, _, encoding_indices, _, \
                    _, _, _, _ = self._model.vq(z)
                valid_reconstructions = self._model.decoder(
                    quantized, self._data_stream.speaker_dic, speaker_ids)
                B = valid_reconstructions.size(0)

                encoding_indices = encoding_indices.view(B, -1, 1)

                for i in range(len(valid_reconstructions)):
                    wav_filename = wav_filenames[0][i]
                    utterence_key = wav_filename.split('/')[-1].replace(
                        '.wav', '')
                    phonemes_alignment_path = os.sep.join(wav_filename.split('/')[:-3]) + os.sep + 'phonemes' + os.sep + utterence_key.split('_')[0] + os.sep \
                        + utterence_key + '.TextGrid'
                    tg = textgrid.TextGrid()
                    tg.read(phonemes_alignment_path)
                    entry = {
                        'encoding_indices':
                        encoding_indices[i].detach().cpu().numpy(),
                        'groundtruth':
                        tg.tiers[1],
                        'shifting_time':
                        shifting_times[i].detach().cpu().item()
                    }
                    tokens_selections.append(entry)

        ConsoleLogger.status(val_speaker_ids)

        ConsoleLogger.status('{} tokens selections retreived'.format(
            len(tokens_selections)))

        phonemes_mapping = dict()
        # For each tokens selections (i.e. the number of valuations)
        for entry in tokens_selections:
            encoding_indices = entry['encoding_indices']
            unified_encoding_indices_time_scale = self._compute_unified_time_scale(
                encoding_indices.shape[0], downsampling_factor=2
            )  # Compute the time scale array for each token
            """
            Search the grountruth phoneme where the selected token index time scale
            is within the groundtruth interval.
            Then, it adds the selected token index in the list of indices selected for
            the a specific token in the tokens mapping dictionnary.
            """
            for i in range(len(unified_encoding_indices_time_scale)):
                index_time_scale = unified_encoding_indices_time_scale[
                    i] + entry['shifting_time']
                corresponding_phoneme = None
                for interval in entry['groundtruth']:
                    # TODO: replace that by nearest interpolation
                    if index_time_scale >= interval.minTime and index_time_scale <= interval.maxTime:
                        corresponding_phoneme = interval.mark
                        break
                if not corresponding_phoneme:
                    ConsoleLogger.warn(
                        "Corresponding phoneme not found. unified_encoding_indices_time_scale[{}]: {}"
                        "entry['shifting_time']: {} index_time_scale: {}".
                        format(i, unified_encoding_indices_time_scale[i],
                               entry['shifting_time'], index_time_scale))
                if corresponding_phoneme not in phonemes_mapping:
                    phonemes_mapping[corresponding_phoneme] = list()
                phonemes_mapping[corresponding_phoneme].append(
                    encoding_indices[i][0])

        ConsoleLogger.status('phonemes_mapping: {}'.format(phonemes_mapping))

        tokens_mapping = dict(
        )  # dictionnary that will contain the distribution for each token to fits with a certain phoneme
        """
        Fill the tokens_mapping such that for each token index (key)
        it contains the list of tuple of (phoneme, prob) where prob
        is the probability that the token fits this phoneme.
        """
        for phoneme, indices in phonemes_mapping.items():
            for index in list(set(indices)):
                if index not in tokens_mapping:
                    tokens_mapping[index] = list()
                tokens_mapping[index].append(
                    (phoneme, indices.count(index) / len(indices)))

        # Sort the probabilities for each token
        for index, distribution in tokens_mapping.items():
            tokens_mapping[index] = list(
                sorted(distribution, key=lambda x: x[1], reverse=True))

        ConsoleLogger.status('tokens_mapping: {}'.format(tokens_mapping))

        with open(
                self._results_path + os.sep + self._experiment_name +
                '_phonemes_mapping.pickle', 'wb') as f:
            pickle.dump(phonemes_mapping, f)

        with open(
                self._results_path + os.sep + self._experiment_name +
                '_tokens_mapping.pickle', 'wb') as f:
            pickle.dump(tokens_mapping, f)
    def compute_groundtruth_alignments(self):
        ConsoleLogger.status(
            'Computing groundtruth alignments of VCTK val dataset...')

        desired_time_interval = 0.02
        extended_alignment_dataset = list()
        possible_phonemes = set()
        phonemes_counter = dict()
        total_phonemes_apparations = 0
        data_length = self._configuration['length'] / self._configuration[
            'sampling_rate']

        with tqdm(self._data_stream.validation_loader) as bar:
            for data in bar:
                speaker_ids = data['speaker_id'].to(self._device)
                wav_filenames = data['wav_filename']
                shifting_times = data['shifting_time'].to(self._device)
                loader_indices = data['index'].to(self._device)

                speaker_id = wav_filenames[0][0].split('/')[-2]
                if speaker_id not in os.listdir(self._vctk.raw_folder +
                                                os.sep + 'VCTK-Corpus' +
                                                os.sep + 'phonemes'):
                    # TODO: log the missing folders
                    continue

                for i in range(len(shifting_times)):
                    wav_filename = wav_filenames[0][i]
                    utterence_key = wav_filename.split('/')[-1].replace(
                        '.wav', '')
                    phonemes_alignment_path = os.sep.join(wav_filename.split('/')[:-3]) + os.sep + 'phonemes' + os.sep + utterence_key.split('_')[0] + os.sep \
                        + utterence_key + '.TextGrid'
                    if not os.path.isfile(phonemes_alignment_path):
                        # TODO: log this warn instead of print it
                        #ConsoleLogger.warn('File {} not found'.format(phonemes_alignment_path))
                        break

                    shifting_time = shifting_times[0].detach().cpu().item()
                    target_time_scale = np.arange(
                        (data_length / desired_time_interval) +
                        1) * desired_time_interval + shifting_time
                    shifted_indices = np.where(
                        target_time_scale >= shifting_time)
                    tg = textgrid.TextGrid()
                    tg.read(phonemes_alignment_path)
                    """if target_time_scale[-1] > tg.tiers[1][-1].maxTime:
                        ConsoleLogger.error('Shifting time error at {}.pickle: shifting_time:{}' \
                            ' target_time_scale[-1]:{} > tg.tiers[1][-1].maxTime:{}'.format(
                            loader_indices[i].detach().cpu().item(),
                            shifting_time,
                            target_time_scale[-1],
                            tg.tiers[1][-1].maxTime))
                        continue"""

                    phonemes = list()
                    current_target_time_index = 0
                    for interval in tg.tiers[1]:
                        if interval.mark in ['', '-', "'"]:
                            if interval == tg.tiers[1][-1] and len(
                                    phonemes) != int(
                                        data_length / desired_time_interval):
                                previous_interval = tg.tiers[1][-2]
                                ConsoleLogger.warn(
                                    "{}/{} phonemes aligned. Add the last valid phoneme '{}' in the list to have the correct number.\n"
                                    "Sanity checks to find the possible cause:\n"
                                    "current_target_time_index < (data_length / desired_time_interval): {}\n"
                                    "target_time_scale[current_target_time_index] >= interval.minTime: {}\n"
                                    "target_time_scale[current_target_time_index] <= interval.maxTime: {}"
                                    .format(
                                        len(phonemes),
                                        int(data_length /
                                            desired_time_interval),
                                        previous_interval.mark,
                                        current_target_time_index <
                                        (data_length / desired_time_interval),
                                        target_time_scale[
                                            current_target_time_index] >=
                                        previous_interval.minTime,
                                        target_time_scale[
                                            current_target_time_index] <=
                                        previous_interval.maxTime))
                                phonemes.append(previous_interval.mark)
                            continue
                        interval.minTime = float(interval.minTime)
                        interval.maxTime = float(interval.maxTime)
                        if interval.maxTime < shifting_time:
                            continue
                        interval.mark = interval.mark[:-1] if interval.mark[
                            -1].isdigit() else interval.mark
                        possible_phonemes.add(interval.mark)
                        if interval.mark not in phonemes_counter:
                            phonemes_counter[interval.mark] = 0
                        phonemes_counter[interval.mark] += 1
                        total_phonemes_apparations += 1
                        while current_target_time_index < (data_length / desired_time_interval) and \
                            target_time_scale[current_target_time_index] >= interval.minTime and \
                            target_time_scale[current_target_time_index] <= interval.maxTime:
                            phonemes.append(interval.mark)
                            current_target_time_index += 1
                        if len(phonemes) == int(data_length /
                                                desired_time_interval):
                            break
                    if len(phonemes) != int(
                            data_length / desired_time_interval):
                        intervals = [
                            'min:{} max:{} mark:{}'.format(
                                interval.minTime, interval.maxTime,
                                interval.mark) for interval in tg.tiers[1]
                        ]
                        ConsoleLogger.error(
                            'Error - min:{} max:{} shifting:{} target_time_scale: {} intervals: {}'
                            .format(interval.minTime, interval.maxTime,
                                    shifting_time, target_time_scale,
                                    intervals))
                        ConsoleLogger.error('#phonemes:{} phonemes:{}'.format(
                            len(phonemes), phonemes))
                    else:
                        extended_alignment_dataset.append(
                            (utterence_key, phonemes))

        with open(
                self._results_path + os.sep +
                'vctk_groundtruth_alignments.pickle', 'wb') as f:
            pickle.dump(
                {
                    'desired_time_interval': desired_time_interval,
                    'extended_alignment_dataset': extended_alignment_dataset,
                    'possible_phonemes': list(possible_phonemes),
                    'phonemes_counter': phonemes_counter,
                    'total_phonemes_apparations': total_phonemes_apparations
                }, f)
    def compute_empirical_bigrams_matrix(self, wo_diag=True):
        ConsoleLogger.status(
            'Computing empirical bigrams matrix of VCTK val dataset {}...'.
            format('without diagonal' if wo_diag else 'with diagonal'))

        alignments_dic = None
        with open(
                self._results_path + os.sep + self._experiment_name +
                '_vctk_empirical_alignments.pickle', 'rb') as f:
            alignments_dic = pickle.load(f)

        all_alignments = alignments_dic['all_alignments']
        encodings_counter = alignments_dic['encodings_counter']
        desired_time_interval = alignments_dic['desired_time_interval']
        total_indices_apparations = alignments_dic['total_indices_apparations']
        num_embeddings = alignments_dic['num_embeddings']

        if num_embeddings > 100:
            ConsoleLogger.warn(
                'Stopping the computation of empirical bigrams matrix because the embedding number ({}) is huge'
                .format(num_embeddings))
            return

        bigrams = np.zeros((num_embeddings, num_embeddings), dtype=int)
        previous_index_counter = np.zeros((num_embeddings), dtype=int)

        for _, alignment in all_alignments:
            previous_encoding_index = alignment[0]
            for i in range(len(alignment)):
                current_encoding_index = alignment[i]
                bigrams[current_encoding_index][previous_encoding_index] += 1
                previous_index_counter[previous_encoding_index] += 1
                previous_encoding_index = current_encoding_index

        if wo_diag:
            np.fill_diagonal(bigrams, 0)  # Zeroes the diagonal values
        previous_index_counter[
            previous_index_counter ==
            0] = 1  # Replace the zeros of the previous phonemes number by one to avoid dividing by zeros
        bigrams = normalize(bigrams / previous_index_counter,
                            axis=1,
                            norm='l1')
        round_bigrams = np.around(bigrams.copy(), decimals=2)

        fig, ax = plt.subplots(figsize=(20, 20))

        im = ax.matshow(round_bigrams)
        ax.set_xticks(np.arange(num_embeddings))
        ax.set_yticks(np.arange(num_embeddings))
        ax.set_xticklabels(np.arange(num_embeddings))
        ax.set_yticklabels(np.arange(num_embeddings))
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im, cax=cax)

        for i in range(num_embeddings):
            for j in range(num_embeddings):
                text = ax.text(j,
                               i,
                               round_bigrams[i, j],
                               ha='center',
                               va='center',
                               color='w')

        output_path = self._results_path + os.sep + self._experiment_name + '_vctk_empirical_bigrams_{}{}ms'.format(
            'wo_diag_' if wo_diag else '', int(desired_time_interval * 1000))

        fig.tight_layout()
        fig.savefig(output_path + '.png')
        plt.close(fig)

        with open(output_path + '.npy', 'wb') as f:
            np.save(f, bigrams)