Пример #1
0
    def fetch(self):
        ConsoleLogger.status('Begin fetch...')
        font1 = {
            'family': 'Times New Roman',
            'weight': 'normal',
            'size': 40,
        }
        indices = torch.LongTensor(np.load('indices.npy')).to(self._device)
        idx = [32, 128, 512]
        title = [0.0625, 0.25, 1.0]
        os.makedirs('./audio', exist_ok=True)
        os.makedirs('./img', exist_ok=True)

        for i in range(len(indices)):

            fig, axs = plt.subplots(len(idx), 1, figsize=(35, 30), sharex=True)

            for j in range(len(idx)):
                axs[j].set_title('Fraction of Dimensions:' + str(title[j]),
                                 font1)
                x = self._model.indices_fetch(indices[i][:idx[j]].unsqueeze(
                    0))[0, 0].detach().cpu().numpy()
                axs[j].plot(np.arange(len(x)), x)
                write(
                    "./audio/sample_" + str(i) + "_" + str(title[j]) + ".wav",
                    16000, x)

            plt.savefig('./img/sample_' + str(i),
                        bbox_inches='tight',
                        pad_inches=0)
            plt.clf()
Пример #2
0
 def train(self):
     ConsoleLogger.status("Running the experiment called '{}'".format(
         self._name))
     ConsoleLogger.status('Begins to train the model')
     self._trainer.train()
     ConsoleLogger.success(
         "Succeed to runned the experiment called '{}'".format(self._name))
Пример #3
0
 def evaluate(self, evaluation_options):
     ConsoleLogger.status("Running the experiment called '{}'".format(
         self._name))
     ConsoleLogger.status('Begins to evaluate the model')
     self._evaluator.evaluate(evaluation_options)
     ConsoleLogger.success(
         "Succeed to runned the experiment called '{}'".format(self._name))
Пример #4
0
    def test_global_conditioning(self):
        configuration = None
        with open('../../configurations/vctk_features.yaml', 'r') as configuration_file:
            configuration = yaml.load(configuration_file)
        device_configuration = DeviceConfiguration.load_from_configuration(configuration)
        data_stream = VCTKSpeechStream(configuration, device_configuration.gpu_ids, device_configuration.use_cuda)
        (x_enc, x_dec, speaker_id, _, _) = next(iter(data_stream.training_loader))

        ConsoleLogger.status('x_enc.size(): {}'.format(x_enc.size()))
        ConsoleLogger.status('x_dec.size(): {}'.format(x_dec.size()))

        x = x_dec.squeeze(-1)
        global_conditioning = GlobalConditioning.compute(
            speaker_dic=data_stream.speaker_dic,
            speaker_ids=speaker_id,
            x_one_hot=x,
            expand=False
        )
        self.assertEqual(global_conditioning.size(), torch.Size([1, 128, 1]))
        ConsoleLogger.success('global_conditioning.size(): {}'.format(global_conditioning.size()))

        expanded_global_conditioning = GlobalConditioning.compute(
            speaker_dic=data_stream.speaker_dic,
            speaker_ids=speaker_id,
            x_one_hot=x,
            expand=True
        )
        self.assertEqual(expanded_global_conditioning.size(), torch.Size([1, 128, 7680]))
        ConsoleLogger.success('expanded_global_conditioning.size(): {}'.format(expanded_global_conditioning.size()))
Пример #5
0
    def comupte_empirical_encodings_frequency(self):
        ConsoleLogger.status(
            'Computing empirical encodings frequency of VCTK val dataset...')

        alignments_dic = None
        with open(
                self._results_path + os.sep + self._experiment_name +
                '_vctk_empirical_alignments.pickle', 'rb') as f:
            alignments_dic = pickle.load(f)

        encodings_counter = alignments_dic['encodings_counter']
        desired_time_interval = alignments_dic['desired_time_interval']
        total_indices_apparations = alignments_dic['total_indices_apparations']

        encodings_frequency = dict()
        for key, value in encodings_counter.items():
            encodings_frequency[key] = value * 100 / total_indices_apparations

        encodings_frequency_sorted_keys = sorted(encodings_frequency,
                                                 key=encodings_frequency.get,
                                                 reverse=True)
        values = [
            encodings_frequency[key] for key in encodings_frequency_sorted_keys
        ]

        # TODO: add title
        fig, ax = plt.subplots(figsize=(20, 1))
        ax.bar(encodings_frequency_sorted_keys, values)
        fig.savefig(self._results_path + os.sep + self._experiment_name +
                    '_vctk_empirical_frequency_{}ms.png'.format(
                        int(desired_time_interval * 1000)),
                    bbox_inches='tight',
                    pad_inches=0)
        plt.close(fig)
Пример #6
0
    def compute_groundtruth_phonemes_frequency(self, wo_diag=True):
        ConsoleLogger.status(
            'Computing groundtruth phonemes frequency of VCTK val dataset...')

        alignments_dic = None
        with open(
                self._results_path + os.sep +
                'vctk_groundtruth_alignments.pickle', 'rb') as f:
            alignments_dic = pickle.load(f)

        desired_time_interval = alignments_dic['desired_time_interval']
        phonemes_counter = alignments_dic['phonemes_counter']
        total_phonemes_apparations = alignments_dic[
            'total_phonemes_apparations']

        phonemes_frequency = dict()
        for key, value in phonemes_counter.items():
            phonemes_frequency[key] = value * 100 / total_phonemes_apparations

        phonemes_frequency_sorted_keys = sorted(phonemes_frequency,
                                                key=phonemes_frequency.get,
                                                reverse=True)
        values = [
            phonemes_frequency[key] for key in phonemes_frequency_sorted_keys
        ]

        # TODO: add title
        fig, ax = plt.subplots(figsize=(20, 1))
        ax.bar(phonemes_frequency_sorted_keys, values)
        fig.savefig(self._results_path + os.sep +
                    'vctk_groundtruth_phonemes_frequency_{}ms.png'.format(
                        int(desired_time_interval * 1000)),
                    bbox_inches='tight',
                    pad_inches=0)
        plt.close(fig)
