def gen_figure_from_output(self, id_name, labels, hidden, hparams):

        labels_post = self.dataset_train.postprocess_sample(
            labels)  # Labels come in as T x C.
        org_raw = RawWaveformLabelGen.load_sample(
            id_name, self.OutputGen.frame_rate_output_Hz)

        # Get a data plotter.
        plotter = DataPlotter()
        net_name = os.path.basename(hparams.model_name)
        id_name = os.path.basename(id_name).rsplit('.', 1)[0]
        filename = os.path.join(hparams.out_dir, id_name + "." + net_name)
        plotter.set_title(id_name + " - " + net_name)
        grid_idx = 0

        graphs = list()
        graphs.append((org_raw, 'Org'))
        graphs.append((labels_post, 'Wavenet'))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs)
        plotter.set_linewidth(grid_idx=grid_idx, linewidth=[0.1])
        plotter.set_colors(grid_idx=grid_idx, alpha=0.8)
        plotter.set_lim(grid_idx, ymin=-1, ymax=1)
        plotter.set_label(grid_idx=grid_idx,
                          xlabel='frames [' +
                          str(hparams.frame_rate_output_Hz) + ' Hz]',
                          ylabel='raw')

        plotter.gen_plot()
        plotter.save_to_file(filename + '.Raw' + hparams.gen_figure_ext)
示例#2
0
    def gen_figure_from_output(self,
                               id_name,
                               labels,
                               hidden,
                               hparams,
                               clustering=None,
                               filters_out=None):

        if labels is None or filters_out is None:
            input_labels = self.InputGen[id_name][:, None, ...]
            labels = self.model_handler.forward(input_labels, hparams)[0][:, 0]
            filters_out = self.filters_forward(input_labels, hparams)[:, 0,
                                                                      ...]

        intern_amps = labels[:, 2:]
        labels = labels[:, :2]

        # Retrieve data from label.
        labels_post = self.OutputGen.postprocess_sample(labels)
        output_vuv = labels_post[:, 1]
        output_vuv[output_vuv < 0.5] = 0.0
        output_vuv[output_vuv >= 0.5] = 1.0
        output_vuv = output_vuv.astype(bool)

        output_lf0 = labels_post[:, 0]

        # Load original lf0 and vuv.
        org_labels = self.OutputGen.load_sample(id_name,
                                                self.OutputGen.dir_labels)
        original_lf0, original_vuv = self.OutputGen.convert_to_world_features(
            org_labels)
        # original_lf0, _ = interpolate_lin(original_lf0)

        # phrase_curve = self.OutputGen.get_phrase_curve(id_name)
        # original_lf0 -= phrase_curve[:len(original_lf0)]
        original_lf0 = original_lf0[:len(output_lf0)]

        f0_mse = (np.exp(original_lf0) - np.exp(output_lf0))**2
        f0_rmse = math.sqrt((f0_mse * original_vuv[:len(output_lf0)]).sum() /
                            original_vuv[:len(output_lf0)].sum())
        self.logger.info("RMSE of {}: {} Hz.".format(id_name, f0_rmse))

        org_labels = self.flat_trainer.atom_trainer.OutputGen.load_sample(
            id_name, self.flat_trainer.atom_trainer.OutputGen.dir_labels,
            len(self.flat_trainer.atom_trainer.OutputGen.theta_interval),
            self.flat_trainer.atom_trainer.OutputGen.dir_world_labels)
        org_vuv = org_labels[:, 0, 0]
        org_vuv = org_vuv.astype(bool)

        thetas = self.model_handler.model.thetas_approx()

        # Get a data plotter
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()

        plot_id = 0

        graphs_intern = list()

        for idx in reversed(range(intern_amps.shape[1])):
            graphs_intern.append(
                (intern_amps[:, idx], r'$\theta$={0:.3f}'.format(thetas[idx])))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_intern)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0)])
        plotter.set_label(grid_idx=plot_id, ylabel='command')
        amp_max = 0.04
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_filters = list()
        for idx in reversed(range(filters_out.shape[1])):
            graphs_filters.append((filters_out[:, idx], ))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_filters)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0,
                                          'Unvoiced')])
        plotter.set_label(grid_idx=plot_id, ylabel='filtered')
        amp_max = 0.1
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_lf0 = list()
        graphs_lf0.append((original_lf0, "Original"))
        graphs_lf0.append((output_lf0, "Predicted"))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_lf0)
        plotter.set_hatchstyles(grid_idx=plot_id, hatchstyles=['\\\\'])
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(org_vuv.astype(bool)),
                                          '0.75', 1.0, 'Reference unvoiced')])
        plotter.set_label(grid_idx=plot_id,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='LF0')

        plotter.set_lim(grid_idx=plot_id, ymin=3, ymax=6)
        plotter.set_linestyles(grid_idx=plot_id, linestyles=['-.', '-'])
        plotter.set_colors(grid_idx=plot_id,
                           colors=['C3', 'C2', 'C0'],
                           alpha=1)

        plotter.gen_plot()
        # plotter.gen_plot(True)
        plotter.save_to_file(filename + ".PHRASE" + hparams.gen_figure_ext)

        if clustering is None:
            return

        plotter = DataPlotter()

        def cluster(array, mean=False):
            if mean:
                return np.array([
                    np.take(array, i, axis=-1).mean() for i in clustering
                ]).transpose()
            return np.array([
                np.take(array, i, axis=-1).sum(-1) for i in clustering
            ]).transpose()

        clustered_amps = cluster(intern_amps)
        clustered_thetas = cluster(thetas, True)
        clustered_filters = cluster(filters_out)

        plot_id = 0
        graphs_intern = list()

        for idx in reversed(range(clustered_amps.shape[1])):
            graphs_intern.append(
                (clustered_amps[:, idx],
                 r'$\theta$={0:.3f}'.format(clustered_thetas[idx])))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_intern)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0,
                                          'Unvoiced')])
        plotter.set_label(grid_idx=plot_id, ylabel='cluster command')
        amp_max = 0.04
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_filters = list()
        for idx in reversed(range(clustered_filters.shape[1])):
            graphs_filters.append((clustered_filters[:, idx], ))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_filters)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0)])
        plotter.set_label(grid_idx=plot_id, ylabel='filtered')
        amp_max = 0.175
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_lf0 = list()
        graphs_lf0.append((original_lf0, "Original"))
        graphs_lf0.append((output_lf0, "Predicted"))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_lf0)
        plotter.set_hatchstyles(grid_idx=plot_id, hatchstyles=['\\\\'])
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(org_vuv.astype(bool)),
                                          '0.75', 1.0, 'Reference unvoiced')])
        plotter.set_label(grid_idx=plot_id,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='lf0')
        # amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
        amp_lim = 1
        plotter.set_lim(grid_idx=plot_id, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=plot_id, linestyles=['-.', '-'])
        plotter.set_colors(grid_idx=plot_id,
                           colors=['C3', 'C2', 'C0'],
                           alpha=1)

        plotter.gen_plot()
        # plotter.gen_plot(True)
        plotter.save_to_file(filename + ".CLUSTERS" + hparams.gen_figure_ext)