def gen_figure_from_output(self, id_name, labels, hidden, hparams): labels_post = self.dataset_train.postprocess_sample( labels) # Labels come in as T x C. org_raw = RawWaveformLabelGen.load_sample( id_name, self.OutputGen.frame_rate_output_Hz) # Get a data plotter. plotter = DataPlotter() net_name = os.path.basename(hparams.model_name) id_name = os.path.basename(id_name).rsplit('.', 1)[0] filename = os.path.join(hparams.out_dir, id_name + "." + net_name) plotter.set_title(id_name + " - " + net_name) grid_idx = 0 graphs = list() graphs.append((org_raw, 'Org')) graphs.append((labels_post, 'Wavenet')) plotter.set_data_list(grid_idx=grid_idx, data_list=graphs) plotter.set_linewidth(grid_idx=grid_idx, linewidth=[0.1]) plotter.set_colors(grid_idx=grid_idx, alpha=0.8) plotter.set_lim(grid_idx, ymin=-1, ymax=1) plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_rate_output_Hz) + ' Hz]', ylabel='raw') plotter.gen_plot() plotter.save_to_file(filename + '.Raw' + hparams.gen_figure_ext)
def gen_figure_from_output(self, id_name, labels, hidden, hparams, clustering=None, filters_out=None): if labels is None or filters_out is None: input_labels = self.InputGen[id_name][:, None, ...] labels = self.model_handler.forward(input_labels, hparams)[0][:, 0] filters_out = self.filters_forward(input_labels, hparams)[:, 0, ...] intern_amps = labels[:, 2:] labels = labels[:, :2] # Retrieve data from label. labels_post = self.OutputGen.postprocess_sample(labels) output_vuv = labels_post[:, 1] output_vuv[output_vuv < 0.5] = 0.0 output_vuv[output_vuv >= 0.5] = 1.0 output_vuv = output_vuv.astype(bool) output_lf0 = labels_post[:, 0] # Load original lf0 and vuv. org_labels = self.OutputGen.load_sample(id_name, self.OutputGen.dir_labels) original_lf0, original_vuv = self.OutputGen.convert_to_world_features( org_labels) # original_lf0, _ = interpolate_lin(original_lf0) # phrase_curve = self.OutputGen.get_phrase_curve(id_name) # original_lf0 -= phrase_curve[:len(original_lf0)] original_lf0 = original_lf0[:len(output_lf0)] f0_mse = (np.exp(original_lf0) - np.exp(output_lf0))**2 f0_rmse = math.sqrt((f0_mse * original_vuv[:len(output_lf0)]).sum() / original_vuv[:len(output_lf0)].sum()) self.logger.info("RMSE of {}: {} Hz.".format(id_name, f0_rmse)) org_labels = self.flat_trainer.atom_trainer.OutputGen.load_sample( id_name, self.flat_trainer.atom_trainer.OutputGen.dir_labels, len(self.flat_trainer.atom_trainer.OutputGen.theta_interval), self.flat_trainer.atom_trainer.OutputGen.dir_world_labels) org_vuv = org_labels[:, 0, 0] org_vuv = org_vuv.astype(bool) thetas = self.model_handler.model.thetas_approx() # Get a data plotter net_name = os.path.basename(hparams.model_name) filename = str(os.path.join(hparams.out_dir, id_name + '.' + net_name)) plotter = DataPlotter() plot_id = 0 graphs_intern = list() for idx in reversed(range(intern_amps.shape[1])): graphs_intern.append( (intern_amps[:, idx], r'$\theta$={0:.3f}'.format(thetas[idx]))) plotter.set_data_list(grid_idx=plot_id, data_list=graphs_intern) plotter.set_area_list(grid_idx=plot_id, area_list=[(np.invert(output_vuv), '0.75', 1.0)]) plotter.set_label(grid_idx=plot_id, ylabel='command') amp_max = 0.04 amp_min = -amp_max plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max) plot_id += 1 graphs_filters = list() for idx in reversed(range(filters_out.shape[1])): graphs_filters.append((filters_out[:, idx], )) plotter.set_data_list(grid_idx=plot_id, data_list=graphs_filters) plotter.set_area_list(grid_idx=plot_id, area_list=[(np.invert(output_vuv), '0.75', 1.0, 'Unvoiced')]) plotter.set_label(grid_idx=plot_id, ylabel='filtered') amp_max = 0.1 amp_min = -amp_max plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max) plot_id += 1 graphs_lf0 = list() graphs_lf0.append((original_lf0, "Original")) graphs_lf0.append((output_lf0, "Predicted")) plotter.set_data_list(grid_idx=plot_id, data_list=graphs_lf0) plotter.set_hatchstyles(grid_idx=plot_id, hatchstyles=['\\\\']) plotter.set_area_list(grid_idx=plot_id, area_list=[(np.invert(org_vuv.astype(bool)), '0.75', 1.0, 'Reference unvoiced')]) plotter.set_label(grid_idx=plot_id, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='LF0') plotter.set_lim(grid_idx=plot_id, ymin=3, ymax=6) plotter.set_linestyles(grid_idx=plot_id, linestyles=['-.', '-']) plotter.set_colors(grid_idx=plot_id, colors=['C3', 'C2', 'C0'], alpha=1) plotter.gen_plot() # plotter.gen_plot(True) plotter.save_to_file(filename + ".PHRASE" + hparams.gen_figure_ext) if clustering is None: return plotter = DataPlotter() def cluster(array, mean=False): if mean: return np.array([ np.take(array, i, axis=-1).mean() for i in clustering ]).transpose() return np.array([ np.take(array, i, axis=-1).sum(-1) for i in clustering ]).transpose() clustered_amps = cluster(intern_amps) clustered_thetas = cluster(thetas, True) clustered_filters = cluster(filters_out) plot_id = 0 graphs_intern = list() for idx in reversed(range(clustered_amps.shape[1])): graphs_intern.append( (clustered_amps[:, idx], r'$\theta$={0:.3f}'.format(clustered_thetas[idx]))) plotter.set_data_list(grid_idx=plot_id, data_list=graphs_intern) plotter.set_area_list(grid_idx=plot_id, area_list=[(np.invert(output_vuv), '0.75', 1.0, 'Unvoiced')]) plotter.set_label(grid_idx=plot_id, ylabel='cluster command') amp_max = 0.04 amp_min = -amp_max plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max) plot_id += 1 graphs_filters = list() for idx in reversed(range(clustered_filters.shape[1])): graphs_filters.append((clustered_filters[:, idx], )) plotter.set_data_list(grid_idx=plot_id, data_list=graphs_filters) plotter.set_area_list(grid_idx=plot_id, area_list=[(np.invert(output_vuv), '0.75', 1.0)]) plotter.set_label(grid_idx=plot_id, ylabel='filtered') amp_max = 0.175 amp_min = -amp_max plotter.set_lim(grid_idx=plot_id, ymin=amp_min, ymax=amp_max) plot_id += 1 graphs_lf0 = list() graphs_lf0.append((original_lf0, "Original")) graphs_lf0.append((output_lf0, "Predicted")) plotter.set_data_list(grid_idx=plot_id, data_list=graphs_lf0) plotter.set_hatchstyles(grid_idx=plot_id, hatchstyles=['\\\\']) plotter.set_area_list(grid_idx=plot_id, area_list=[(np.invert(org_vuv.astype(bool)), '0.75', 1.0, 'Reference unvoiced')]) plotter.set_label(grid_idx=plot_id, xlabel='frames [' + str(hparams.frame_size_ms) + ' ms]', ylabel='lf0') # amp_lim = max(np.max(np.abs(wcad_lf0)), np.max(np.abs(output_lf0))) * 1.1 amp_lim = 1 plotter.set_lim(grid_idx=plot_id, ymin=-amp_lim, ymax=amp_lim) plotter.set_linestyles(grid_idx=plot_id, linestyles=['-.', '-']) plotter.set_colors(grid_idx=plot_id, colors=['C3', 'C2', 'C0'], alpha=1) plotter.gen_plot() # plotter.gen_plot(True) plotter.save_to_file(filename + ".CLUSTERS" + hparams.gen_figure_ext)