Пример #7
0
    def load_from_configuration(configuration):
        use_cuda = configuration['use_cuda'] and torch.cuda.is_available(
        )  # Use cuda if specified and available
        default_device = 'cuda' if use_cuda else 'cpu'  # Use default cuda device if possible or use the cpu
        device = configuration['use_device'] if configuration[
            'use_device'] is not None else default_device  # Use a defined device if specified
        gpu_ids = [i for i in range(torch.cuda.device_count())
                   ] if configuration['use_data_parallel'] else [
                       0
                   ]  # Resolve the gpu ids if gpu parallelization is specified
        if configuration['use_device'] and ':' in configuration['use_device']:
            gpu_ids = [int(configuration['use_device'].split(':')[1])]

        use_data_parallel = True if configuration[
            'use_data_parallel'] and use_cuda and len(gpu_ids) > 1 else False

        ConsoleLogger.status('The used device is: {}'.format(device))
        ConsoleLogger.status('The gpu ids are: {}'.format(gpu_ids))

        # Sanity checks
        if not use_cuda and configuration['use_cuda']:
            ConsoleLogger.warn(
                "The configuration file specified use_cuda=True but cuda isn't available"
            )
        if configuration['use_data_parallel'] and len(gpu_ids) < 2:
            ConsoleLogger.warn(
                'The configuration file specified use_data_parallel=True but there is only {} GPU available'
                .format(len(gpu_ids)))

        return DeviceConfiguration(use_cuda, device, gpu_ids,
                                   use_data_parallel)
    def evaluate(self, evaluation_options):
        self._model.eval()

        if evaluation_options['plot_exp'] or \
            evaluation_options['plot_quantized_embedding_spaces'] or \
            evaluation_options['plot_distances_histogram']:
            evaluation_entry = self._evaluate_once()

        if evaluation_options['plot_exp']:
            self._compute_comparaison_plot(evaluation_entry)

        if evaluation_options['plot_quantized_embedding_spaces']:
            EmbeddingSpaceStats.compute_and_plot_quantized_embedding_space_projections(
                self._results_path, self._experiment_name, evaluation_entry,
                self._model.vq.embedding,
                self._data_stream.validation_batch_size)

        if evaluation_options['plot_distances_histogram']:
            self._plot_distances_histogram(evaluation_entry)

        #self._test_denormalization(evaluation_entry)

        if evaluation_options['compute_many_to_one_mapping']:
            self._many_to_one_mapping()

        if evaluation_options['compute_alignments'] or \
            evaluation_options['compute_clustering_metrics'] or \
            evaluation_options['compute_groundtruth_average_phonemes_number']:
            alignment_stats = AlignmentStats(
                self._data_stream, self._vctk, self._configuration,
                self._device, self._model, self._results_path,
                self._experiment_name, evaluation_options['alignment_subset'])
        if evaluation_options['compute_alignments']:
            groundtruth_alignments_path = self._results_path + os.sep + \
                'vctk_{}_groundtruth_alignments.pickle'.format(evaluation_options['alignment_subset'])
            if not os.path.isfile(groundtruth_alignments_path):
                alignment_stats.compute_groundtruth_alignments()
                alignment_stats.compute_groundtruth_bigrams_matrix(
                    wo_diag=True)
                alignment_stats.compute_groundtruth_bigrams_matrix(
                    wo_diag=False)
                alignment_stats.compute_groundtruth_phonemes_frequency()
            else:
                ConsoleLogger.status('Groundtruth alignments already exist')

            empirical_alignments_path = self._results_path + os.sep + self._experiment_name + \
                '_vctk_{}_empirical_alignments.pickle'.format(evaluation_options['alignment_subset'])
            if not os.path.isfile(empirical_alignments_path):
                alignment_stats.compute_empirical_alignments()
                alignment_stats.compute_empirical_bigrams_matrix(wo_diag=True)
                alignment_stats.compute_empirical_bigrams_matrix(wo_diag=False)
                alignment_stats.comupte_empirical_encodings_frequency()
            else:
                ConsoleLogger.status('Empirical alignments already exist')

        if evaluation_options['compute_clustering_metrics']:
            alignment_stats.compute_clustering_metrics()

        if evaluation_options['compute_groundtruth_average_phonemes_number']:
            alignment_stats.compute_groundtruth_average_phonemes_number()
    def search_latest_checkpoint_file(checkpoint_files):
        # Search the latest checkpoint file
        ConsoleLogger.status('Searching the latest checkpoint file')
        latest_checkpoint_file = checkpoint_files[0]
        latest_epoch = int(checkpoint_files[0].split('_')[1])
        for i in range(1, len(checkpoint_files)):
            epoch = int(checkpoint_files[i].split('_')[1])
            if epoch > latest_epoch:
                latest_checkpoint_file = checkpoint_files[i]
                latest_epoch = epoch

        return latest_checkpoint_file, latest_epoch
    def compute_dataset_stats(self):
        initial_index = 0
        attempts = 10
        current_attempt = 0
        total_length = len(self._training_loader)
        train_mfccs = list()
        while current_attempt < attempts:
            try:
                i = initial_index
                train_bar = tqdm(self._training_loader, initial=initial_index)
                for data in train_bar:
                    input_features = data['input_features']
                    train_mfccs.append(input_features.detach().view(input_features.size(1), input_features.size(2)).numpy())

                    i += 1

                    if i == total_length:
                        train_bar.update(total_length)
                        break

                train_bar.close()
                break

            except KeyboardInterrupt:
                train_bar.close()
                ConsoleLogger.warn('Keyboard interrupt detected. Leaving the function...')
                return
            except:
                error_message = 'An error occured in the data loader at {}. Current attempt: {}/{}'.format(i, current_attempt+1, attempts)
                self._logger.exception(error_message)
                ConsoleLogger.error(error_message)
                initial_index = i
                current_attempt += 1
                continue


        ConsoleLogger.status('Compute mean of mfccs training set...')
        train_mean = np.concatenate(train_mfccs).mean(axis=0)

        ConsoleLogger.status('Compute std of mfccs training set...')
        train_std = np.concatenate(train_mfccs).std(axis=0)

        stats = {
            'train_mean': train_mean,
            'train_std': train_std
        }

        ConsoleLogger.status('Writing stats in file...')
        with open(self._normalizer_path, 'wb') as file: # TODO: do not use hardcoded path
            pickle.dump(stats, file)

        train_mfccs_norm = (train_mfccs[0] - train_mean) / train_std

        ConsoleLogger.status('Computing example plot...')
        _, axs = plt.subplots(2, sharex=True)
        axs[0].imshow(train_mfccs[0].T, aspect='auto', origin='lower')
        axs[0].set_ylabel('Unnormalized')
        axs[1].imshow(train_mfccs_norm.T, aspect='auto', origin='lower')
        axs[1].set_ylabel('Normalized')
        plt.savefig('mfcc_normalization_comparaison.png') # TODO: do not use hardcoded path
