示例#1
0
    def gen_figure_from_output(self,
                               id_name,
                               labels,
                               hidden,
                               hparams,
                               clustering=None,
                               filters_out=None):

        if labels is None or filters_out is None:
            input_labels = self.InputGen[id_name][:, None, ...]
            labels = self.model_handler.forward(input_labels, hparams)[0][:, 0]
            filters_out = self.filters_forward(input_labels, hparams)[:, 0,
                                                                      ...]

        intern_amps = labels[:, 2:]
        labels = labels[:, :2]

        # Retrieve data from label.
        labels_post = self.OutputGen.postprocess_sample(labels)
        output_vuv = labels_post[:, 1]
        output_vuv[output_vuv < 0.5] = 0.0
        output_vuv[output_vuv >= 0.5] = 1.0
        output_vuv = output_vuv.astype(bool)

        output_lf0 = labels_post[:, 0]

        # Load original lf0 and vuv.
        org_labels = self.OutputGen.load_sample(id_name,
                                                self.OutputGen.dir_labels)
        original_lf0, original_vuv = self.OutputGen.convert_to_world_features(
            org_labels)
        # original_lf0, _ = interpolate_lin(original_lf0)

        # phrase_curve = self.OutputGen.get_phrase_curve(id_name)
        # original_lf0 -= phrase_curve[:len(original_lf0)]
        original_lf0 = original_lf0[:len(output_lf0)]

        f0_mse = (np.exp(original_lf0) - np.exp(output_lf0))**2
        f0_rmse = math.sqrt((f0_mse * original_vuv[:len(output_lf0)]).sum() /
                            original_vuv[:len(output_lf0)].sum())
        self.logger.info("RMSE of {}: {} Hz.".format(id_name, f0_rmse))

        org_labels = self.flat_trainer.atom_trainer.OutputGen.load_sample(
            id_name, self.flat_trainer.atom_trainer.OutputGen.dir_labels,
            len(self.flat_trainer.atom_trainer.OutputGen.theta_interval),
            self.flat_trainer.atom_trainer.OutputGen.dir_world_labels)
        org_vuv = org_labels[:, 0, 0]
        org_vuv = org_vuv.astype(bool)

        thetas = self.model_handler.model.thetas_approx()

        # Get a data plotter
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()

        plot_id = 0

        graphs_intern = list()

        for idx in reversed(range(intern_amps.shape[1])):
            graphs_intern.append(
                (intern_amps[:, idx], r'$\theta$={0:.3f}'.format(thetas[idx])))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_intern)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0)])
        plotter.set_label(grid_idx=plot_id, ylabel='command')
        amp_max = 0.04
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_filters = list()
        for idx in reversed(range(filters_out.shape[1])):
            graphs_filters.append((filters_out[:, idx], ))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_filters)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0,
                                          'Unvoiced')])
        plotter.set_label(grid_idx=plot_id, ylabel='filtered')
        amp_max = 0.1
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_lf0 = list()
        graphs_lf0.append((original_lf0, "Original"))
        graphs_lf0.append((output_lf0, "Predicted"))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_lf0)
        plotter.set_hatchstyles(grid_idx=plot_id, hatchstyles=['\\\\'])
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(org_vuv.astype(bool)),
                                          '0.75', 1.0, 'Reference unvoiced')])
        plotter.set_label(grid_idx=plot_id,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='LF0')

        plotter.set_lim(grid_idx=plot_id, ymin=3, ymax=6)
        plotter.set_linestyles(grid_idx=plot_id, linestyles=['-.', '-'])
        plotter.set_colors(grid_idx=plot_id,
                           colors=['C3', 'C2', 'C0'],
                           alpha=1)

        plotter.gen_plot()
        # plotter.gen_plot(True)
        plotter.save_to_file(filename + ".PHRASE" + hparams.gen_figure_ext)

        if clustering is None:
            return

        plotter = DataPlotter()

        def cluster(array, mean=False):
            if mean:
                return np.array([
                    np.take(array, i, axis=-1).mean() for i in clustering
                ]).transpose()
            return np.array([
                np.take(array, i, axis=-1).sum(-1) for i in clustering
            ]).transpose()

        clustered_amps = cluster(intern_amps)
        clustered_thetas = cluster(thetas, True)
        clustered_filters = cluster(filters_out)

        plot_id = 0
        graphs_intern = list()

        for idx in reversed(range(clustered_amps.shape[1])):
            graphs_intern.append(
                (clustered_amps[:, idx],
                 r'$\theta$={0:.3f}'.format(clustered_thetas[idx])))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_intern)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0,
                                          'Unvoiced')])
        plotter.set_label(grid_idx=plot_id, ylabel='cluster command')
        amp_max = 0.04
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_filters = list()
        for idx in reversed(range(clustered_filters.shape[1])):
            graphs_filters.append((clustered_filters[:, idx], ))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_filters)
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(output_vuv), '0.75', 1.0)])
        plotter.set_label(grid_idx=plot_id, ylabel='filtered')
        amp_max = 0.175
        amp_min = -amp_max
        plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max)
        plot_id += 1

        graphs_lf0 = list()
        graphs_lf0.append((original_lf0, "Original"))
        graphs_lf0.append((output_lf0, "Predicted"))
        plotter.set_data_list(grid_idx=plot_id, data_list=graphs_lf0)
        plotter.set_hatchstyles(grid_idx=plot_id, hatchstyles=['\\\\'])
        plotter.set_area_list(grid_idx=plot_id,
                              area_list=[(np.invert(org_vuv.astype(bool)),
                                          '0.75', 1.0, 'Reference unvoiced')])
        plotter.set_label(grid_idx=plot_id,
                          xlabel='frames [' + str(hparams.frame_size_ms) +
                          ' ms]',
                          ylabel='lf0')
        # amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
        amp_lim = 1
        plotter.set_lim(grid_idx=plot_id, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=plot_id, linestyles=['-.', '-'])
        plotter.set_colors(grid_idx=plot_id,
                           colors=['C3', 'C2', 'C0'],
                           alpha=1)

        plotter.gen_plot()
        # plotter.gen_plot(True)
        plotter.save_to_file(filename + ".CLUSTERS" + hparams.gen_figure_ext)
    def gen_figure_from_output(self, id_name, label, hidden, hparams):

        # Retrieve data from label.
        output_amps = label[:, 1:-1]
        output_pos = label[:, -1]
        labels_post = self.OutputGen.postprocess_sample(label)
        output_vuv = labels_post[:, 0, 1].astype(bool)
        output_atoms = self.OutputGen.labels_to_atoms(labels_post, k=hparams.k, amp_threshold=hparams.min_atom_amp)
        output_lf0 = self.OutputGen.atoms_to_lf0(output_atoms, len(label))

        # Load original lf0 and vuv.
        world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                      else os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
        org_labels = LF0LabelGen.load_sample(id_name, world_dir)
        original_lf0, _ = LF0LabelGen.convert_to_world_features(org_labels)
        original_lf0, _ = interpolate_lin(original_lf0)

        phrase_curve = np.fromfile(os.path.join(self.OutputGen.dir_labels, id_name + self.OutputGen.ext_phrase),
                                   dtype=np.float32).reshape(-1, 1)
        original_lf0[:len(phrase_curve)] -= phrase_curve[:len(original_lf0)]
        original_lf0 = original_lf0[:len(output_lf0)]

        org_labels = self.OutputGen.load_sample(id_name,
                                                self.OutputGen.dir_labels,
                                                len(hparams.thetas),
                                                self.OutputGen.dir_world_labels)
        org_vuv = org_labels[:, 0, 0].astype(bool)
        org_labels = org_labels[:, 1:]
        len_diff = len(org_labels) - len(labels_post)
        org_labels = self.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0))
        org_labels = self.OutputGen.trim_end_sample(org_labels, int(len_diff / 2.0) + 1)
        org_atoms = AtomLabelGen.labels_to_atoms(org_labels, k=hparams.k, frame_size=hparams.frame_size_ms)
        wcad_lf0 = self.OutputGen.atoms_to_lf0(org_atoms, len(org_labels))

        # Get a data plotter
        net_name = os.path.basename(hparams.model_name)
        filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name))
        plotter = DataPlotter()
        plotter.set_title(id_name + " - " + net_name)

        grid_idx = 0
        graphs_output = list()
        for idx in reversed(range(output_amps.shape[1])):
            graphs_output.append((output_amps[:, idx], r'$\theta$={0:.3f}'.format(hparams.thetas[idx])))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_output)
        plotter.set_label(grid_idx=grid_idx, ylabel='NN amps')
        amp_max = np.max(output_amps) * 1.1
        amp_min = np.min(output_amps) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=amp_min, ymax=amp_max)

        grid_idx += 1
        graphs_pos_flag = list()
        graphs_pos_flag.append((output_pos,))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_pos_flag)
        plotter.set_label(grid_idx=grid_idx, ylabel='NN pos')

        grid_idx += 1
        graphs_peaks = list()
        for idx in reversed(range(label.shape[1] - 2)):
            graphs_peaks.append((labels_post[:, 1 + idx, 0],))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_peaks)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(output_vuv), '0.75', 1.0, 'Unvoiced')])
        plotter.set_label(grid_idx=grid_idx, ylabel='NN peaks')
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_target = list()
        for idx in reversed(range(org_labels.shape[1])):
            graphs_target.append((org_labels[:, idx, 0],))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_target)
        plotter.set_hatchstyles(grid_idx=grid_idx, hatchstyles=['\\\\'])
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(org_vuv.astype(bool)), '0.75', 1.0, 'Reference unvoiced')])
        plotter.set_label(grid_idx=grid_idx, ylabel='target')
        plotter.set_lim(grid_idx=grid_idx, ymin=-1.8, ymax=1.8)

        grid_idx += 1
        graphs_lf0 = list()
        graphs_lf0.append((wcad_lf0, "wcad lf0"))
        graphs_lf0.append((original_lf0, "org lf0"))
        graphs_lf0.append((output_lf0, "predicted lf0"))
        plotter.set_data_list(grid_idx=grid_idx, data_list=graphs_lf0)
        plotter.set_area_list(grid_idx=grid_idx, area_list=[(np.invert(org_vuv.astype(bool)), '0.75', 1.0)])
        plotter.set_hatchstyles(grid_idx=grid_idx, hatchstyles=['\\\\'])
        plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='lf0')
        amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1
        plotter.set_lim(grid_idx=grid_idx, ymin=-amp_lim, ymax=amp_lim)
        plotter.set_linestyles(grid_idx=grid_idx, linestyles=[':', '--', '-'])

        # # Compute F0 RMSE for sample and add it to title.
        # org_f0 = (np.exp(lf0.squeeze() + phrase_curve[:len(lf0)].squeeze()) * vuv)[:len(output_lf0)]  # Fix minor negligible length mismatch.
        # output_f0 = np.exp(output_lf0 + phrase_curve[:len(output_lf0)].squeeze()) * output_vuv[:len(output_lf0)]
        # f0_mse = (org_f0 - output_f0) ** 2
        # # non_zero_count = np.logical_and(vuv[:len(output_lf0)], output_vuv).sum()
        # f0_rmse = math.sqrt(f0_mse.sum() / (np.logical_and(vuv[:len(output_lf0)], output_vuv).sum()))

        # # Compute vuv error rate.
        # num_errors = (vuv[:len(output_lf0)] != output_vuv)
        # vuv_error_rate = float(num_errors.sum()) / len(output_lf0)
        # plotter.set_title(id_name + " - " + net_name + " - F0_RMSE_" + "{:4.2f}Hz".format(f0_rmse) + " - VUV_" + "{:2.2f}%".format(vuv_error_rate * 100))
        # plotter.set_lim(xmin=300, xmax=1100)g
        plotter.gen_plot(monochrome=True)
        plotter.gen_plot()
        plotter.save_to_file(filename + ".VUV_DIST_POS" + hparams.gen_figure_ext)