def run(): """Main function which wraps the prosody labeller tool """ global args # Extract labels if args.label: lab_f = args.label else: lab_f = os.path.splitext(args.input_file)[0]+".lab" if os.path.exists(lab_f): labels = lab.read_htk_label(lab_f) labels = labels[args.level] # Filter by level else: logging.error("Label file \"%s\" doesn't exist" % lab_f) sys.exit(-1) # Extract parameters (params, pitch, energy_smooth, rate) = extract_params(args.input_file, labels) # perform wavelet transform (cwt,scales) = cwt_utils.cwt_analysis(params, mother_name="mexican_hat", period=2, num_scales=args.nb_scales, scale_distance=args.scale_dist, apply_coi=True) scales *= args.scale_factor # Labelling prominences and boundarys (prominences, boundaries, pos_loma, neg_loma) = label_prosody(scales, cwt, labels) print("========================================================") print("label\tprominence\tboundary") for i in range(0, len(labels)): print("%s\t%f\t%f" %(labels[i][-1], prominences[i][-1], boundaries[i][-1])) if args.plot: warnings.simplefilter("ignore", np.ComplexWarning) # Plotting can't deal with complex, but we don't care plot(labels, rate, energy_smooth, pitch, params, cwt, boundaries, prominences, pos_loma, neg_loma)
def analysis(input_file, cfg, logger, annotation_dir=None, output_dir=None, plot=False): # Load the wave file print("Analyzing %s starting..." % input_file) orig_sr, sig = misc.read_wav(input_file) # extract energy energy = energy_processing.extract_energy(sig, orig_sr, cfg["energy"]["band_min"], cfg["energy"]["band_max"], cfg["energy"]["calculation_method"]) energy = np.cbrt(energy+1) if cfg["energy"]["smooth_energy"]: energy = smooth_and_interp.peak_smooth(energy, 30, 3) # FIXME: 30? 3? energy = smooth_and_interp.smooth(energy, 10) # extract f0 raw_pitch = f0_processing.extract_f0(sig, orig_sr, f0_min=cfg["f0"]["min_f0"], f0_max=cfg["f0"]["max_f0"], voicing=cfg["f0"]["voicing_threshold"], #harmonics=cfg["f0"]["harmonics"], configuration=cfg["f0"]["pitch_tracker"]) # interpolate, stylize pitch = f0_processing.process(raw_pitch) # extract speech rate rate = np.zeros(len(pitch)) # Get annotations (if available) tiers = [] if annotation_dir is None: annotation_dir = os.path.dirname(input_file) basename = os.path.splitext(os.path.basename(input_file))[0] grid = os.path.join(annotation_dir, "%s.TextGrid" % basename) if os.path.exists(grid): tiers = lab.read_textgrid(grid) else: grid = os.path.join(annotation_dir, "%s.lab" % basename) if not os.path.exists(grid): raise Exception("There is no annotations associated with %s" % input_file) tiers = lab.read_htk_label(grid) # Extract duration if len(tiers) > 0: dur_tiers = [] for level in cfg["duration"]["duration_tiers"]: assert(level.lower() in tiers), level+" not defined in tiers: check that duration_tiers in config match the actual textgrid tiers" try: dur_tiers.append(tiers[level.lower()]) except: print("\nerror: "+"\""+level+"\"" +" not in labels, modify duration_tiers in config\n\n") raise if not cfg["duration"]["acoustic_estimation"]: rate = duration_processing.get_duration_signal(dur_tiers, weights=cfg["duration"]["weights"], linear=cfg["duration"]["linear"], sil_symbols=cfg["duration"]["silence_symbols"], bump = cfg["duration"]["bump"]) else: rate = duration_processing.get_rate(energy) rate = smooth_and_interp.smooth(rate, 30) if cfg["duration"]["delta_duration"]: rate = np.diff(rate) # Combine signals min_length = np.min([len(pitch), len(energy), len(rate)]) pitch = pitch[:min_length] energy = energy[:min_length] rate = rate[:min_length] if cfg["feature_combination"]["type"] == "product": pitch = misc.normalize_minmax(pitch) ** cfg["feature_combination"]["weights"]["f0"] energy = misc.normalize_minmax(energy) ** cfg["feature_combination"]["weights"]["energy"] rate = misc.normalize_minmax(rate) ** cfg["feature_combination"]["weights"]["duration"] params = pitch * energy * rate else: params = misc.normalize_std(pitch) * cfg["feature_combination"]["weights"]["f0"] + \ misc.normalize_std(energy) * cfg["feature_combination"]["weights"]["energy"] + \ misc.normalize_std(rate) * cfg["feature_combination"]["weights"]["duration"] if cfg["feature_combination"]["detrend"]: params = smooth_and_interp.remove_bias(params, 800) params = misc.normalize_std(params) # CWT analysis (cwt, scales, freqs) = cwt_utils.cwt_analysis(params, mother_name=cfg["wavelet"]["mother_wavelet"], period=cfg["wavelet"]["period"], num_scales=cfg["wavelet"]["num_scales"], scale_distance=cfg["wavelet"]["scale_distance"], apply_coi=False) cwt = np.real(cwt) scales *= 200 # FIXME: why 200? # Compute lines of maximum amplitude assert(cfg["labels"]["annotation_tier"].lower() in tiers), \ cfg["labels"]["annotation_tier"]+" not defined in tiers: check that annotation_tier in config is found in the textgrid tiers" labels = tiers[cfg["labels"]["annotation_tier"].lower()] # get scale corresponding to avg unit length of selected tier n_scales = cfg["wavelet"]["num_scales"] scale_dist = cfg["wavelet"]["scale_distance"] scales = (1./freqs*200)*0.5 # FIXME: hardcoded vales unit_scale = misc.get_best_scale2(scales, labels) # Define the scale information (FIXME: description) pos_loma_start_scale = unit_scale + int(cfg["loma"]["prom_start"]/scale_dist) # three octaves down from average unit length pos_loma_end_scale = unit_scale + int(cfg["loma"]["prom_end"]/scale_dist) neg_loma_start_scale = unit_scale + int(cfg["loma"]["boundary_start"]/scale_dist) # two octaves down neg_loma_end_scale = unit_scale + int(cfg["loma"]["boundary_end"]/scale_dist) # one octave up pos_loma = loma.get_loma(cwt, scales, pos_loma_start_scale, pos_loma_end_scale) neg_loma = loma.get_loma(-cwt, scales, neg_loma_start_scale, neg_loma_end_scale) max_loma = loma.get_prominences(pos_loma, labels) prominences = np.array(max_loma) boundaries = np.array(loma.get_boundaries(max_loma, neg_loma, labels)) # output results if output_dir is None: output_dir = os.path.dirname(input_file) os.makedirs(output_dir, exist_ok=True) basename = os.path.splitext(os.path.basename(input_file))[0] output_filename = os.path.join(output_dir, "%s.prom" % basename) print("Saving %s..." % (output_filename)) loma.save_analyses(output_filename, labels, prominences, boundaries) # Plotting if plot != 0: fig, ax = plt.subplots(6, 1, sharex=True, figsize=(len(labels) / 10 * 8, 8), gridspec_kw = {'height_ratios':[1, 1, 1, 2, 4, 1.5]}) plt.subplots_adjust(hspace=0) # Plot individual signals ax[0].plot(pitch, linewidth=1) ax[0].set_ylabel("Pitch", rotation="horizontal", ha="right", va="center") ax[1].plot(energy, linewidth=1) ax[1].set_ylabel("Energy", rotation="horizontal", ha="right", va="center") ax[2].plot(rate, linewidth=1) ax[2].set_ylabel("Speech rate", rotation="horizontal", ha="right", va="center") # Plot combined signal ax[3].plot(params, linewidth=1) ax[3].set_ylabel("Combined \n signal", rotation="horizontal", ha="right", va="center") plt.xlim(0, len(params)) # Wavelet and loma cwt[cwt>0] = np.log(cwt[cwt>0]+1.) cwt[cwt<-0.1] = -0.1 ax[4].contourf(cwt,100, cmap="inferno") loma.plot_loma(pos_loma, ax[4], color="black") loma.plot_loma(neg_loma, ax[4], color="white") ax[4].set_ylabel("Wavelet & \n LOMA", rotation="horizontal", ha="right", va="center") # Add labels prom_text = prominences[:, 1]/(np.max(prominences[:, 1]))*2.5 + 0.5 lab.plot_labels(labels, ypos=0.3, size=6, prominences=prom_text, fig=ax[5], boundary=False, background=False) ax[5].set_ylabel("Labels", rotation="horizontal", ha="right", va="center") for i in range(0, len(labels)): for a in [0, 1, 2, 3, 4, 5]: ax[a].axvline(x=labels[i][0], color='black', linestyle="-", linewidth=0.2, alpha=0.5) ax[a].axvline(x=labels[i][1], color='black', linestyle="-", linewidth=0.2+boundaries[i][-1] * 2, alpha=0.5) plt.xlim(0, cwt.shape[1]) # Align ylabels and remove axis fig.align_ylabels(ax) for i in range(len(ax)-1): ax[i].tick_params( axis='x', # changes apply to the x-axis which='both', # both major and minor ticks are affected bottom=False, # ticks along the bottom edge are off top=False, # ticks along the top edge are off labelbottom=False) # labels along the bottom edge are off ax[i].tick_params( axis='y', # changes apply to the x-axis which='both', # both major and minor ticks are affected left=False, # ticks along the bottom edge are off right=False, # ticks along the top edge are off labelleft=False) # labels along the bottom edge are off ax[len(ax)-1].tick_params( axis='y', # changes apply to the x-axis which='both', # both major and minor ticks are affected left=False, # ticks along the bottom edge are off right=False, # ticks along the top edge are off labelleft=False) # labels along the bottom edge are off # Plot if plot < 0: output_filename = os.path.join(output_dir, "%s.png" % basename) logger.info("Save plot %s" % output_filename) fig.savefig(output_filename, bbox_inches='tight', dpi=400) elif plot > 0: plt.show()
def run(): """Main entry function This function contains the code needed to achieve the analysis and/or the synthesis """ global args warnings.simplefilter( "ignore", FutureWarning) # Plotting can't deal with complex, but we don't care # Loading default configuration configuration = defaultdict() with open( os.path.dirname(os.path.realpath(__file__)) + "/configs/default.yaml", 'r') as f: configuration = apply_configuration( configuration, defaultdict(lambda: False, yaml.load(f))) logging.debug("default configuration") logging.debug(configuration) # Loading dedicated analysis.synthesis configuration with open( os.path.dirname(os.path.realpath(__file__)) + "/configs/synthesis.yaml", 'r') as f: configuration = apply_configuration( configuration, defaultdict(lambda: False, yaml.load(f))) logging.debug("configuration filled with synthesis part") logging.debug(configuration) # Loading user configuration if args.configuration_file: try: with open(args.configuration_file, 'r') as f: configuration = apply_configuration( configuration, defaultdict(lambda: False, yaml.load(f))) logging.debug("configuration filled with user part") logging.debug(configuration) except IOError as ex: logging.error("configuration file " + args.config + " could not be loaded:") logging.error(ex.msg) sys.exit(1) # Analysis Mode if args.mode == 0: raw_f0 = load_f0(args.input_file, args.binary_mode, configuration) logging.info("Processing f0") f0 = f0_processing.process(raw_f0) # FIXME: reintegrated if args.plot: plt.title("F0 preprocessing and interpolation") plt.plot(f0, color="red", alpha=0.5, linewidth=3) plt.plot(raw_f0, color="gray", alpha=0.5) plt.show() # # FIXME: read this? # logging.info("writing interpolated lf0\t" + output_file + ".interp") # np.savetxt(output_file + ".interp", f0.astype('float'), # fmt="%f", delimiter="\n") # Perform continuous wavelet transform of mean-substracted f0 with 12 scales, one octave apart logging.info( "Starting analysis with (num_scale=%d, scale_distance=%f, mother_name=%s)" % (configuration["wavelet"]["num_scales"], configuration["wavelet"]["scale_distance"], configuration["wavelet"]["mother_wavelet"])) full_scales, widths, _ = cwt_utils.cwt_analysis( f0 - np.mean(f0), mother_name=configuration["wavelet"]["mother_wavelet"], period=configuration["wavelet"]["period"], num_scales=configuration["wavelet"]["num_scales"], scale_distance=configuration["wavelet"]["scale_distance"], apply_coi=False) full_scales = np.real(full_scales) # SSW parameterization, adjacent scales combined (with extra scales to handle long utterances) scales = cwt_utils.combine_scales( np.real(full_scales), configuration["wavelet"]["combined_scales"]) for i in range(0, len(scales)): logging.debug("Mean scale[%d]: %s" % (i, str(np.mean(scales[i])))) # Saving matrix logging.info("writing wavelet matrix in \"%s\"" % args.output_file) if args.binary_mode: with open(args.output_file, "wb") as f_out: scales.T.astype(np.float32).tofile(f_out) else: np.savetxt(args.output_file, scales.T.astype('float'), fmt="%f", delimiter=",") # Synthesis mode if args.mode == 1: if args.binary_mode: scales = np.fromfile(args.input_file, dtype=np.float32) scales = scales.reshape( -1, len(configuration["wavelet"]["combined_scales"])).T else: scales = np.loadtxt(args.input_file, delimiter=",").T # FIXME: hardcoded rec = cwt_utils.cwt_synthesis(scales, args.mean_f0) logging.info("Save reconstructed f0 in %s" % args.output_file) if args.binary_mode: with open(args.output_file, "wb") as f_out: rec.astype(np.float32).tofile(f_out) else: np.savetxt(args.output_file, rec, fmt="%f") # Debugging /plotting part if args.plot: nb_sub = 2 if args.mode == 0: nb_sub = 3 ax = plt.subplot(nb_sub, 1, 1) # pylab.title("CWT decomposition to % scales and reconstructed signal" % len(configuration["wavelet"]["combined_scales"])) if args.mode == 0: plt.plot(f0, linewidth=1, color="red") rec = cwt_utils.cwt_synthesis(scales, np.mean(f0)) plt.plot(rec, color="blue", alpha=0.3) plt.subplot(nb_sub, 1, 2, sharex=ax) for i in range(0, len(scales)): plt.plot(scales[i] + max(rec) * 1.5 + i * 75, color="blue", alpha=0.5) #plt.plot(scales[len(scales)-i-1] + max(rec)*1.5 + i*75, if args.mode == 0: plt.subplot(nb_sub, 1, 3, sharex=ax) plt.contourf(np.real(full_scales), 100, norm=colors.SymLogNorm(linthresh=0.2, linscale=0.05, vmin=np.min(full_scales), vmax=np.max(full_scales)), cmap="jet") plt.show()
def analysis(self): prev_zoom = None if not self.fUpdate["wav"]: prev_zoom = self.ax[3].axis() if not self.cur_wav: return self.refresh_updates() # show spectrogram if self.fUpdate['wav']: self.toolbar.update() self.logger.debug("plot specgram") self.ax[0].cla() self.orig_sr, self.sig = misc.read_wav(self.cur_wav) self.plot_len = int(len(self.sig) * (PLOT_SR / self.orig_sr)) self.ax[0].specgram(self.sig, NFFT=200, noverlap=40, Fs=self.orig_sr, xextent=[0, self.plot_len], cmap="jet") if self.fUpdate['energy']: # 'energy' is just a smoothed envelope here self.logger.debug("analyzing energy..") self.energy = energy_processing.extract_energy( self.sig, self.orig_sr, self.configuration["energy"]["band_min"], self.configuration["energy"]["band_max"], self.configuration["energy"]["calculation_method"]) if self.configuration["energy"]["smooth_energy"]: self.energy_smooth = smooth_and_interp.peak_smooth( self.energy, 30, 3) # FIXME: 30? 3? else: self.energy_smooth = self.energy raw_pitch = None if self.fUpdate['f0']: self.ax[1].cla() self.pitch = None raw_pitch = None # if f0 file is provided, use that if self.bUseExistingF0.isChecked(): raw_pitch = f0_processing.read_f0(self.cur_wav) # else use reaper if raw_pitch is None: # analyze pitch self.logger.debug("analyzing pitch..") min_f0 = float(str(self.min_f0.text())) max_f0 = float(str(self.max_f0.text())) max_f0 = np.max([max_f0, 10.]) min_f0 = np.min([max_f0 - 1., min_f0]) raw_pitch = f0_processing.extract_f0( self.sig, self.orig_sr, min_f0, max_f0, float(self.harmonics.value()), float(self.voicing.value()), self.configuration["f0"]["pitch_tracker"]) # FIXME: fix errors, smooth and interpolate try: self.pitch = f0_processing.process(raw_pitch) except Exception as ex: exception_log( self.logger, "no idea!!!", ex, logging.DEBUG) # FIXME: more human friendly message # f0_processing.process crashes if raw_pitch is all zeros, kludge self.pitch = raw_pitch self.ax[1].plot(raw_pitch, color='black', linewidth=1) self.ax[1].plot(self.pitch, color='black', linewidth=2) self.ax[1].set_ylim( np.min(self.pitch) * 0.75, np.max(self.pitch) * 1.2) if self.fUpdate['duration']: self.rate = np.zeros(len(self.pitch)) self.logger.debug("analyzing duration...") # signal method for speech rate, quite shaky if self.signalRate.isChecked(): self.rate = duration_processing.get_rate(self.energy) self.rate = smooth_and_interp.smooth(self.rate, 30) # word / syllable / segment duration from labels else: sig_tiers = [] for item in self.signalTiers.selectedItems(): sig_tiers.append(self.tiers[item.text()]) try: # Only if some tiers are selected if (len(sig_tiers)) > 0: self.rate = duration_processing.get_duration_signal( sig_tiers, sil_symbols=self.configuration["duration"] ["silence_symbols"]) except Exception as ex: exception_log(self.logger, "Duration signal construction failed", ex, logging.ERROR) if self.diffDur.isChecked(): self.rate = np.diff(self.rate, 1) try: self.rate = np.pad(self.rate, (0, len(self.pitch) - len(self.rate)), 'edge') except Exception: self.rate = self.rate[0:len(self.pitch)] # combine acoustic features by normalizing, fixing lengths and summing (or multiplying) if self.fUpdate['params']: self.ax[2].cla() self.ax[3].cla() self.ax[2].plot(misc.normalize_std(self.pitch) + 12, label="F0") self.ax[2].plot(misc.normalize_std(self.energy_smooth) + 8, label="Energy") self.ax[2].plot(misc.normalize_std(self.rate) + 4, label="Duration") self.energy_smooth = self.energy_smooth[:np.min( [len(self.pitch), len(self.energy_smooth)])] self.pitch = self.pitch[:np.min( [len(self.pitch), len(self.energy_smooth)])] self.rate = self.rate[:np.min([len(self.pitch), len(self.rate)])] if self.mul_feats.isChecked(): pitch = np.ones(len(self.pitch)) energy = np.ones(len(self.pitch)) duration = np.ones(len(self.pitch)) if float(self.wF0.text()) > 0 and np.std(self.pitch) > 0: pitch = misc.normalize_minmax(self.pitch) + float( self.wF0.text()) if float(self.wEnergy.text()) > 0 and np.std( self.energy_smooth) > 0: energy = misc.normalize_minmax(self.energy_smooth) + float( self.wEnergy.text()) if float(self.wDuration.text()) > 0 and np.std(self.rate) > 0: duration = misc.normalize_minmax(self.rate) + float( self.wDuration.text()) params = pitch * energy * duration else: params = misc.normalize_std(self.pitch) * float(self.wF0.text()) + \ misc.normalize_std(self.energy_smooth) * float(self.wEnergy.text()) + \ misc.normalize_std(self.rate) * float(self.wDuration.text()) if self.configuration["feature_combination"]["detrend"]: params = smooth_and_interp.remove_bias(params, 800) # FIXME: 800? self.params = misc.normalize_std(params) self.ax[2].plot(params, color="black", linewidth=2, label="Combined") try: labels = self.tiers[unicode(self.tierlist.currentText())] except Exception: labels = None if self.fUpdate['tiers']: self.ax[3].cla() # do wavelet analysis n_scales = 40 scale_dist = 0.25 if self.fUpdate['cwt']: self.logger.debug("wavelet transform...") (self.cwt, self.scales, self.freqs) = cwt_utils.cwt_analysis( self.params, mother_name=self.configuration["wavelet"]["mother_wavelet"], period=self.configuration["wavelet"]["period"], num_scales=self.configuration["wavelet"]["num_scales"], scale_distance=self.configuration["wavelet"]["scale_distance"], apply_coi=True) if self.configuration["wavelet"]["magnitude"]: self.cwt = np.log(np.abs(self.cwt) + 1.) else: self.cwt = np.real(self.cwt) self.fUpdate['loma'] = True # operate on frames, not time self.scales *= PLOT_SR if self.fUpdate['tiers'] or self.fUpdate['cwt']: import matplotlib.colors as colors self.ax[-1].contourf(np.real(self.cwt), 100, norm=colors.SymLogNorm(linthresh=0.01, linscale=0.05, vmin=-1.0, vmax=1.0), cmap="jet") # calculate lines of maximum and minimum amplitude if self.fUpdate['loma'] and labels: self.logger.debug("lines of maximum amplitude...") n_scales = self.configuration["wavelet"]["num_scales"] scale_dist = self.configuration["wavelet"]["scale_distance"] # get scale corresponding to avg unit length of selected tier unit_scale = misc.get_best_scale2(self.scales, labels) unit_scale = np.max([8, unit_scale]) unit_scale = np.min([n_scales - 2, unit_scale]) print(unit_scale) labdur = [] for l in labels: labdur.append(l[1] - l[0]) # Define the scale information (FIXME: description) pos_loma_start_scale = unit_scale + int( self.configuration["loma"]["prom_start"] / scale_dist) # three octaves down from average unit length pos_loma_end_scale = unit_scale + int( self.configuration["loma"]["prom_end"] / scale_dist) neg_loma_start_scale = unit_scale + int( self.configuration["loma"]["boundary_start"] / scale_dist) # two octaves down neg_loma_end_scale = unit_scale + int( self.configuration["loma"]["boundary_end"] / scale_dist) # one octave up # some bug if starting from 0-3 scales pos_loma_start_scale = np.max([4, pos_loma_start_scale]) neg_loma_start_scale = np.max([4, neg_loma_start_scale]) pos_loma_end_scale = np.min([n_scales, pos_loma_end_scale]) neg_loma_end_scale = np.min([n_scales, neg_loma_end_scale]) pos_loma = loma.get_loma(np.real(self.cwt), self.scales, pos_loma_start_scale, pos_loma_end_scale) loma.plot_loma(pos_loma, self.ax[-1], color="black") neg_loma = loma.get_loma(-np.real(self.cwt), self.scales, neg_loma_start_scale, neg_loma_end_scale) loma.plot_loma(neg_loma, self.ax[-1], color="white") if labels: max_loma = loma.get_prominences(pos_loma, labels) self.prominences = np.array(max_loma) self.boundaries = np.array( loma.get_boundaries(max_loma, neg_loma, labels)) self.fUpdate['tiers'] = True # plot labels if self.fUpdate['tiers'] and labels: labels = self.tiers[unicode(self.tierlist.currentText())] text_prominence = self.prominences[:, 1] / (np.max( self.prominences[:, 1])) * 2.5 + 0.5 lab.plot_labels(labels, ypos=1, fig=self.ax[-1], size=5, prominences=text_prominence, boundary=True) for i in range(0, len(labels)): self.ax[-1].axvline(x=labels[i][1], color='black', linestyle="-", linewidth=self.boundaries[i][-1] * 4, alpha=0.3) # # save analyses if labels: pass # FIXME: ???? loma.save_analyses( os.path.splitext(unicode(self.cur_wav))[0] + ".prom", labels, self.prominences, self.boundaries, PLOT_SR) self.ax[-1].set_ylim(0, n_scales) self.ax[-1].set_xlim(0, len(self.params)) self.ax[0].set_ylabel("Spec (Hz)") self.ax[1].set_ylabel("F0 (Hz)") self.ax[2].set_ylabel("Signals") self.ax[2].set_yticklabels(["sum", "dur", "en", "f0"]) self.ax[3].set_ylabel("Wavelet scale (Hz)") plt.setp([a.get_xticklabels() for a in self.ax[0:-1]], visible=False) vals = self.ax[-1].get_xticks()[1:] ticks_x = ticker.FuncFormatter( lambda vals, p: '{:1.2f}'.format(float(vals / PLOT_SR))) self.ax[-1].xaxis.set_major_formatter(ticks_x) # can't comprehend matplotlib ticks.. construct frequency axis manually self.ax[3].set_yticks(np.linspace(0, len(self.freqs), len(self.freqs))) self.ax[3].set_yticklabels(np.around(self.freqs[:-1], 2).astype('str')) for index, label in enumerate(self.ax[3].yaxis.get_ticklabels()): if index % 4 != 0 or index == 0: label.set_visible(False) for i in range(0, 3): nbins = len(self.ax[i].get_yticklabels()) self.ax[i].yaxis.set_major_locator( MaxNLocator(nbins=nbins, prune='lower')) self.figure.subplots_adjust(hspace=0, wspace=0) if prev_zoom: self.ax[3].axis(prev_zoom) self.canvas.draw() self.canvas.show() self.fUpdate = dict.fromkeys(self.fUpdate, False)
def calc_global_spectrum(wav_file, period=5, n_scales=60, plot=False): """ """ # Extract signal envelope, scale and normalize (fs, waveform) = misc.read_wav(wav_file) waveform = misc.resample(waveform, fs, 16000) energy = energy_processing.extract_energy(waveform, min_freq=30, method="hilbert") energy[energy < 0] = 0 energy = np.cbrt(energy + 0.1) params = misc.normalize_std(energy) # perform continous wavelet transform on envelope with morlet wavelet # increase _period to get sharper spectrum matrix, scales, freq = cwt_utils.cwt_analysis(params, first_freq=16, num_scales=n_scales, scale_distance=0.1, period=period, mother_name="Morlet", apply_coi=True) # power, arbitrary scaling to prevent underflow p_matrix = (abs(matrix)**2).astype('float32') * 1000.0 power_spec = np.nanmean(p_matrix, axis=1) if plot: f, wave_pics = plt.subplots(1, 2, gridspec_kw={'width_ratios': [5, 1]}, sharey=True) f.subplots_adjust(hspace=10) f.subplots_adjust(wspace=0) wave_pics[0].set_ylim(0, n_scales) wave_pics[0].set_xlabel("Time(m:s)") wave_pics[0].set_ylabel("Frequency(Hz)") wave_pics[1].set_xlabel("power") wave_pics[1].tick_params(labelright=True) fname = os.path.basename(wav_file) title = "CWT Morlet(p=" + str(period) + ") global spectrum, " + fname wave_pics[0].contourf(p_matrix, 100) wave_pics[0].set_title(title, loc="center") wave_pics[0].plot(params * 3, color="white", alpha=0.5) freq_labels = [ round(x, 3) if (np.isclose(x, round(x)) or (x < 2 and np.isclose(x * 100., round(x * 100))) or (x < 0.5 and np.isclose(x * 10000., round(x * 10000)))) else "" for x in list(freq) ] wave_pics[0].set_yticks( np.linspace(0, len(freq_labels) - 1, len(freq_labels))) wave_pics[0].set_yticklabels(freq_labels) formatter = matplotlib.ticker.FuncFormatter( lambda ms, x: time.strftime('%M:%S', time.gmtime(ms // 200))) wave_pics[0].xaxis.set_major_formatter(formatter) wave_pics[1].grid(axis="y") wave_pics[1].plot(power_spec, np.linspace(0, len(power_spec), len(power_spec)), "-") plt.show() return (power_spec, freq)