Пример #11
0
    def eval(self):
        ConsoleLogger.status('start epoch: {}'.format(
            self._configuration['start_epoch']))
        ConsoleLogger.status('num epoch: {}'.format(
            self._configuration['num_epochs']))
        for index in [16, 32, 48, 64]:
            recons_loss = 0.0
            cnt = 0.0
            with tqdm(self._data_stream.training_loader) as train_bar:
                train_res_recon_error = list(
                )  # FIXME: record as a global metric
                train_res_perplexity = list(
                )  # FIXME: record as a global metric
                iteration = 0
                max_iterations_number = len(train_bar)
                iterations = list(
                    np.arange(max_iterations_number,
                              step=(max_iterations_number /
                                    self._iterations_to_record) - 1,
                              dtype=int))
                for data in train_bar:
                    if len(data['one_hot']) == 1:
                        continue
                    losses, perplexity_value, sample_index = self.iterate_deconv(
                        data,
                        0,
                        iteration,
                        iterations,
                        train_bar,
                        eval=True,
                        designate=index)
                    if losses is None or perplexity_value is None:
                        continue
                    train_res_recon_error.append(losses)
                    train_res_perplexity.append(perplexity_value)
                    iteration += 1
                    recons_loss += losses['reconstruction_loss']
                    cnt += len(data['one_hot'])

            print("Index:{},Train_res_recon_error:{}".format(
                index, recons_loss / cnt))
            self.save_eval(
                0, **{
                    'train_res_recon_error': train_res_recon_error,
                    'train_res_perplexity': train_res_perplexity
                })
    def merge_experiment_losses(experiment_path, checkpoint_files, device_configuration):
        train_res_losses = {}
        train_res_perplexities = []

        sorted_checkpoint_files = sorted(checkpoint_files, key=lambda x: int(x.split('.')[0].split('_')[-2]))
        for checkpoint_file in sorted_checkpoint_files:
            # Load the checkpoint file
            checkpoint_path = experiment_path + os.sep + checkpoint_file
            ConsoleLogger.status("Loading the checkpoint file '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path, map_location=device_configuration.device)
            for loss_entry in checkpoint['train_res_recon_error']:
                for key in loss_entry.keys():
                    if key not in train_res_losses:
                        train_res_losses[key] = list()
                    train_res_losses[key].append(loss_entry[key])
            train_res_perplexities += checkpoint['train_res_perplexity']

        return train_res_losses, train_res_perplexities
    def plot_gradient_flow_over_epochs(gradient_stats_entries,
                                       output_file_name):
        epoch_number, iteration_number = set(), set()
        for epoch, iteration, _ in gradient_stats_entries:
            epoch_number.add(epoch)
            iteration_number.add(iteration)

        epoch_number = len(epoch_number)
        iteration_number = len(iteration_number)

        fig, axs = plt.subplots(epoch_number,
                                iteration_number,
                                figsize=(epoch_number * 8,
                                         iteration_number * 8),
                                sharey=True,
                                sharex=True)
        k = 0
        for i in trange(epoch_number):
            for j in trange(iteration_number):
                _, _, gradient_stats_entry = gradient_stats_entries[k]
                GradientStats.plot_gradient_flow(
                    gradient_stats_entry['model'],
                    axs[i][j],
                    set_xticks=True if i + 1 == epoch_number else False,
                    set_ylabels=True if j == 0 else False)
                k += 1

        ConsoleLogger.status('Saving gradient flow plot...')
        fig.suptitle('Gradient flow', fontsize='x-large')
        fig.legend(
            [
                Line2D([0], [0], color='c', lw=4),
                Line2D([0], [0], color='b', lw=4),
                Line2D([0], [0], color='k', lw=4)
            ],
            ['max-gradient', 'mean-gradient', 'zero-gradient'],
            loc="center right",  # Position of legend
            borderaxespad=0.1  # Small spacing around legend box
        )
        fig.savefig(output_file_name,
                    bbox_inches='tight',
                    pad_inches=0,
                    dpi=200)
        plt.close(fig)
Пример #14
0
    def train(self):
        ConsoleLogger.status('start epoch: {}'.format(
            self._configuration['start_epoch']))
        ConsoleLogger.status('num epoch: {}'.format(
            self._configuration['num_epochs']))

        for epoch in range(self._configuration['start_epoch'],
                           self._configuration['num_epochs']):
            with tqdm(self._data_stream.training_loader) as train_bar:
                train_res_recon_error = list(
                )  # FIXME: record as a global metric
                train_res_perplexity = list(
                )  # FIXME: record as a global metric
                train_res_recon_error_index = dict()
                index_cnt = dict()
                iteration = 0
                loss_sum = 0.0
                max_iterations_number = len(train_bar)
                iterations = list(
                    np.arange(max_iterations_number,
                              step=(max_iterations_number /
                                    self._iterations_to_record) - 1,
                              dtype=int))

                for data in train_bar:
                    if len(data['one_hot']) == 1:
                        continue
                    if self._configuration[
                            'decoder_type'] == 'deconvolutional':
                        _, loss_res = self.iterate_deconv(data,
                                                          epoch,
                                                          iteration,
                                                          iterations,
                                                          train_bar,
                                                          eval=False,
                                                          return_loss=True)
                        loss_sum += loss_res
                        iteration += 1
                print("Average loss per iteration:", loss_sum / iteration)
                self.save(epoch)
    def retreive_losses_values(experiment_path, experiment):
        experiment_name = experiment.name

        ConsoleLogger.status("Searching configuration and checkpoints of experiment '{}' at path '{}'".format(experiment_name, experiment_path))
        configuration_file, checkpoint_files = CheckpointUtils.search_configuration_and_checkpoints_files(
            experiment_path,
            experiment_name
        )

        # Check if a configuration file was found
        if not configuration_file:
            raise ValueError('No configuration file found with name: {}'.format(experiment_name))

        # Check if at least one checkpoint file was found
        if len(checkpoint_files) == 0:
            raise ValueError('No checkpoint files found with name: {}'.format(experiment_name))

        # Load the configuration file
        configuration_path = experiment_path + os.sep + configuration_file
        ConsoleLogger.status("Loading the configuration file '{}'".format(configuration_path))
        configuration = None
        with open(configuration_path, 'r') as file:
            configuration = yaml.load(file, Loader=yaml.FullLoader)
        
        # Load the device configuration from the configuration state
        device_configuration = DeviceConfiguration.load_from_configuration(configuration)

        ConsoleLogger.status("Merge {} checkpoint losses of experiment '{}'".format(len(checkpoint_files), experiment_name))
        train_res_losses, train_res_perplexities = CheckpointUtils.merge_experiment_losses(
            experiment_path,
            checkpoint_files,
            device_configuration
        )

        return train_res_losses, train_res_perplexities, len(checkpoint_files)
    def search_configuration_and_checkpoints_files(experiment_path, experiment_name):
        # Check if the specified experiment path exists
        ConsoleLogger.status("Checking if the experiment path '{}' exists".format(experiment_path))
        if not os.path.isdir(experiment_path):
            raise ValueError("Specified experiment path '{}' doesn't not exist".format(experiment_path))

        # List all the files from this directory and raise an error if it's empty
        ConsoleLogger.status('Listing the specified experiment path directory')
        files = os.listdir(experiment_path)
        if not files or len(files) == 0:
            raise ValueError("Specified experiment path '{}' is empty".format(experiment_path))

        # Search the configuration file and the checkpoint files of the specified experiment
        ConsoleLogger.status('Searching the configuration file and the checkpoint files')
        checkpoint_files = list()
        configuration_file = None
        for file in files:
            # Check if the file is a checkpoint or config file by looking at the extension
            if not '.pth' in file and '.yaml' not in file:
                continue
            split_file = file.split('_')
            if len(split_file) > 1 and split_file[0] == experiment_name and split_file[1] == 'configuration.yaml':
                configuration_file = file
            elif len(split_file) > 1 and split_file[0] == experiment_name and split_file[1] != 'configuration.yaml':
                checkpoint_files.append(file)

        return configuration_file, checkpoint_files
    def forward(self, x, speaker_dic, speaker_id,full=False, designate=None):

        if self._verbose:
            #print("max:{},min:{}".format(torch.max(x), torch.min(x)))
            ConsoleLogger.status('[ConvVQVAE] _encoder input size: {}'.format(x.size()))
        x = x.permute(0, 2, 1).contiguous().float()
        z = self._encoder(x)
        if self._verbose:
            ConsoleLogger.status('[ConvVQVAE] _encoder output size: {}'.format(z.size()))

        #z = self._pre_vq_conv(z)
        if self._verbose:
            ConsoleLogger.status('[ConvVQVAE] _pre_vq_conv output size: {}'.format(z.size()))

        z_q_x_st, z_q_x, indices = self._vq.straight_through(z)
        area = z.size(1)
        sample_index = area
        if not full:
            sample_index = np.array((np.random.randint(area) + 1))
        if designate != None:
            sample_index = designate
        zero_out = torch.zeros((z_q_x_st.size(0), area - sample_index, z_q_x_st.size(2))).cuda()
        z_q_x_st = torch.cat((z_q_x_st[:, :sample_index], zero_out), dim=1)
        reconstructed_x = self._decoder(z_q_x_st, speaker_dic, speaker_id)
        output_features_size = reconstructed_x.size(2)
        reconstructed_x = reconstructed_x.view(-1, 1, output_features_size)
        return reconstructed_x, z, z_q_x, sample_index, indices
    def __init__(self, name, experiments_path, results_path, global_configuration,
        experiment_configuration, seed):

        self._name = name
        self._experiments_path = experiments_path
        self._results_path = results_path
        self._global_configuration = global_configuration
        self._experiment_configuration = experiment_configuration
        self._seed = seed

        # Create the experiments path directory if it doesn't exist
        if not os.path.isdir(experiments_path):
            ConsoleLogger.status('Creating experiments directory at path: {}'.format(experiments_path))
            os.mkdir(experiments_path)
        else:
            ConsoleLogger.status('Experiments directory already created at path: {}'.format(experiments_path))

        # Create the results path directory if it doesn't exist
        if not os.path.isdir(results_path):
            ConsoleLogger.status('Creating results directory at path: {}'.format(results_path))
            os.mkdir(results_path)
        else:
            ConsoleLogger.status('Results directory already created at path: {}'.format(results_path))



        experiments_configuration_path = experiments_path + os.sep + name + '_configuration.yaml'
        configuration_file_already_exists = True if os.path.isfile(experiments_configuration_path) else False
        if not configuration_file_already_exists:
            self._device_configuration = DeviceConfiguration.load_from_configuration(global_configuration)

            # Create a new configuration state from the default and the experiment specific aspects
            self._configuration = copy.deepcopy(self._global_configuration)
            for experiment_key in experiment_configuration.keys():
                if experiment_key in self._configuration:
                    self._configuration[experiment_key] = experiment_configuration[experiment_key]

            # Save the configuration of the experiments
            with open(experiments_configuration_path, 'w') as file:
                yaml.dump(self._configuration, file)
        else:
            with open(experiments_configuration_path, 'r') as file:
                self._configuration = yaml.load(file, Loader=yaml.FullLoader)
                self._device_configuration = DeviceConfiguration.load_from_configuration(self._configuration)
        if configuration_file_already_exists:
            self._trainer, self._evaluator, self._configuration, self._device_configuration = PipelineFactory.load(self._experiments_path, self._name, self._results_path)
        else:
            self._trainer, self._evaluator = PipelineFactory.build(self._configuration,
                self._device_configuration, self._experiments_path, self._name, self._results_path)
Пример #19
0
    def train(self):
        ConsoleLogger.status('start epoch: {}'.format(self._configuration['start_epoch']))
        ConsoleLogger.status('num epoch: {}'.format(self._configuration['num_epochs']))

        for epoch in range(self._configuration['start_epoch'], self._configuration['num_epochs']):

            with tqdm(self._data_stream.training_loader) as train_bar:
                train_res_recon_error = list() # FIXME: record as a global metric
                train_res_perplexity = list() # FIXME: record as a global metric

                iteration = 0
                max_iterations_number = len(train_bar)
                iterations = list(np.arange(max_iterations_number, step=(max_iterations_number / self._iterations_to_record) - 1, dtype=int))

                for data in train_bar:
                    losses, perplexity_value = self.iterate(data, epoch, iteration, iterations, train_bar)
                    if losses is None or perplexity_value is None:
                        continue
                    train_res_recon_error.append(losses)
                    train_res_perplexity.append(perplexity_value)
                    iteration += 1

                self.save(epoch, **{'train_res_recon_error': train_res_recon_error, 'train_res_perplexity': train_res_perplexity})
Пример #20
0
    def forward(self, x, speaker_dic, speaker_id):
        x = x.permute(0, 2, 1).contiguous().float()

        z = self._encoder(x)
        if self._verbose:
            ConsoleLogger.status('[ConvVQVAE] _encoder output size: {}'.format(z.size()))

        z = self._pre_vq_conv(z)
        if self._verbose:
            ConsoleLogger.status('[ConvVQVAE] _pre_vq_conv output size: {}'.format(z.size()))

        vq_loss, quantized, perplexity, _, _, encoding_indices, \
            losses, _, _, _, concatenated_quantized = self._vq(z, record_codebook_stats=self._record_codebook_stats)

        reconstructed_x = self._decoder(quantized, speaker_dic, speaker_id)

        input_features_size = x.size(2)
        output_features_size = reconstructed_x.size(2)

        reconstructed_x = reconstructed_x.view(-1, self._output_features_filters, output_features_size)
        reconstructed_x = reconstructed_x[:, :, :-(output_features_size-input_features_size)]
        
        return reconstructed_x, vq_loss, losses, perplexity, encoding_indices, concatenated_quantized
Пример #21
0
    def _plot_distances_histogram(self, evaluation_entry):
        encoding_distances = evaluation_entry['encoding_distances'][0].detach(
        ).cpu().numpy()
        embedding_distances = evaluation_entry['embedding_distances'].detach(
        ).cpu().numpy()
        frames_vs_embedding_distances = evaluation_entry[
            'frames_vs_embedding_distances'].detach()[0].cpu().transpose(
                0, 1).numpy().ravel()

        if self._configuration['verbose']:
            ConsoleLogger.status('encoding_distances[0].size(): {}'.format(
                encoding_distances.shape))
            ConsoleLogger.status('embedding_distances.size(): {}'.format(
                embedding_distances.shape))
            ConsoleLogger.status(
                'frames_vs_embedding_distances[0].shape: {}'.format(
                    frames_vs_embedding_distances.shape))

        fig, axs = plt.subplots(3, 1, figsize=(30, 20), sharex=True)

        axs[0].set_title('\n'.join(
            wrap('Histogram of the distances between the'
                 ' encodings vectors', 60)))
        sns.distplot(encoding_distances,
                     hist=True,
                     kde=False,
                     ax=axs[0],
                     norm_hist=True)

        axs[1].set_title('\n'.join(
            wrap(
                'Histogram of the distances between the'
                ' embeddings vectors', 60)))
        sns.distplot(embedding_distances,
                     hist=True,
                     kde=False,
                     ax=axs[1],
                     norm_hist=True)

        axs[2].set_title(
            'Histogram of the distances computed in'
            ' VQ\n($||z_e(x) - e_i||^2_2$ with $z_e(x)$ the output of the encoder'
            ' prior to quantization)')
        sns.distplot(frames_vs_embedding_distances,
                     hist=True,
                     kde=False,
                     ax=axs[2],
                     norm_hist=True)

        output_path = self._results_path + os.sep + self._experiment_name + '_distances-histogram-plot.png'
        fig.savefig(output_path, bbox_inches='tight', pad_inches=0)
        plt.close(fig)
Пример #22
0
    def forward(self, inputs):
        if self._verbose:
            ConsoleLogger.status('inputs size: {}'.format(inputs.size()))

        x_conv_1 = F.relu(self._conv_1(inputs))
        if self._verbose:
            ConsoleLogger.status('x_conv_1 output size: {}'.format(
                x_conv_1.size()))

        x = F.relu(self._conv_2(x_conv_1)) + x_conv_1
        if self._verbose:
            ConsoleLogger.status('_conv_2 output size: {}'.format(x.size()))

        x_conv_3 = F.relu(self._conv_3(x))
        if self._verbose:
            ConsoleLogger.status('_conv_3 output size: {}'.format(
                x_conv_3.size()))

        x_conv_4 = F.relu(self._conv_4(x_conv_3)) + x_conv_3
        if self._verbose:
            ConsoleLogger.status('_conv_4 output size: {}'.format(
                x_conv_4.size()))

        x_conv_5 = F.relu(self._conv_5(x_conv_4)) + x_conv_4
        if self._verbose:
            ConsoleLogger.status('x_conv_5 output size: {}'.format(
                x_conv_5.size()))

        x = self._residual_stack(x_conv_5) + x_conv_5
        if self._verbose:
            ConsoleLogger.status('_residual_stack output size: {}'.format(
                x.size()))

        return x
Пример #23
0
    def dump(self):
        ConsoleLogger.status('start epoch: {}'.format(
            self._configuration['start_epoch']))
        ConsoleLogger.status('num epoch: {}'.format(
            self._configuration['num_epochs']))
        print('Dumping Codebook')
        name = 'wave_bigger_baseline'
        with tqdm(self._data_stream.training_loader) as train_bar:
            iteration = 0
            max_iterations_number = len(train_bar)
            iterations = list(
                np.arange(
                    max_iterations_number,
                    step=(max_iterations_number / self._iterations_to_record) -
                    1,
                    dtype=int))
            lst = []
            for data in train_bar:
                if len(data['one_hot']) == 1:
                    continue
                if self._configuration['decoder_type'] == 'deconvolutional':
                    indices = self.iterate_deconv(data,
                                                  0,
                                                  iteration,
                                                  iterations,
                                                  train_bar,
                                                  eval=True).view(
                                                      len(data['one_hot']), -1)
                    iteration += 1
                    lst.append(indices.cpu())
        lst = torch.cat(lst, 0)
        os.makedirs(os.path.join('./data-bin', name), exist_ok=True)
        torch.save(lst, os.path.join('./data-bin', name, 'train.pt'))
        print('Dumped training Codebook')

        with tqdm(self._data_stream.validation_loader) as train_bar:
            iteration = 0
            max_iterations_number = len(train_bar)
            iterations = list(
                np.arange(
                    max_iterations_number,
                    step=(max_iterations_number / self._iterations_to_record) -
                    1,
                    dtype=int))
            lst = []
            for data in train_bar:
                if len(data['one_hot']) == 1:
                    continue
                if self._configuration['decoder_type'] == 'deconvolutional':
                    indices = self.iterate_deconv(data,
                                                  0,
                                                  iteration,
                                                  iterations,
                                                  train_bar,
                                                  eval=True).view(
                                                      len(data['one_hot']), -1)
                    iteration += 1
                    lst.append(indices.cpu())
        lst = torch.cat(lst, 0)
        torch.save(lst, os.path.join('./data-bin', name, 'valid.pt'))
        print('Dumped validation Codebook')
    def _compute_comparaison_plot(self, evaluation_entry):
        print("entry:", evaluation_entry['wav_filename'])
        sample = evaluation_entry['sample']
        utterence_key = evaluation_entry['wav_filename'].split(
            '/')[-1].replace('.wav', '')
        utterence = self._vctk.utterences[utterence_key].replace('\n', '')
        phonemes_alignment_path = os.sep.join(evaluation_entry['wav_filename'].split('/')[:-3]) \
            + os.sep + 'phonemes' + os.sep + utterence_key.split('_')[0] + os.sep \
            + utterence_key + '.TextGrid'

        speaker = evaluation_entry['speaker_ids']
        ConsoleLogger.status('Original utterence: {}, speaker id:{}'.format(
            utterence, speaker))

        if self._configuration['verbose']:
            ConsoleLogger.status('utterence: {}'.format(utterence))
        spectrogram_parser = SpectrogramParser()
        preprocessed_audio = evaluation_entry['preprocessed_audio'].detach(
        ).cpu()[sample].numpy().squeeze()

        sns.set(style='darkgrid', font_scale=5)
        preprocessed_audio = preprocessed_audio[:15360]
        valid_reconstructions = evaluation_entry['valid_reconstructions']
        fig, axs = plt.subplots(1,
                                len(valid_reconstructions),
                                figsize=(120, 20),
                                sharex=True)
        # Waveform of the original speech signal
        # axs[0].set_title('Waveform of the original speech signal')
        # axs[0].plot(np.arange(len(preprocessed_audio)), preprocessed_audio)
        write("original.wav", 16000, preprocessed_audio)
        # # Spectrogram of the original speech signal
        # axs[1].set_title('Spectrogram of the original speech signal')
        # self._plot_pcolormesh(spectrogram, fig, x=self._compute_unified_time_scale(spectrogram.shape[1]), axis=axs[1])
        #
        # # MFCC + d + a of the original speech signal
        # axs[2].set_title('Augmented MFCC + d + a #filters=13+13+13 of the original speech signal')
        # self._plot_pcolormesh(valid_originals, fig, x=self._compute_unified_time_scale(valid_originals.shape[1]), axis=axs[2])
        #
        # # Softmax of distances computed in VQ
        # axs[3].set_title('Softmax of distances computed in VQ\n($||z_e(x) - e_i||^2_2$ with $z_e(x)$ the output of the encoder prior to quantization)')
        # self._plot_pcolormesh(probs, fig, x=self._compute_unified_time_scale(probs.shape[1], downsampling_factor=2), axis=axs[3])
        #
        # encodings = evaluation_entry['encodings'].detach().cpu().numpy()
        # axs[4].set_title('Encodings')
        # self._plot_pcolormesh(encodings[0].transpose(), fig, x=self._compute_unified_time_scale(encodings[0].transpose().shape[1],
        #     downsampling_factor=2), axis=axs[4])

        # Actual reconstruction
        idx = [0.0625, 0.25, 1.0]
        for i in range(1, len(valid_reconstructions) + 1):

            axs[i - 1].set_title('Fraction of full code length:' +
                                 str(idx[i - 1]))
            print("Reconstruction size:", valid_reconstructions[i - 1].size())
            valid_reconstructions[i - 1] = valid_reconstructions[
                i - 1].detach().cpu().numpy()[0]
            d_1 = {
                ' ': np.arange(len(valid_reconstructions[i - 1])),
                '  ': valid_reconstructions[i - 1]
            }
            pdnumsqr_1 = pd.DataFrame(d_1)
            sns.lineplot(x=' ', y='  ', data=pdnumsqr_1, ax=axs[i - 1])
            write("reconstruction_" + str(i - 1) + ".wav", 16000,
                  valid_reconstructions[i - 1])

        output_path = '_evaluation-comparaison-plot_3.pdf'
        print("Output path:", output_path)
        plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
        plt.close()
Пример #25
0
    def _many_to_one_mapping(self):
        # TODO: fix it for batch size greater than one

        tokens_selections = list()
        val_speaker_ids = set()

        with tqdm(self._data_stream.validation_loader) as bar:
            for data in bar:
                valid_originals = data['input_features'].to(
                    self._device).permute(0, 2, 1).contiguous().float()
                speaker_ids = data['speaker_id'].to(self._device)
                shifting_times = data['shifting_time'].to(self._device)
                wav_filenames = data['wav_filename']

                speaker_id = wav_filenames[0][0].split(os.sep)[-2]
                val_speaker_ids.add(speaker_id)

                if speaker_id not in os.listdir(self._vctk.raw_folder +
                                                os.sep + 'VCTK-Corpus' +
                                                os.sep + 'phonemes'):
                    # TODO: log the missing folders
                    continue

                z = self._model.encoder(valid_originals)
                z = self._model.pre_vq_conv(z)
                _, quantized, _, encodings, _, encoding_indices, _, \
                    _, _, _, _ = self._model.vq(z)
                valid_reconstructions = self._model.decoder(
                    quantized, self._data_stream.speaker_dic, speaker_ids)
                B = valid_reconstructions.size(0)

                encoding_indices = encoding_indices.view(B, -1, 1)

                for i in range(len(valid_reconstructions)):
                    wav_filename = wav_filenames[0][i]
                    utterence_key = wav_filename.split('/')[-1].replace(
                        '.wav', '')
                    phonemes_alignment_path = os.sep.join(wav_filename.split('/')[:-3]) + os.sep + 'phonemes' + os.sep + utterence_key.split('_')[0] + os.sep \
                        + utterence_key + '.TextGrid'
                    tg = textgrid.TextGrid()
                    tg.read(phonemes_alignment_path)
                    entry = {
                        'encoding_indices':
                        encoding_indices[i].detach().cpu().numpy(),
                        'groundtruth':
                        tg.tiers[1],
                        'shifting_time':
                        shifting_times[i].detach().cpu().item()
                    }
                    tokens_selections.append(entry)

        ConsoleLogger.status(val_speaker_ids)

        ConsoleLogger.status('{} tokens selections retreived'.format(
            len(tokens_selections)))

        phonemes_mapping = dict()
        # For each tokens selections (i.e. the number of valuations)
        for entry in tokens_selections:
            encoding_indices = entry['encoding_indices']
            unified_encoding_indices_time_scale = self._compute_unified_time_scale(
                encoding_indices.shape[0], downsampling_factor=2
            )  # Compute the time scale array for each token
            """
            Search the grountruth phoneme where the selected token index time scale
            is within the groundtruth interval.
            Then, it adds the selected token index in the list of indices selected for
            the a specific token in the tokens mapping dictionnary.
            """
            for i in range(len(unified_encoding_indices_time_scale)):
                index_time_scale = unified_encoding_indices_time_scale[
                    i] + entry['shifting_time']
                corresponding_phoneme = None
                for interval in entry['groundtruth']:
                    # TODO: replace that by nearest interpolation
                    if index_time_scale >= interval.minTime and index_time_scale <= interval.maxTime:
                        corresponding_phoneme = interval.mark
                        break
                if not corresponding_phoneme:
                    ConsoleLogger.warn(
                        "Corresponding phoneme not found. unified_encoding_indices_time_scale[{}]: {}"
                        "entry['shifting_time']: {} index_time_scale: {}".
                        format(i, unified_encoding_indices_time_scale[i],
                               entry['shifting_time'], index_time_scale))
                if corresponding_phoneme not in phonemes_mapping:
                    phonemes_mapping[corresponding_phoneme] = list()
                phonemes_mapping[corresponding_phoneme].append(
                    encoding_indices[i][0])

        ConsoleLogger.status('phonemes_mapping: {}'.format(phonemes_mapping))

        tokens_mapping = dict(
        )  # dictionnary that will contain the distribution for each token to fits with a certain phoneme
        """
        Fill the tokens_mapping such that for each token index (key)
        it contains the list of tuple of (phoneme, prob) where prob
        is the probability that the token fits this phoneme.
        """
        for phoneme, indices in phonemes_mapping.items():
            for index in list(set(indices)):
                if index not in tokens_mapping:
                    tokens_mapping[index] = list()
                tokens_mapping[index].append(
                    (phoneme, indices.count(index) / len(indices)))

        # Sort the probabilities for each token
        for index, distribution in tokens_mapping.items():
            tokens_mapping[index] = list(
                sorted(distribution, key=lambda x: x[1], reverse=True))

        ConsoleLogger.status('tokens_mapping: {}'.format(tokens_mapping))

        with open(
                self._results_path + os.sep + self._experiment_name +
                '_phonemes_mapping.pickle', 'wb') as f:
            pickle.dump(phonemes_mapping, f)

        with open(
                self._results_path + os.sep + self._experiment_name +
                '_tokens_mapping.pickle', 'wb') as f:
            pickle.dump(tokens_mapping, f)
Пример #26
0
    def _compute_comparaison_plot(self, evaluation_entry):
        utterence_key = evaluation_entry['wav_filename'].split(
            '/')[-1].replace('.wav', '')
        utterence = self._vctk.utterences[utterence_key].replace('\n', '')
        phonemes_alignment_path = os.sep.join(evaluation_entry['wav_filename'].split('/')[:-3]) \
            + os.sep + 'phonemes' + os.sep + utterence_key.split('_')[0] + os.sep \
            + utterence_key + '.TextGrid'
        #tg = textgrid.TextGrid()
        #tg.read(phonemes_alignment_path)
        #for interval in tg.tiers[0]:

        ConsoleLogger.status('Original utterence: {}'.format(utterence))

        if self._configuration['verbose']:
            ConsoleLogger.status('utterence: {}'.format(utterence))

        spectrogram_parser = SpectrogramParser()
        preprocessed_audio = evaluation_entry['preprocessed_audio'].detach(
        ).cpu()[0].numpy().squeeze()
        spectrogram = spectrogram_parser.parse_audio(
            preprocessed_audio).contiguous()

        spectrogram = spectrogram.detach().cpu().numpy()

        valid_originals = evaluation_entry['valid_originals'].detach().cpu(
        )[0].numpy()

        probs = F.softmax(-evaluation_entry['distances'][0],
                          dim=1).detach().cpu().transpose(0, 1).contiguous()

        #target = self._target.detach().cpu()[0].numpy()

        valid_reconstructions = evaluation_entry[
            'valid_reconstructions'].detach().cpu().numpy()

        fig, axs = plt.subplots(6, 1, figsize=(35, 30), sharex=True)

        # Waveform of the original speech signal
        axs[0].set_title('Waveform of the original speech signal')
        axs[0].plot(
            np.arange(len(preprocessed_audio)) /
            float(self._configuration['sampling_rate']), preprocessed_audio)

        # TODO: Add number of encoding indices at the same rate of the tokens with _compute_unified_time_scale()
        """
        # Example of vertical red lines
        xposition = [0.3, 0.4, 0.45]
        for xc in xposition:
            plt.axvline(x=xc, color='r', linestyle='-', linewidth=1)
        """

        # Spectrogram of the original speech signal
        axs[1].set_title('Spectrogram of the original speech signal')
        self._plot_pcolormesh(spectrogram,
                              fig,
                              x=self._compute_unified_time_scale(
                                  spectrogram.shape[1]),
                              axis=axs[1])

        # MFCC + d + a of the original speech signal
        axs[2].set_title(
            'Augmented MFCC + d + a #filters=13+13+13 of the original speech signal'
        )
        self._plot_pcolormesh(valid_originals,
                              fig,
                              x=self._compute_unified_time_scale(
                                  valid_originals.shape[1]),
                              axis=axs[2])

        # Softmax of distances computed in VQ
        axs[3].set_title(
            'Softmax of distances computed in VQ\n($||z_e(x) - e_i||^2_2$ with $z_e(x)$ the output of the encoder prior to quantization)'
        )
        self._plot_pcolormesh(probs,
                              fig,
                              x=self._compute_unified_time_scale(
                                  probs.shape[1], downsampling_factor=2),
                              axis=axs[3])

        encodings = evaluation_entry['encodings'].detach().cpu().numpy()
        axs[4].set_title('Encodings')
        self._plot_pcolormesh(encodings[0].transpose(),
                              fig,
                              x=self._compute_unified_time_scale(
                                  encodings[0].transpose().shape[1],
                                  downsampling_factor=2),
                              axis=axs[4])

        # Actual reconstruction
        axs[5].set_title('Actual reconstruction')
        self._plot_pcolormesh(valid_reconstructions,
                              fig,
                              x=self._compute_unified_time_scale(
                                  valid_reconstructions.shape[1]),
                              axis=axs[5])

        output_path = self._results_path + os.sep + self._experiment_name + '_evaluation-comparaison-plot.png'
        plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
        plt.close()
Пример #27
0
        'compute_alignments': args.compute_alignments,
        'alignment_subset': args.alignment_subset,
        'compute_clustering_metrics': args.compute_clustering_metrics,
        'compute_groundtruth_average_phonemes_number':
        args.compute_groundtruth_average_phonemes_number,
        'plot_clustering_metrics_evolution':
        args.plot_clustering_metrics_evolution,
        'check_clustering_metrics_stability_over_seeds':
        args.check_clustering_metrics_stability_over_seeds,
        'plot_gradient_stats': args.plot_gradient_stats
    }

    # If specified, print the summary of the model using the CPU device
    if args.summary:
        configuration = load_configuration(args.summary)
        ConsoleLogger.status('Printing the summary of the model...')
        device_configuration = DeviceConfiguration.load_from_configuration(
            configuration)
        model = PipelineFactory.build(configuration, device_configuration,
                                      default_experiments_path,
                                      default_experiment_name,
                                      default_results_path)
        print(model)
        sys.exit(0)

    if args.plot_experiments_losses:
        LossesPlotter().plot_training_losses(
            Experiments.load(args.experiments_configuration_path).experiments,
            args.experiments_path)
        sys.exit(0)
    if args.eval:
