def evaluate(self, evaluation_options):
        self._model.eval()

        if evaluation_options['plot_exp'] or \
            evaluation_options['plot_quantized_embedding_spaces'] or \
            evaluation_options['plot_distances_histogram']:
            evaluation_entry = self._evaluate_once()

        if evaluation_options['plot_exp']:
            self._compute_comparaison_plot(evaluation_entry)

        if evaluation_options['plot_quantized_embedding_spaces']:
            EmbeddingSpaceStats.compute_and_plot_quantized_embedding_space_projections(
                self._results_path, self._experiment_name, evaluation_entry,
                self._model.vq.embedding,
                self._data_stream.validation_batch_size)

        if evaluation_options['plot_distances_histogram']:
            self._plot_distances_histogram(evaluation_entry)

        #self._test_denormalization(evaluation_entry)

        if evaluation_options['compute_many_to_one_mapping']:
            self._many_to_one_mapping()

        if evaluation_options['compute_alignments'] or \
            evaluation_options['compute_clustering_metrics'] or \
            evaluation_options['compute_groundtruth_average_phonemes_number']:
            alignment_stats = AlignmentStats(
                self._data_stream, self._vctk, self._configuration,
                self._device, self._model, self._results_path,
                self._experiment_name, evaluation_options['alignment_subset'])
        if evaluation_options['compute_alignments']:
            groundtruth_alignments_path = self._results_path + os.sep + \
                'vctk_{}_groundtruth_alignments.pickle'.format(evaluation_options['alignment_subset'])
            if not os.path.isfile(groundtruth_alignments_path):
                alignment_stats.compute_groundtruth_alignments()
                alignment_stats.compute_groundtruth_bigrams_matrix(
                    wo_diag=True)
                alignment_stats.compute_groundtruth_bigrams_matrix(
                    wo_diag=False)
                alignment_stats.compute_groundtruth_phonemes_frequency()
            else:
                ConsoleLogger.status('Groundtruth alignments already exist')

            empirical_alignments_path = self._results_path + os.sep + self._experiment_name + \
                '_vctk_{}_empirical_alignments.pickle'.format(evaluation_options['alignment_subset'])
            if not os.path.isfile(empirical_alignments_path):
                alignment_stats.compute_empirical_alignments()
                alignment_stats.compute_empirical_bigrams_matrix(wo_diag=True)
                alignment_stats.compute_empirical_bigrams_matrix(wo_diag=False)
                alignment_stats.comupte_empirical_encodings_frequency()
            else:
                ConsoleLogger.status('Empirical alignments already exist')

        if evaluation_options['compute_clustering_metrics']:
            alignment_stats.compute_clustering_metrics()

        if evaluation_options['compute_groundtruth_average_phonemes_number']:
            alignment_stats.compute_groundtruth_average_phonemes_number()
示例#2
0
    def fetch(self):
        ConsoleLogger.status('Begin fetch...')
        font1 = {
            'family': 'Times New Roman',
            'weight': 'normal',
            'size': 40,
        }
        indices = torch.LongTensor(np.load('indices.npy')).to(self._device)
        idx = [32, 128, 512]
        title = [0.0625, 0.25, 1.0]
        os.makedirs('./audio', exist_ok=True)
        os.makedirs('./img', exist_ok=True)

        for i in range(len(indices)):

            fig, axs = plt.subplots(len(idx), 1, figsize=(35, 30), sharex=True)

            for j in range(len(idx)):
                axs[j].set_title('Fraction of Dimensions:' + str(title[j]),
                                 font1)
                x = self._model.indices_fetch(indices[i][:idx[j]].unsqueeze(
                    0))[0, 0].detach().cpu().numpy()
                axs[j].plot(np.arange(len(x)), x)
                write(
                    "./audio/sample_" + str(i) + "_" + str(title[j]) + ".wav",
                    16000, x)

            plt.savefig('./img/sample_' + str(i),
                        bbox_inches='tight',
                        pad_inches=0)
            plt.clf()
    def comupte_empirical_encodings_frequency(self):
        ConsoleLogger.status(
            'Computing empirical encodings frequency of VCTK val dataset...')

        alignments_dic = None
        with open(
                self._results_path + os.sep + self._experiment_name +
                '_vctk_empirical_alignments.pickle', 'rb') as f:
            alignments_dic = pickle.load(f)

        encodings_counter = alignments_dic['encodings_counter']
        desired_time_interval = alignments_dic['desired_time_interval']
        total_indices_apparations = alignments_dic['total_indices_apparations']

        encodings_frequency = dict()
        for key, value in encodings_counter.items():
            encodings_frequency[key] = value * 100 / total_indices_apparations

        encodings_frequency_sorted_keys = sorted(encodings_frequency,
                                                 key=encodings_frequency.get,
                                                 reverse=True)
        values = [
            encodings_frequency[key] for key in encodings_frequency_sorted_keys
        ]

        # TODO: add title
        fig, ax = plt.subplots(figsize=(20, 1))
        ax.bar(encodings_frequency_sorted_keys, values)
        fig.savefig(self._results_path + os.sep + self._experiment_name +
                    '_vctk_empirical_frequency_{}ms.png'.format(
                        int(desired_time_interval * 1000)),
                    bbox_inches='tight',
                    pad_inches=0)
        plt.close(fig)
    def compute_groundtruth_phonemes_frequency(self, wo_diag=True):
        ConsoleLogger.status(
            'Computing groundtruth phonemes frequency of VCTK val dataset...')

        alignments_dic = None
        with open(
                self._results_path + os.sep +
                'vctk_groundtruth_alignments.pickle', 'rb') as f:
            alignments_dic = pickle.load(f)

        desired_time_interval = alignments_dic['desired_time_interval']
        phonemes_counter = alignments_dic['phonemes_counter']
        total_phonemes_apparations = alignments_dic[
            'total_phonemes_apparations']

        phonemes_frequency = dict()
        for key, value in phonemes_counter.items():
            phonemes_frequency[key] = value * 100 / total_phonemes_apparations

        phonemes_frequency_sorted_keys = sorted(phonemes_frequency,
                                                key=phonemes_frequency.get,
                                                reverse=True)
        values = [
            phonemes_frequency[key] for key in phonemes_frequency_sorted_keys
        ]

        # TODO: add title
        fig, ax = plt.subplots(figsize=(20, 1))
        ax.bar(phonemes_frequency_sorted_keys, values)
        fig.savefig(self._results_path + os.sep +
                    'vctk_groundtruth_phonemes_frequency_{}ms.png'.format(
                        int(desired_time_interval * 1000)),
                    bbox_inches='tight',
                    pad_inches=0)
        plt.close(fig)