Пример #28
0
    def export_to_features(self, vctk_path, configuration):
        if not os.path.isdir(vctk_path):
            raise ValueError(
                "VCTK dataset not found at path '{}'".format(vctk_path))

        # Create the features path directory if it doesn't exist
        features_path = vctk_path + os.sep + configuration['features_path']
        if not os.path.isdir(features_path):
            ConsoleLogger.status(
                'Creating features directory at path: {}'.format(
                    features_path))
            os.mkdir(features_path)
        else:
            ConsoleLogger.status(
                'Features directory already created at path: {}'.format(
                    features_path))

        # Create the features path directory if it doesn't exist
        train_features_path = features_path + os.sep + 'train'
        if not os.path.isdir(train_features_path):
            ConsoleLogger.status(
                'Creating train features directory at path: {}'.format(
                    train_features_path))
            os.mkdir(train_features_path)
        else:
            ConsoleLogger.status(
                'Train features directory already created at path: {}'.format(
                    train_features_path))

        # Create the features path directory if it doesn't exist
        val_features_path = features_path + os.sep + 'val'
        if not os.path.isdir(val_features_path):
            ConsoleLogger.status(
                'Creating val features directory at path: {}'.format(
                    val_features_path))
            os.mkdir(val_features_path)
        else:
            ConsoleLogger.status(
                'Val features directory already created at path: {}'.format(
                    val_features_path))

        def process(loader, output_dir, input_features_name,
                    output_features_name, rate, input_filters_number,
                    output_filters_number, input_target_shape,
                    augment_output_features, export_one_hot_features):

            initial_index = 0
            attempts = 10
            current_attempt = 0
            total_length = len(loader)

            while current_attempt < attempts:
                try:
                    i = initial_index
                    bar = tqdm(loader, initial=initial_index)
                    for data in bar:
                        (preprocessed_audio, one_hot, speaker_id, quantized,
                         wav_filename, sampling_rate, shifting_time,
                         random_starting_index, preprocessed_length,
                         top_db) = data

                        output_path = output_dir + os.sep + str(i) + '.pickle'
                        if os.path.isfile(output_path):
                            if os.path.getsize(output_path) == 0:
                                bar.set_description(
                                    '{} already exists but is empty. Computing it again...'
                                    .format(output_path))
                                os.remove(output_path)
                            else:
                                bar.set_description(
                                    '{} already exists'.format(output_path))
                            i += 1
                            continue

                        input_features = SpeechFeatures.features_from_name(
                            name=input_features_name,
                            signal=preprocessed_audio,
                            rate=rate,
                            filters_number=input_filters_number)

                        if input_features.shape[0] != input_target_shape[
                                0] or input_features.shape[
                                    1] != input_target_shape[1]:
                            ConsoleLogger.warn(
                                "Raw features number {} with invalid dimension {} will not be saved. Target shape: {}"
                                .format(i, input_features.shape,
                                        input_target_shape))
                            i += 1
                            continue

                        output_features = SpeechFeatures.features_from_name(
                            name=output_features_name,
                            signal=preprocessed_audio,
                            rate=rate,
                            filters_number=output_filters_number,
                            augmented=augment_output_features)

                        # TODO: add an option in configuration to save quantized/one_hot or not
                        output = {
                            'preprocessed_audio':
                            preprocessed_audio,
                            'wav_filename':
                            wav_filename,
                            'input_features':
                            input_features,
                            'one_hot':
                            one_hot
                            if export_one_hot_features else np.array([]),
                            'quantized':
                            np.array([]),
                            'speaker_id':
                            speaker_id,
                            'output_features':
                            output_features,
                            'shifting_time':
                            shifting_time,
                            'random_starting_index':
                            random_starting_index,
                            'preprocessed_length':
                            preprocessed_length,
                            'sampling_rate':
                            sampling_rate,
                            'top_db':
                            top_db
                        }

                        with open(output_path, 'wb') as file:
                            pickle.dump(output, file)

                        bar.set_description('{} saved'.format(output_path))

                        i += 1

                        if i == total_length:
                            bar.update(total_length)
                            break

                    bar.close()
                    break
                except KeyboardInterrupt:
                    bar.close()
                    ConsoleLogger.warn(
                        'Keyboard interrupt detected. Leaving the function...')
                    return
                except:
                    error_message = 'An error occured in the data loader at {}/{}. Current attempt: {}/{}'.format(
                        output_dir, i, current_attempt + 1, attempts)
                    self._logger.exception(error_message)
                    ConsoleLogger.error(error_message)
                    initial_index = i
                    current_attempt += 1
                    continue

        try:
            ConsoleLogger.status('Processing training part')
            process(
                loader=self._training_loader,
                output_dir=train_features_path,
                input_features_name=configuration['input_features_type'],
                output_features_name=configuration['output_features_type'],
                rate=configuration['sampling_rate'],
                input_filters_number=configuration['input_features_filters'],
                output_filters_number=configuration['output_features_filters'],
                input_target_shape=(configuration['input_features_dim'],
                                    configuration['input_features_filters'] *
                                    3),
                augment_output_features=configuration[
                    'augment_output_features'],
                export_one_hot_features=configuration[
                    'export_one_hot_features'])
            ConsoleLogger.success('Training part processed')
        except:
            ConsoleLogger.error(
                'An error occured during training features generation')

        try:
            ConsoleLogger.status('Processing validation part')
            process(
                loader=self._validation_loader,
                output_dir=val_features_path,
                input_features_name=configuration['input_features_type'],
                output_features_name=configuration['output_features_type'],
                rate=configuration['sampling_rate'],
                input_filters_number=configuration['input_features_filters'],
                output_filters_number=configuration['output_features_filters'],
                input_target_shape=(configuration['input_features_dim'],
                                    configuration['input_features_filters'] *
                                    3),
                augment_output_features=configuration[
                    'augment_output_features'],
                export_one_hot_features=configuration[
                    'export_one_hot_features'])
            ConsoleLogger.success('Validation part processed')
        except:
            ConsoleLogger.error(
                'An error occured during validation features generation')
    def load(experiments_path,
             experiment_name,
             results_path,
             data_path='../data'):
        error_caught = False

        try:
            configuration_file, checkpoint_files = PipelineFactory.load_configuration_and_checkpoints(
                experiments_path, experiment_name)
        except:
            ConsoleLogger.error(
                'Failed to load existing configuration. Building a new model...'
            )
            error_caught = True

        # Load the configuration file
        ConsoleLogger.status('Loading the configuration file')
        configuration = None
        with open(experiments_path + os.sep + configuration_file, 'r') as file:
            configuration = yaml.load(file, Loader=yaml.FullLoader)
        device_configuration = DeviceConfiguration.load_from_configuration(
            configuration)

        if error_caught or len(checkpoint_files) == 0:
            trainer, evaluator = PipelineFactory.build(configuration,
                                                       device_configuration,
                                                       experiments_path,
                                                       experiment_name,
                                                       results_path)
        else:
            latest_checkpoint_file, latest_epoch = CheckpointUtils.search_latest_checkpoint_file(
                checkpoint_files)
            # Update the epoch number to begin with for the future training
            configuration['start_epoch'] = latest_epoch
            configuration['num_epochs'] = 60
            #latest_checkpoint_file = 'baseline_15_checkpoint.pth'
            #print(latest_checkpoint_file)
            # Load the checkpoint file
            checkpoint_path = experiments_path + os.sep + latest_checkpoint_file
            ConsoleLogger.status(
                "Loading the checkpoint file '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path,
                                    map_location=device_configuration.device)

            # Load the data stream
            ConsoleLogger.status('Loading the data stream')
            data_stream = VCTKFeaturesStream('/atlas/u/xuyilun/vctk',
                                             configuration,
                                             device_configuration.gpu_ids,
                                             device_configuration.use_cuda)

            def load_state_dicts(model, checkpoint, model_name,
                                 optimizer_name):
                # Load the state dict from the checkpoint to the model
                model.load_state_dict(checkpoint[model_name])
                # Create an Adam optimizer using the model parameters
                optimizer = optim.Adam(model.parameters())
                # Load the state dict from the checkpoint to the optimizer
                optimizer.load_state_dict(checkpoint[optimizer_name])
                # Map the optimizer memory into the specified device
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to(device_configuration.device)
                return model, optimizer

            # If the decoder type is a deconvolutional
            if configuration['decoder_type'] == 'deconvolutional':
                # Create the model and map it to the specified device
                vqvae_model = ConvolutionalVQVAE(
                    configuration, device_configuration.device).to(
                        device_configuration.device)
                evaluator = Evaluator(device_configuration.device, vqvae_model,
                                      data_stream, configuration, results_path,
                                      experiment_name)

                # Load the model and optimizer state dicts
                vqvae_model, vqvae_optimizer = load_state_dicts(
                    vqvae_model, checkpoint, 'model', 'optimizer')
            elif configuration['decoder_type'] == 'wavenet':
                vqvae_model = WaveNetVQVAE(configuration,
                                           data_stream.speaker_dic,
                                           device_configuration.device).to(
                                               device_configuration.device)
                evaluator = Evaluator(device_configuration.device, vqvae_model,
                                      data_stream, configuration, results_path,
                                      experiment_name)
                # Load the model and optimizer state dicts
                vqvae_model, vqvae_optimizer = load_state_dicts(
                    vqvae_model, checkpoint, 'model', 'optimizer')
            else:
                raise NotImplementedError(
                    "Decoder type '{}' isn't implemented for now".format(
                        configuration['decoder_type']))

            # Temporary backward compatibility
            if 'trainer_type' not in configuration:
                ConsoleLogger.error(
                    "trainer_type was not found in configuration file. Use 'convolutional' by default."
                )
                configuration['trainer_type'] = 'convolutional'

            if configuration['trainer_type'] == 'convolutional':
                trainer = ConvolutionalTrainer(
                    device_configuration.device, data_stream, configuration,
                    experiments_path, experiment_name, **{
                        'model': vqvae_model,
                        'optimizer': vqvae_optimizer
                    })
            else:
                raise NotImplementedError(
                    "Trainer type '{}' isn't implemented for now".format(
                        configuration['trainer_type']))

            # Use data parallelization if needed and available
            vqvae_model = vqvae_model

        return trainer, evaluator, configuration, device_configuration
Пример #30
0
    def forward(self, inputs, speaker_dic, speaker_id):
        x = inputs
        if self._verbose:
            ConsoleLogger.status('[FEATURES_DEC] input size: {}'.format(
                x.size()))

        if self._use_jitter and self.training:
            x = self._jitter(x)

        if self._use_speaker_conditioning:
            speaker_embedding = GlobalConditioning.compute(speaker_dic,
                                                           speaker_id,
                                                           x,
                                                           device=self._device,
                                                           gin_channels=40,
                                                           expand=True)
            x = torch.cat([x, speaker_embedding], dim=1).to(self._device)

        x = self._conv_1(x)
        if self._verbose:
            ConsoleLogger.status(
                '[FEATURES_DEC] _conv_1 output size: {}'.format(x.size()))

        x = self._upsample(x)
        if self._verbose:
            ConsoleLogger.status(
                '[FEATURES_DEC] _upsample output size: {}'.format(x.size()))

        x = self._residual_stack(x)
        if self._verbose:
            ConsoleLogger.status(
                '[FEATURES_DEC] _residual_stack output size: {}'.format(
                    x.size()))

        x = F.relu(self._conv_trans_1(x))
        if self._verbose:
            ConsoleLogger.status(
                '[FEATURES_DEC] _conv_trans_1 output size: {}'.format(
                    x.size()))

        x = F.relu(self._conv_trans_2(x))
        if self._verbose:
            ConsoleLogger.status(
                '[FEATURES_DEC] _conv_trans_2 output size: {}'.format(
                    x.size()))

        x = self._conv_trans_3(x)
        if self._verbose:
            ConsoleLogger.status(
                '[FEATURES_DEC] _conv_trans_3 output size: {}'.format(
                    x.size()))

        return x