示例#5
0
    def _plot_merged_all_losses_type(self,
                                     all_results_paths,
                                     all_experiments_names,
                                     all_train_losses,
                                     all_train_perplexities,
                                     all_latest_epochs,
                                     colormap_name='tab20'):

        latest_epoch = all_latest_epochs[0]
        for i in range(1, len(all_latest_epochs)):
            if all_latest_epochs[i] != latest_epoch:
                raise ValueError(
                    'All experiments must have the same number of epochs to merge them'
                )

        results_path = all_results_paths[0]

        all_train_losses_smooth = dict()
        for i in range(len(all_train_losses)):
            for loss_name in all_train_losses[i].keys():
                if loss_name == 'loss':
                    continue
                if loss_name not in all_train_losses_smooth:
                    all_train_losses_smooth[loss_name] = list()
                all_train_losses_smooth[loss_name].append(
                    self._smooth_curve(all_train_losses[i][loss_name]))

        for loss_name in all_train_losses_smooth.keys():
            n_colors = len(all_train_losses_smooth[loss_name])
            colors = self._get_colors_from_cmap(colormap_name, n_colors)

            train_losses_smooth = all_train_losses_smooth[loss_name]
            all_train_loss_smooth = np.asarray(train_losses_smooth)
            all_train_loss_smooth = np.reshape(
                all_train_loss_smooth,
                (n_colors, latest_epoch,
                 all_train_loss_smooth.shape[1] // latest_epoch))

            fig, ax = plt.subplots(figsize=(8, 8))

            for j in range(len(all_train_loss_smooth)):
                ax = self._plot_fill_between(ax, colors[j],
                                             all_train_loss_smooth[j],
                                             all_experiments_names[j])
            ax = self._configure_ax(ax,
                                    title='Smoothed ' +
                                    loss_name.replace('_', ' '),
                                    xlabel='Epochs',
                                    ylabel='Loss',
                                    legend=True)
            output_plot_path = results_path + os.sep + loss_name + '.png'

            fig.savefig(output_plot_path)
            plt.close(fig)

            ConsoleLogger.success(
                "Saved figure at path '{}'".format(output_plot_path))
    def search_latest_checkpoint_file(checkpoint_files):
        # Search the latest checkpoint file
        ConsoleLogger.status('Searching the latest checkpoint file')
        latest_checkpoint_file = checkpoint_files[0]
        latest_epoch = int(checkpoint_files[0].split('_')[1])
        for i in range(1, len(checkpoint_files)):
            epoch = int(checkpoint_files[i].split('_')[1])
            if epoch > latest_epoch:
                latest_checkpoint_file = checkpoint_files[i]
                latest_epoch = epoch

        return latest_checkpoint_file, latest_epoch
示例#7
0
    def _plot_loss_and_perplexity_figures(self, all_results_paths,
                                          all_experiments_names,
                                          all_train_losses,
                                          all_train_perplexities,
                                          all_latest_epochs, n_colors, colors):

        for i in range(len(all_experiments_names)):
            results_path = all_results_paths[i]
            experiment_name = all_experiments_names[i]
            output_plot_path = results_path + os.sep + experiment_name + '_loss-and-perplexity.png'

            train_loss_smooth = self._smooth_curve(all_train_losses[i]['loss'])
            train_perplexity_smooth = self._smooth_curve(
                all_train_perplexities[i])

            latest_epoch = all_latest_epochs[i]

            train_loss_smooth = np.asarray(train_loss_smooth)
            train_perplexity_smooth = np.asarray(train_perplexity_smooth)
            train_loss_smooth = np.reshape(
                train_loss_smooth,
                (latest_epoch, train_loss_smooth.shape[0] // latest_epoch))
            train_perplexity_smooth = np.reshape(
                train_perplexity_smooth,
                (latest_epoch,
                 train_perplexity_smooth.shape[0] // latest_epoch))

            fig = plt.figure(figsize=(16, 8))

            ax = fig.add_subplot(1, 2, 1)
            ax = self._plot_fill_between(ax, colors[i], train_loss_smooth,
                                         all_experiments_names[i])
            ax = self._configure_ax(ax,
                                    title='Smoothed loss',
                                    xlabel='Epochs',
                                    ylabel='Loss',
                                    legend=False)

            ax = fig.add_subplot(1, 2, 2)
            ax = self._plot_fill_between(ax, colors[i],
                                         train_perplexity_smooth,
                                         all_experiments_names[i])
            ax = self._configure_ax(ax,
                                    title='Smoothed average codebook usage',
                                    xlabel='Epochs',
                                    ylabel='Perplexity',
                                    legend=False)

            fig.savefig(output_plot_path)
            plt.close(fig)

            ConsoleLogger.success(
                "Saved figure at path '{}'".format(output_plot_path))
示例#8
0
    def load_from_configuration(configuration):
        use_cuda = configuration['use_cuda'] and torch.cuda.is_available(
        )  # Use cuda if specified and available
        default_device = 'cuda' if use_cuda else 'cpu'  # Use default cuda device if possible or use the cpu
        device = configuration['use_device'] if configuration[
            'use_device'] is not None else default_device  # Use a defined device if specified
        gpu_ids = [i for i in range(torch.cuda.device_count())
                   ] if configuration['use_data_parallel'] else [
                       0
                   ]  # Resolve the gpu ids if gpu parallelization is specified
        if configuration['use_device'] and ':' in configuration['use_device']:
            gpu_ids = [int(configuration['use_device'].split(':')[1])]

        use_data_parallel = True if configuration[
            'use_data_parallel'] and use_cuda and len(gpu_ids) > 1 else False

        ConsoleLogger.status('The used device is: {}'.format(device))
        ConsoleLogger.status('The gpu ids are: {}'.format(gpu_ids))

        # Sanity checks
        if not use_cuda and configuration['use_cuda']:
            ConsoleLogger.warn(
                "The configuration file specified use_cuda=True but cuda isn't available"
            )
        if configuration['use_data_parallel'] and len(gpu_ids) < 2:
            ConsoleLogger.warn(
                'The configuration file specified use_data_parallel=True but there is only {} GPU available'
                .format(len(gpu_ids)))

        return DeviceConfiguration(use_cuda, device, gpu_ids,
                                   use_data_parallel)
    def test_global_conditioning(self):
        configuration = None
        with open('../../configurations/vctk_features.yaml', 'r') as configuration_file:
            configuration = yaml.load(configuration_file)
        device_configuration = DeviceConfiguration.load_from_configuration(configuration)
        data_stream = VCTKSpeechStream(configuration, device_configuration.gpu_ids, device_configuration.use_cuda)
        (x_enc, x_dec, speaker_id, _, _) = next(iter(data_stream.training_loader))

        ConsoleLogger.status('x_enc.size(): {}'.format(x_enc.size()))
        ConsoleLogger.status('x_dec.size(): {}'.format(x_dec.size()))

        x = x_dec.squeeze(-1)
        global_conditioning = GlobalConditioning.compute(
            speaker_dic=data_stream.speaker_dic,
            speaker_ids=speaker_id,
            x_one_hot=x,
            expand=False
        )
        self.assertEqual(global_conditioning.size(), torch.Size([1, 128, 1]))
        ConsoleLogger.success('global_conditioning.size(): {}'.format(global_conditioning.size()))

        expanded_global_conditioning = GlobalConditioning.compute(
            speaker_dic=data_stream.speaker_dic,
            speaker_ids=speaker_id,
            x_one_hot=x,
            expand=True
        )
        self.assertEqual(expanded_global_conditioning.size(), torch.Size([1, 128, 7680]))
        ConsoleLogger.success('expanded_global_conditioning.size(): {}'.format(expanded_global_conditioning.size()))
示例#10
0
    def plot_training_losses(self, experiments, experiments_path):
        all_train_losses = list()
        all_train_perplexities = list()
        all_results_paths = list()
        all_experiments_names = list()
        all_latest_epochs = list()

        for experiment in experiments:
            try:
                train_res_losses, train_res_perplexities, latest_epoch = \
                    CheckpointUtils.retreive_losses_values(experiments_path, experiment)
                all_train_losses.append(train_res_losses)
                all_train_perplexities.append(train_res_perplexities)
                all_results_paths.append(experiment.results_path)
                all_experiments_names.append(experiment.name)
                all_latest_epochs.append(latest_epoch)
            except:
                ConsoleLogger.error(
                    "Failed to retreive losses of experiment '{}'".format(
                        experiment.name))

        n_final_losses_colors = len(all_train_losses)
        final_losses_colors = self._get_colors_from_cmap(
            self._colormap_name, n_final_losses_colors)

        # for each experiment: final loss + perplexity
        self._plot_loss_and_perplexity_figures(
            all_results_paths, all_experiments_names, all_train_losses,
            all_train_perplexities, all_latest_epochs, n_final_losses_colors,
            final_losses_colors)

        # merged experiment: merged final losses + merged perplexities
        self._plot_merged_losses_and_perplexities_figure(
            all_results_paths, all_experiments_names, all_train_losses,
            all_train_perplexities, all_latest_epochs, n_final_losses_colors,
            final_losses_colors)

        # for each experiment: all possible losses
        self._plot_merged_all_losses_figures(all_results_paths,
                                             all_experiments_names,
                                             all_train_losses,
                                             all_train_perplexities,
                                             all_latest_epochs)

        # merged losses of a single type in all experiments
        self._plot_merged_all_losses_type(all_results_paths,
                                          all_experiments_names,
                                          all_train_losses,
                                          all_train_perplexities,
                                          all_latest_epochs)
    def compute_groundtruth_average_phonemes_number(self):
        alignments_dic = None
        with open(
                self._results_path + os.sep +
                'vctk_groundtruth_alignments.pickle', 'rb') as f:
            alignments_dic = pickle.load(f)

        extended_alignment_dataset = alignments_dic[
            'extended_alignment_dataset']

        phonemes_number = list()
        for _, alignment in extended_alignment_dataset:
            phonemes_number.append(len(np.unique(alignment)))
        ConsoleLogger.success(
            'The average number of phonemes per alignment for {} alignments is: {}'
            .format(len(extended_alignment_dataset),
                    np.mean(round(phonemes_number, 2))))
示例#12
0
    def eval(self):
        ConsoleLogger.status('start epoch: {}'.format(
            self._configuration['start_epoch']))
        ConsoleLogger.status('num epoch: {}'.format(
            self._configuration['num_epochs']))
        for index in [16, 32, 48, 64]:
            recons_loss = 0.0
            cnt = 0.0
            with tqdm(self._data_stream.training_loader) as train_bar:
                train_res_recon_error = list(
                )  # FIXME: record as a global metric
                train_res_perplexity = list(
                )  # FIXME: record as a global metric
                iteration = 0
                max_iterations_number = len(train_bar)
                iterations = list(
                    np.arange(max_iterations_number,
                              step=(max_iterations_number /
                                    self._iterations_to_record) - 1,
                              dtype=int))
                for data in train_bar:
                    if len(data['one_hot']) == 1:
                        continue
                    losses, perplexity_value, sample_index = self.iterate_deconv(
                        data,
                        0,
                        iteration,
                        iterations,
                        train_bar,
                        eval=True,
                        designate=index)
                    if losses is None or perplexity_value is None:
                        continue
                    train_res_recon_error.append(losses)
                    train_res_perplexity.append(perplexity_value)
                    iteration += 1
                    recons_loss += losses['reconstruction_loss']
                    cnt += len(data['one_hot'])

            print("Index:{},Train_res_recon_error:{}".format(
                index, recons_loss / cnt))
            self.save_eval(
                0, **{
                    'train_res_recon_error': train_res_recon_error,
                    'train_res_perplexity': train_res_perplexity
                })
    def load_configuration_and_checkpoints(experiments_path, experiment_name):
        configuration_file, checkpoint_files = CheckpointUtils.search_configuration_and_checkpoints_files(
            experiments_path, experiment_name)

        # Check if a configuration file was found
        if not configuration_file:
            raise ValueError(
                'No configuration file found with name: {}'.format(
                    experiment_name))

        # Check if at least one checkpoint file was found
        if len(checkpoint_files) == 0:
            ConsoleLogger.warn(
                'No checkpoint files found with name: {}'.format(
                    experiment_name))

        return configuration_file, checkpoint_files
    def merge_experiment_losses(experiment_path, checkpoint_files, device_configuration):
        train_res_losses = {}
        train_res_perplexities = []

        sorted_checkpoint_files = sorted(checkpoint_files, key=lambda x: int(x.split('.')[0].split('_')[-2]))
        for checkpoint_file in sorted_checkpoint_files:
            # Load the checkpoint file
            checkpoint_path = experiment_path + os.sep + checkpoint_file
            ConsoleLogger.status("Loading the checkpoint file '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path, map_location=device_configuration.device)
            for loss_entry in checkpoint['train_res_recon_error']:
                for key in loss_entry.keys():
                    if key not in train_res_losses:
                        train_res_losses[key] = list()
                    train_res_losses[key].append(loss_entry[key])
            train_res_perplexities += checkpoint['train_res_perplexity']

        return train_res_losses, train_res_perplexities
    def plot_gradient_flow_over_epochs(gradient_stats_entries,
                                       output_file_name):
        epoch_number, iteration_number = set(), set()
        for epoch, iteration, _ in gradient_stats_entries:
            epoch_number.add(epoch)
            iteration_number.add(iteration)

        epoch_number = len(epoch_number)
        iteration_number = len(iteration_number)

        fig, axs = plt.subplots(epoch_number,
                                iteration_number,
                                figsize=(epoch_number * 8,
                                         iteration_number * 8),
                                sharey=True,
                                sharex=True)
        k = 0
        for i in trange(epoch_number):
            for j in trange(iteration_number):
                _, _, gradient_stats_entry = gradient_stats_entries[k]
                GradientStats.plot_gradient_flow(
                    gradient_stats_entry['model'],
                    axs[i][j],
                    set_xticks=True if i + 1 == epoch_number else False,
                    set_ylabels=True if j == 0 else False)
                k += 1

        ConsoleLogger.status('Saving gradient flow plot...')
        fig.suptitle('Gradient flow', fontsize='x-large')
        fig.legend(
            [
                Line2D([0], [0], color='c', lw=4),
                Line2D([0], [0], color='b', lw=4),
                Line2D([0], [0], color='k', lw=4)
            ],
            ['max-gradient', 'mean-gradient', 'zero-gradient'],
            loc="center right",  # Position of legend
            borderaxespad=0.1  # Small spacing around legend box
        )
        fig.savefig(output_file_name,
                    bbox_inches='tight',
                    pad_inches=0,
                    dpi=200)
        plt.close(fig)
示例#16
0
 def train(self):
     ConsoleLogger.status("Running the experiment called '{}'".format(
         self._name))
     ConsoleLogger.status('Begins to train the model')
     self._trainer.train()
     ConsoleLogger.success(
         "Succeed to runned the experiment called '{}'".format(self._name))
示例#17
0
 def evaluate(self, evaluation_options):
     ConsoleLogger.status("Running the experiment called '{}'".format(
         self._name))
     ConsoleLogger.status('Begins to evaluate the model')
     self._evaluator.evaluate(evaluation_options)
     ConsoleLogger.success(
         "Succeed to runned the experiment called '{}'".format(self._name))
    def forward(self, x, speaker_dic, speaker_id,full=False, designate=None):

        if self._verbose:
            #print("max:{},min:{}".format(torch.max(x), torch.min(x)))
            ConsoleLogger.status('[ConvVQVAE] _encoder input size: {}'.format(x.size()))
        x = x.permute(0, 2, 1).contiguous().float()
        z = self._encoder(x)
        if self._verbose:
            ConsoleLogger.status('[ConvVQVAE] _encoder output size: {}'.format(z.size()))

        #z = self._pre_vq_conv(z)
        if self._verbose:
            ConsoleLogger.status('[ConvVQVAE] _pre_vq_conv output size: {}'.format(z.size()))

        z_q_x_st, z_q_x, indices = self._vq.straight_through(z)
        area = z.size(1)
        sample_index = area
        if not full:
            sample_index = np.array((np.random.randint(area) + 1))
        if designate != None:
            sample_index = designate
        zero_out = torch.zeros((z_q_x_st.size(0), area - sample_index, z_q_x_st.size(2))).cuda()
        z_q_x_st = torch.cat((z_q_x_st[:, :sample_index], zero_out), dim=1)
        reconstructed_x = self._decoder(z_q_x_st, speaker_dic, speaker_id)
        output_features_size = reconstructed_x.size(2)
        reconstructed_x = reconstructed_x.view(-1, 1, output_features_size)
        return reconstructed_x, z, z_q_x, sample_index, indices
示例#19
0
    def train(self):
        ConsoleLogger.status('start epoch: {}'.format(
            self._configuration['start_epoch']))
        ConsoleLogger.status('num epoch: {}'.format(
            self._configuration['num_epochs']))

        for epoch in range(self._configuration['start_epoch'],
                           self._configuration['num_epochs']):
            with tqdm(self._data_stream.training_loader) as train_bar:
                train_res_recon_error = list(
                )  # FIXME: record as a global metric
                train_res_perplexity = list(
                )  # FIXME: record as a global metric
                train_res_recon_error_index = dict()
                index_cnt = dict()
                iteration = 0
                loss_sum = 0.0
                max_iterations_number = len(train_bar)
                iterations = list(
                    np.arange(max_iterations_number,
                              step=(max_iterations_number /
                                    self._iterations_to_record) - 1,
                              dtype=int))

                for data in train_bar:
                    if len(data['one_hot']) == 1:
                        continue
                    if self._configuration[
                            'decoder_type'] == 'deconvolutional':
                        _, loss_res = self.iterate_deconv(data,
                                                          epoch,
                                                          iteration,
                                                          iterations,
                                                          train_bar,
                                                          eval=False,
                                                          return_loss=True)
                        loss_sum += loss_res
                        iteration += 1
                print("Average loss per iteration:", loss_sum / iteration)
                self.save(epoch)
    def retreive_losses_values(experiment_path, experiment):
        experiment_name = experiment.name

        ConsoleLogger.status("Searching configuration and checkpoints of experiment '{}' at path '{}'".format(experiment_name, experiment_path))
        configuration_file, checkpoint_files = CheckpointUtils.search_configuration_and_checkpoints_files(
            experiment_path,
            experiment_name
        )

        # Check if a configuration file was found
        if not configuration_file:
            raise ValueError('No configuration file found with name: {}'.format(experiment_name))

        # Check if at least one checkpoint file was found
        if len(checkpoint_files) == 0:
            raise ValueError('No checkpoint files found with name: {}'.format(experiment_name))

        # Load the configuration file
        configuration_path = experiment_path + os.sep + configuration_file
        ConsoleLogger.status("Loading the configuration file '{}'".format(configuration_path))
        configuration = None
        with open(configuration_path, 'r') as file:
            configuration = yaml.load(file, Loader=yaml.FullLoader)
        
        # Load the device configuration from the configuration state
        device_configuration = DeviceConfiguration.load_from_configuration(configuration)

        ConsoleLogger.status("Merge {} checkpoint losses of experiment '{}'".format(len(checkpoint_files), experiment_name))
        train_res_losses, train_res_perplexities = CheckpointUtils.merge_experiment_losses(
            experiment_path,
            checkpoint_files,
            device_configuration
        )

        return train_res_losses, train_res_perplexities, len(checkpoint_files)
    def search_configuration_and_checkpoints_files(experiment_path, experiment_name):
        # Check if the specified experiment path exists
        ConsoleLogger.status("Checking if the experiment path '{}' exists".format(experiment_path))
        if not os.path.isdir(experiment_path):
            raise ValueError("Specified experiment path '{}' doesn't not exist".format(experiment_path))

        # List all the files from this directory and raise an error if it's empty
        ConsoleLogger.status('Listing the specified experiment path directory')
        files = os.listdir(experiment_path)
        if not files or len(files) == 0:
            raise ValueError("Specified experiment path '{}' is empty".format(experiment_path))

        # Search the configuration file and the checkpoint files of the specified experiment
        ConsoleLogger.status('Searching the configuration file and the checkpoint files')
        checkpoint_files = list()
        configuration_file = None
        for file in files:
            # Check if the file is a checkpoint or config file by looking at the extension
            if not '.pth' in file and '.yaml' not in file:
                continue
            split_file = file.split('_')
            if len(split_file) > 1 and split_file[0] == experiment_name and split_file[1] == 'configuration.yaml':
                configuration_file = file
            elif len(split_file) > 1 and split_file[0] == experiment_name and split_file[1] != 'configuration.yaml':
                checkpoint_files.append(file)

        return configuration_file, checkpoint_files
    def __init__(self, name, experiments_path, results_path, global_configuration,
        experiment_configuration, seed):

        self._name = name
        self._experiments_path = experiments_path
        self._results_path = results_path
        self._global_configuration = global_configuration
        self._experiment_configuration = experiment_configuration
        self._seed = seed

        # Create the experiments path directory if it doesn't exist
        if not os.path.isdir(experiments_path):
            ConsoleLogger.status('Creating experiments directory at path: {}'.format(experiments_path))
            os.mkdir(experiments_path)
        else:
            ConsoleLogger.status('Experiments directory already created at path: {}'.format(experiments_path))

        # Create the results path directory if it doesn't exist
        if not os.path.isdir(results_path):
            ConsoleLogger.status('Creating results directory at path: {}'.format(results_path))
            os.mkdir(results_path)
        else:
            ConsoleLogger.status('Results directory already created at path: {}'.format(results_path))



        experiments_configuration_path = experiments_path + os.sep + name + '_configuration.yaml'
        configuration_file_already_exists = True if os.path.isfile(experiments_configuration_path) else False
        if not configuration_file_already_exists:
            self._device_configuration = DeviceConfiguration.load_from_configuration(global_configuration)

            # Create a new configuration state from the default and the experiment specific aspects
            self._configuration = copy.deepcopy(self._global_configuration)
            for experiment_key in experiment_configuration.keys():
                if experiment_key in self._configuration:
                    self._configuration[experiment_key] = experiment_configuration[experiment_key]

            # Save the configuration of the experiments
            with open(experiments_configuration_path, 'w') as file:
                yaml.dump(self._configuration, file)
        else:
            with open(experiments_configuration_path, 'r') as file:
                self._configuration = yaml.load(file, Loader=yaml.FullLoader)
                self._device_configuration = DeviceConfiguration.load_from_configuration(self._configuration)
        if configuration_file_already_exists:
            self._trainer, self._evaluator, self._configuration, self._device_configuration = PipelineFactory.load(self._experiments_path, self._name, self._results_path)
        else:
            self._trainer, self._evaluator = PipelineFactory.build(self._configuration,
                self._device_configuration, self._experiments_path, self._name, self._results_path)
示例#23
0
    def forward(self, x, speaker_dic, speaker_id):
        x = x.permute(0, 2, 1).contiguous().float()

        z = self._encoder(x)
        if self._verbose:
            ConsoleLogger.status('[ConvVQVAE] _encoder output size: {}'.format(z.size()))

        z = self._pre_vq_conv(z)
        if self._verbose:
            ConsoleLogger.status('[ConvVQVAE] _pre_vq_conv output size: {}'.format(z.size()))

        vq_loss, quantized, perplexity, _, _, encoding_indices, \
            losses, _, _, _, concatenated_quantized = self._vq(z, record_codebook_stats=self._record_codebook_stats)

        reconstructed_x = self._decoder(quantized, speaker_dic, speaker_id)

        input_features_size = x.size(2)
        output_features_size = reconstructed_x.size(2)

        reconstructed_x = reconstructed_x.view(-1, self._output_features_filters, output_features_size)
        reconstructed_x = reconstructed_x[:, :, :-(output_features_size-input_features_size)]
        
        return reconstructed_x, vq_loss, losses, perplexity, encoding_indices, concatenated_quantized
示例#24
0
    def train(self):
        ConsoleLogger.status('start epoch: {}'.format(self._configuration['start_epoch']))
        ConsoleLogger.status('num epoch: {}'.format(self._configuration['num_epochs']))

        for epoch in range(self._configuration['start_epoch'], self._configuration['num_epochs']):

            with tqdm(self._data_stream.training_loader) as train_bar:
                train_res_recon_error = list() # FIXME: record as a global metric
                train_res_perplexity = list() # FIXME: record as a global metric

                iteration = 0
                max_iterations_number = len(train_bar)
                iterations = list(np.arange(max_iterations_number, step=(max_iterations_number / self._iterations_to_record) - 1, dtype=int))

                for data in train_bar:
                    losses, perplexity_value = self.iterate(data, epoch, iteration, iterations, train_bar)
                    if losses is None or perplexity_value is None:
                        continue
                    train_res_recon_error.append(losses)
                    train_res_perplexity.append(perplexity_value)
                    iteration += 1

                self.save(epoch, **{'train_res_recon_error': train_res_recon_error, 'train_res_perplexity': train_res_perplexity})
示例#25
0
    def _plot_distances_histogram(self, evaluation_entry):
        encoding_distances = evaluation_entry['encoding_distances'][0].detach(
        ).cpu().numpy()
        embedding_distances = evaluation_entry['embedding_distances'].detach(
        ).cpu().numpy()
        frames_vs_embedding_distances = evaluation_entry[
            'frames_vs_embedding_distances'].detach()[0].cpu().transpose(
                0, 1).numpy().ravel()

        if self._configuration['verbose']:
            ConsoleLogger.status('encoding_distances[0].size(): {}'.format(
                encoding_distances.shape))
            ConsoleLogger.status('embedding_distances.size(): {}'.format(
                embedding_distances.shape))
            ConsoleLogger.status(
                'frames_vs_embedding_distances[0].shape: {}'.format(
                    frames_vs_embedding_distances.shape))

        fig, axs = plt.subplots(3, 1, figsize=(30, 20), sharex=True)

        axs[0].set_title('\n'.join(
            wrap('Histogram of the distances between the'
                 ' encodings vectors', 60)))
        sns.distplot(encoding_distances,
                     hist=True,
                     kde=False,
                     ax=axs[0],
                     norm_hist=True)

        axs[1].set_title('\n'.join(
            wrap(
                'Histogram of the distances between the'
                ' embeddings vectors', 60)))
        sns.distplot(embedding_distances,
                     hist=True,
                     kde=False,
                     ax=axs[1],
                     norm_hist=True)

        axs[2].set_title(
            'Histogram of the distances computed in'
            ' VQ\n($||z_e(x) - e_i||^2_2$ with $z_e(x)$ the output of the encoder'
            ' prior to quantization)')
        sns.distplot(frames_vs_embedding_distances,
                     hist=True,
                     kde=False,
                     ax=axs[2],
                     norm_hist=True)

        output_path = self._results_path + os.sep + self._experiment_name + '_distances-histogram-plot.png'
        fig.savefig(output_path, bbox_inches='tight', pad_inches=0)
        plt.close(fig)
    def load(experiments_path,
             experiment_name,
             results_path,
             data_path='../data'):
        error_caught = False

        try:
            configuration_file, checkpoint_files = PipelineFactory.load_configuration_and_checkpoints(
                experiments_path, experiment_name)
        except:
            ConsoleLogger.error(
                'Failed to load existing configuration. Building a new model...'
            )
            error_caught = True

        # Load the configuration file
        ConsoleLogger.status('Loading the configuration file')
        configuration = None
        with open(experiments_path + os.sep + configuration_file, 'r') as file:
            configuration = yaml.load(file, Loader=yaml.FullLoader)
        device_configuration = DeviceConfiguration.load_from_configuration(
            configuration)

        if error_caught or len(checkpoint_files) == 0:
            trainer, evaluator = PipelineFactory.build(configuration,
                                                       device_configuration,
                                                       experiments_path,
                                                       experiment_name,
                                                       results_path)
        else:
            latest_checkpoint_file, latest_epoch = CheckpointUtils.search_latest_checkpoint_file(
                checkpoint_files)
            # Update the epoch number to begin with for the future training
            configuration['start_epoch'] = latest_epoch
            configuration['num_epochs'] = 60
            #latest_checkpoint_file = 'baseline_15_checkpoint.pth'
            #print(latest_checkpoint_file)
            # Load the checkpoint file
            checkpoint_path = experiments_path + os.sep + latest_checkpoint_file
            ConsoleLogger.status(
                "Loading the checkpoint file '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path,
                                    map_location=device_configuration.device)

            # Load the data stream
            ConsoleLogger.status('Loading the data stream')
            data_stream = VCTKFeaturesStream('/atlas/u/xuyilun/vctk',
                                             configuration,
                                             device_configuration.gpu_ids,
                                             device_configuration.use_cuda)

            def load_state_dicts(model, checkpoint, model_name,
                                 optimizer_name):
                # Load the state dict from the checkpoint to the model
                model.load_state_dict(checkpoint[model_name])
                # Create an Adam optimizer using the model parameters
                optimizer = optim.Adam(model.parameters())
                # Load the state dict from the checkpoint to the optimizer
                optimizer.load_state_dict(checkpoint[optimizer_name])
                # Map the optimizer memory into the specified device
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to(device_configuration.device)
                return model, optimizer

            # If the decoder type is a deconvolutional
            if configuration['decoder_type'] == 'deconvolutional':
                # Create the model and map it to the specified device
                vqvae_model = ConvolutionalVQVAE(
                    configuration, device_configuration.device).to(
                        device_configuration.device)
                evaluator = Evaluator(device_configuration.device, vqvae_model,
                                      data_stream, configuration, results_path,
                                      experiment_name)

                # Load the model and optimizer state dicts
                vqvae_model, vqvae_optimizer = load_state_dicts(
                    vqvae_model, checkpoint, 'model', 'optimizer')
            elif configuration['decoder_type'] == 'wavenet':
                vqvae_model = WaveNetVQVAE(configuration,
                                           data_stream.speaker_dic,
                                           device_configuration.device).to(
                                               device_configuration.device)
                evaluator = Evaluator(device_configuration.device, vqvae_model,
                                      data_stream, configuration, results_path,
                                      experiment_name)
                # Load the model and optimizer state dicts
                vqvae_model, vqvae_optimizer = load_state_dicts(
                    vqvae_model, checkpoint, 'model', 'optimizer')
            else:
                raise NotImplementedError(
                    "Decoder type '{}' isn't implemented for now".format(
                        configuration['decoder_type']))

            # Temporary backward compatibility
            if 'trainer_type' not in configuration:
                ConsoleLogger.error(
                    "trainer_type was not found in configuration file. Use 'convolutional' by default."
                )
                configuration['trainer_type'] = 'convolutional'

            if configuration['trainer_type'] == 'convolutional':
                trainer = ConvolutionalTrainer(
                    device_configuration.device, data_stream, configuration,
                    experiments_path, experiment_name, **{
                        'model': vqvae_model,
                        'optimizer': vqvae_optimizer
                    })
            else:
                raise NotImplementedError(
                    "Trainer type '{}' isn't implemented for now".format(
                        configuration['trainer_type']))

            # Use data parallelization if needed and available
            vqvae_model = vqvae_model

        return trainer, evaluator, configuration, device_configuration
示例#27
0
    def forward(self, inputs):
        if self._verbose:
            ConsoleLogger.status('inputs size: {}'.format(inputs.size()))

        x_conv_1 = F.relu(self._conv_1(inputs))
        if self._verbose:
            ConsoleLogger.status('x_conv_1 output size: {}'.format(
                x_conv_1.size()))

        x = F.relu(self._conv_2(x_conv_1)) + x_conv_1
        if self._verbose:
            ConsoleLogger.status('_conv_2 output size: {}'.format(x.size()))

        x_conv_3 = F.relu(self._conv_3(x))
        if self._verbose:
            ConsoleLogger.status('_conv_3 output size: {}'.format(
                x_conv_3.size()))

        x_conv_4 = F.relu(self._conv_4(x_conv_3)) + x_conv_3
        if self._verbose:
            ConsoleLogger.status('_conv_4 output size: {}'.format(
                x_conv_4.size()))

        x_conv_5 = F.relu(self._conv_5(x_conv_4)) + x_conv_4
        if self._verbose:
            ConsoleLogger.status('x_conv_5 output size: {}'.format(
                x_conv_5.size()))

        x = self._residual_stack(x_conv_5) + x_conv_5
        if self._verbose:
            ConsoleLogger.status('_residual_stack output size: {}'.format(
                x.size()))

        return x
示例#28
0
    def dump(self):
        ConsoleLogger.status('start epoch: {}'.format(
            self._configuration['start_epoch']))
        ConsoleLogger.status('num epoch: {}'.format(
            self._configuration['num_epochs']))
        print('Dumping Codebook')
        name = 'wave_bigger_baseline'
        with tqdm(self._data_stream.training_loader) as train_bar:
            iteration = 0
            max_iterations_number = len(train_bar)
            iterations = list(
                np.arange(
                    max_iterations_number,
                    step=(max_iterations_number / self._iterations_to_record) -
                    1,
                    dtype=int))
            lst = []
            for data in train_bar:
                if len(data['one_hot']) == 1:
                    continue
                if self._configuration['decoder_type'] == 'deconvolutional':
                    indices = self.iterate_deconv(data,
                                                  0,
                                                  iteration,
                                                  iterations,
                                                  train_bar,
                                                  eval=True).view(
                                                      len(data['one_hot']), -1)
                    iteration += 1
                    lst.append(indices.cpu())
        lst = torch.cat(lst, 0)
        os.makedirs(os.path.join('./data-bin', name), exist_ok=True)
        torch.save(lst, os.path.join('./data-bin', name, 'train.pt'))
        print('Dumped training Codebook')

        with tqdm(self._data_stream.validation_loader) as train_bar:
            iteration = 0
            max_iterations_number = len(train_bar)
            iterations = list(
                np.arange(
                    max_iterations_number,
                    step=(max_iterations_number / self._iterations_to_record) -
                    1,
                    dtype=int))
            lst = []
            for data in train_bar:
                if len(data['one_hot']) == 1:
                    continue
                if self._configuration['decoder_type'] == 'deconvolutional':
                    indices = self.iterate_deconv(data,
                                                  0,
                                                  iteration,
                                                  iterations,
                                                  train_bar,
                                                  eval=True).view(
                                                      len(data['one_hot']), -1)
                    iteration += 1
                    lst.append(indices.cpu())
        lst = torch.cat(lst, 0)
        torch.save(lst, os.path.join('./data-bin', name, 'valid.pt'))
        print('Dumped validation Codebook')
示例#29
0
    def export_to_features(self, vctk_path, configuration):
        if not os.path.isdir(vctk_path):
            raise ValueError(
                "VCTK dataset not found at path '{}'".format(vctk_path))

        # Create the features path directory if it doesn't exist
        features_path = vctk_path + os.sep + configuration['features_path']
        if not os.path.isdir(features_path):
            ConsoleLogger.status(
                'Creating features directory at path: {}'.format(
                    features_path))
            os.mkdir(features_path)
        else:
            ConsoleLogger.status(
                'Features directory already created at path: {}'.format(
                    features_path))

        # Create the features path directory if it doesn't exist
        train_features_path = features_path + os.sep + 'train'
        if not os.path.isdir(train_features_path):
            ConsoleLogger.status(
                'Creating train features directory at path: {}'.format(
                    train_features_path))
            os.mkdir(train_features_path)
        else:
            ConsoleLogger.status(
                'Train features directory already created at path: {}'.format(
                    train_features_path))

        # Create the features path directory if it doesn't exist
        val_features_path = features_path + os.sep + 'val'
        if not os.path.isdir(val_features_path):
            ConsoleLogger.status(
                'Creating val features directory at path: {}'.format(
                    val_features_path))
            os.mkdir(val_features_path)
        else:
            ConsoleLogger.status(
                'Val features directory already created at path: {}'.format(
                    val_features_path))

        def process(loader, output_dir, input_features_name,
                    output_features_name, rate, input_filters_number,
                    output_filters_number, input_target_shape,
                    augment_output_features, export_one_hot_features):

            initial_index = 0
            attempts = 10
            current_attempt = 0
            total_length = len(loader)

            while current_attempt < attempts:
                try:
                    i = initial_index
                    bar = tqdm(loader, initial=initial_index)
                    for data in bar:
                        (preprocessed_audio, one_hot, speaker_id, quantized,
                         wav_filename, sampling_rate, shifting_time,
                         random_starting_index, preprocessed_length,
                         top_db) = data

                        output_path = output_dir + os.sep + str(i) + '.pickle'
                        if os.path.isfile(output_path):
                            if os.path.getsize(output_path) == 0:
                                bar.set_description(
                                    '{} already exists but is empty. Computing it again...'
                                    .format(output_path))
                                os.remove(output_path)
                            else:
                                bar.set_description(
                                    '{} already exists'.format(output_path))
                            i += 1
                            continue

                        input_features = SpeechFeatures.features_from_name(
                            name=input_features_name,
                            signal=preprocessed_audio,
                            rate=rate,
                            filters_number=input_filters_number)

                        if input_features.shape[0] != input_target_shape[
                                0] or input_features.shape[
                                    1] != input_target_shape[1]:
                            ConsoleLogger.warn(
                                "Raw features number {} with invalid dimension {} will not be saved. Target shape: {}"
                                .format(i, input_features.shape,
                                        input_target_shape))
                            i += 1
                            continue

                        output_features = SpeechFeatures.features_from_name(
                            name=output_features_name,
                            signal=preprocessed_audio,
                            rate=rate,
                            filters_number=output_filters_number,
                            augmented=augment_output_features)

                        # TODO: add an option in configuration to save quantized/one_hot or not
                        output = {
                            'preprocessed_audio':
                            preprocessed_audio,
                            'wav_filename':
                            wav_filename,
                            'input_features':
                            input_features,
                            'one_hot':
                            one_hot
                            if export_one_hot_features else np.array([]),
                            'quantized':
                            np.array([]),
                            'speaker_id':
                            speaker_id,
                            'output_features':
                            output_features,
                            'shifting_time':
                            shifting_time,
                            'random_starting_index':
                            random_starting_index,
                            'preprocessed_length':
                            preprocessed_length,
                            'sampling_rate':
                            sampling_rate,
                            'top_db':
                            top_db
                        }

                        with open(output_path, 'wb') as file:
                            pickle.dump(output, file)

                        bar.set_description('{} saved'.format(output_path))

                        i += 1

                        if i == total_length:
                            bar.update(total_length)
                            break

                    bar.close()
                    break
                except KeyboardInterrupt:
                    bar.close()
                    ConsoleLogger.warn(
                        'Keyboard interrupt detected. Leaving the function...')
                    return
                except:
                    error_message = 'An error occured in the data loader at {}/{}. Current attempt: {}/{}'.format(
                        output_dir, i, current_attempt + 1, attempts)
                    self._logger.exception(error_message)
                    ConsoleLogger.error(error_message)
                    initial_index = i
                    current_attempt += 1
                    continue

        try:
            ConsoleLogger.status('Processing training part')
            process(
                loader=self._training_loader,
                output_dir=train_features_path,
                input_features_name=configuration['input_features_type'],
                output_features_name=configuration['output_features_type'],
                rate=configuration['sampling_rate'],
                input_filters_number=configuration['input_features_filters'],
                output_filters_number=configuration['output_features_filters'],
                input_target_shape=(configuration['input_features_dim'],
                                    configuration['input_features_filters'] *
                                    3),
                augment_output_features=configuration[
                    'augment_output_features'],
                export_one_hot_features=configuration[
                    'export_one_hot_features'])
            ConsoleLogger.success('Training part processed')
        except:
            ConsoleLogger.error(
                'An error occured during training features generation')

        try:
            ConsoleLogger.status('Processing validation part')
            process(
                loader=self._validation_loader,
                output_dir=val_features_path,
                input_features_name=configuration['input_features_type'],
                output_features_name=configuration['output_features_type'],
                rate=configuration['sampling_rate'],
                input_filters_number=configuration['input_features_filters'],
                output_filters_number=configuration['output_features_filters'],
                input_target_shape=(configuration['input_features_dim'],
                                    configuration['input_features_filters'] *
                                    3),
                augment_output_features=configuration[
                    'augment_output_features'],
                export_one_hot_features=configuration[
                    'export_one_hot_features'])
            ConsoleLogger.success('Validation part processed')
        except:
            ConsoleLogger.error(
                'An error occured during validation features generation')
示例#30
0
        def process(loader, output_dir, input_features_name,
                    output_features_name, rate, input_filters_number,
                    output_filters_number, input_target_shape,
                    augment_output_features, export_one_hot_features):

            initial_index = 0
            attempts = 10
            current_attempt = 0
            total_length = len(loader)

            while current_attempt < attempts:
                try:
                    i = initial_index
                    bar = tqdm(loader, initial=initial_index)
                    for data in bar:
                        (preprocessed_audio, one_hot, speaker_id, quantized,
                         wav_filename, sampling_rate, shifting_time,
                         random_starting_index, preprocessed_length,
                         top_db) = data

                        output_path = output_dir + os.sep + str(i) + '.pickle'
                        if os.path.isfile(output_path):
                            if os.path.getsize(output_path) == 0:
                                bar.set_description(
                                    '{} already exists but is empty. Computing it again...'
                                    .format(output_path))
                                os.remove(output_path)
                            else:
                                bar.set_description(
                                    '{} already exists'.format(output_path))
                            i += 1
                            continue

                        input_features = SpeechFeatures.features_from_name(
                            name=input_features_name,
                            signal=preprocessed_audio,
                            rate=rate,
                            filters_number=input_filters_number)

                        if input_features.shape[0] != input_target_shape[
                                0] or input_features.shape[
                                    1] != input_target_shape[1]:
                            ConsoleLogger.warn(
                                "Raw features number {} with invalid dimension {} will not be saved. Target shape: {}"
                                .format(i, input_features.shape,
                                        input_target_shape))
                            i += 1
                            continue

                        output_features = SpeechFeatures.features_from_name(
                            name=output_features_name,
                            signal=preprocessed_audio,
                            rate=rate,
                            filters_number=output_filters_number,
                            augmented=augment_output_features)

                        # TODO: add an option in configuration to save quantized/one_hot or not
                        output = {
                            'preprocessed_audio':
                            preprocessed_audio,
                            'wav_filename':
                            wav_filename,
                            'input_features':
                            input_features,
                            'one_hot':
                            one_hot
                            if export_one_hot_features else np.array([]),
                            'quantized':
                            np.array([]),
                            'speaker_id':
                            speaker_id,
                            'output_features':
                            output_features,
                            'shifting_time':
                            shifting_time,
                            'random_starting_index':
                            random_starting_index,
                            'preprocessed_length':
                            preprocessed_length,
                            'sampling_rate':
                            sampling_rate,
                            'top_db':
                            top_db
                        }

                        with open(output_path, 'wb') as file:
                            pickle.dump(output, file)

                        bar.set_description('{} saved'.format(output_path))

                        i += 1

                        if i == total_length:
                            bar.update(total_length)
                            break

                    bar.close()
                    break
                except KeyboardInterrupt:
                    bar.close()
                    ConsoleLogger.warn(
                        'Keyboard interrupt detected. Leaving the function...')
                    return
                except:
                    error_message = 'An error occured in the data loader at {}/{}. Current attempt: {}/{}'.format(
                        output_dir, i, current_attempt + 1, attempts)
                    self._logger.exception(error_message)
                    ConsoleLogger.error(error_message)
                    initial_index = i
                    current_attempt += 1
                    continue