コード例 #1
0
def run():
    """Main function which wraps the prosody labeller tool

    """
    global args

    # Extract labels
    if args.label:
        lab_f = args.label
    else:
        lab_f = os.path.splitext(args.input_file)[0]+".lab"

    if os.path.exists(lab_f):
        labels = lab.read_htk_label(lab_f)
        labels = labels[args.level] # Filter by level
    else:
        logging.error("Label file \"%s\" doesn't exist" % lab_f)
        sys.exit(-1)

    # Extract parameters
    (params, pitch, energy_smooth, rate) = extract_params(args.input_file, labels)

    # perform wavelet transform
    (cwt,scales) = cwt_utils.cwt_analysis(params, mother_name="mexican_hat", period=2,
                                          num_scales=args.nb_scales, scale_distance=args.scale_dist,
                                          apply_coi=True)
    scales *= args.scale_factor

    # Labelling prominences and boundarys
    (prominences, boundaries, pos_loma, neg_loma) = label_prosody(scales, cwt, labels)

    print("========================================================")
    print("label\tprominence\tboundary")
    for i in range(0, len(labels)):
        print("%s\t%f\t%f" %(labels[i][-1], prominences[i][-1], boundaries[i][-1]))

    if args.plot:
        warnings.simplefilter("ignore", np.ComplexWarning) # Plotting can't deal with complex, but we don't care
        plot(labels, rate, energy_smooth, pitch, params, cwt, boundaries, prominences, pos_loma, neg_loma)
コード例 #2
0
def analysis(input_file, cfg, logger, annotation_dir=None, output_dir=None, plot=False):

    # Load the wave file
    print("Analyzing %s starting..." % input_file)
    orig_sr, sig = misc.read_wav(input_file)

    # extract energy
    energy = energy_processing.extract_energy(sig, orig_sr,
                                              cfg["energy"]["band_min"],
                                              cfg["energy"]["band_max"],
                                              cfg["energy"]["calculation_method"])
    energy = np.cbrt(energy+1)
    if cfg["energy"]["smooth_energy"]:
        energy = smooth_and_interp.peak_smooth(energy, 30, 3)  # FIXME: 30? 3?
        energy = smooth_and_interp.smooth(energy, 10)

    # extract f0
    raw_pitch = f0_processing.extract_f0(sig, orig_sr,
                                         f0_min=cfg["f0"]["min_f0"],
                                         f0_max=cfg["f0"]["max_f0"],
                                         voicing=cfg["f0"]["voicing_threshold"],
                                         #harmonics=cfg["f0"]["harmonics"],
                                         configuration=cfg["f0"]["pitch_tracker"])
    # interpolate, stylize
    pitch = f0_processing.process(raw_pitch)

    # extract speech rate
    rate = np.zeros(len(pitch))


    # Get annotations (if available)
    tiers = []
    if annotation_dir is None:
        annotation_dir = os.path.dirname(input_file)
    basename = os.path.splitext(os.path.basename(input_file))[0]
    grid =  os.path.join(annotation_dir, "%s.TextGrid" % basename)
    if os.path.exists(grid):
        tiers = lab.read_textgrid(grid)
    else:
        grid =  os.path.join(annotation_dir, "%s.lab" % basename)
        if not os.path.exists(grid):
            raise Exception("There is no annotations associated with %s" % input_file)
        tiers = lab.read_htk_label(grid)

    # Extract duration
    if len(tiers) > 0:
        dur_tiers = []
        for level in cfg["duration"]["duration_tiers"]:
            assert(level.lower() in tiers), level+" not defined in tiers: check that duration_tiers in config match the actual textgrid tiers"
            try:
                dur_tiers.append(tiers[level.lower()])
            except:
                print("\nerror: "+"\""+level+"\"" +" not in labels, modify duration_tiers in config\n\n")
                raise

    if not cfg["duration"]["acoustic_estimation"]:
        rate = duration_processing.get_duration_signal(dur_tiers,
                                                       weights=cfg["duration"]["weights"],
                                                       linear=cfg["duration"]["linear"],
                                                       sil_symbols=cfg["duration"]["silence_symbols"],
                                                       bump = cfg["duration"]["bump"])

    else:
        rate = duration_processing.get_rate(energy)
        rate = smooth_and_interp.smooth(rate, 30)

    if cfg["duration"]["delta_duration"]:
            rate = np.diff(rate)

    # Combine signals
    min_length = np.min([len(pitch), len(energy), len(rate)])
    pitch = pitch[:min_length]
    energy = energy[:min_length]
    rate = rate[:min_length]

    if cfg["feature_combination"]["type"] == "product":
        pitch = misc.normalize_minmax(pitch) ** cfg["feature_combination"]["weights"]["f0"]
        energy = misc.normalize_minmax(energy) ** cfg["feature_combination"]["weights"]["energy"]
        rate =  misc.normalize_minmax(rate) ** cfg["feature_combination"]["weights"]["duration"]
        params = pitch * energy * rate

    else:
        params = misc.normalize_std(pitch) * cfg["feature_combination"]["weights"]["f0"] + \
                 misc.normalize_std(energy) * cfg["feature_combination"]["weights"]["energy"] + \
                 misc.normalize_std(rate) * cfg["feature_combination"]["weights"]["duration"]

    if cfg["feature_combination"]["detrend"]:
         params = smooth_and_interp.remove_bias(params, 800)

    params = misc.normalize_std(params)


    # CWT analysis
    (cwt, scales, freqs) = cwt_utils.cwt_analysis(params,
                                                  mother_name=cfg["wavelet"]["mother_wavelet"],
                                                  period=cfg["wavelet"]["period"],
                                                  num_scales=cfg["wavelet"]["num_scales"],
                                                  scale_distance=cfg["wavelet"]["scale_distance"],
                                                  apply_coi=False)
    cwt = np.real(cwt)
    scales *= 200 # FIXME: why 200?


    # Compute lines of maximum amplitude
    assert(cfg["labels"]["annotation_tier"].lower() in tiers), \
        cfg["labels"]["annotation_tier"]+" not defined in tiers: check that annotation_tier in config is found in the textgrid tiers"
    labels = tiers[cfg["labels"]["annotation_tier"].lower()]

    # get scale corresponding to avg unit length of selected tier
    n_scales = cfg["wavelet"]["num_scales"]
    scale_dist = cfg["wavelet"]["scale_distance"]
    scales = (1./freqs*200)*0.5 # FIXME: hardcoded vales
    unit_scale = misc.get_best_scale2(scales, labels)

    # Define the scale information (FIXME: description)
    pos_loma_start_scale = unit_scale + int(cfg["loma"]["prom_start"]/scale_dist)  # three octaves down from average unit length
    pos_loma_end_scale = unit_scale + int(cfg["loma"]["prom_end"]/scale_dist)
    neg_loma_start_scale = unit_scale + int(cfg["loma"]["boundary_start"]/scale_dist)  # two octaves down
    neg_loma_end_scale = unit_scale + int(cfg["loma"]["boundary_end"]/scale_dist)  # one octave up

    pos_loma = loma.get_loma(cwt, scales, pos_loma_start_scale, pos_loma_end_scale)
    neg_loma = loma.get_loma(-cwt, scales, neg_loma_start_scale, neg_loma_end_scale)

    max_loma = loma.get_prominences(pos_loma, labels)
    prominences = np.array(max_loma)
    boundaries = np.array(loma.get_boundaries(max_loma, neg_loma, labels))


    # output results
    if output_dir is None:
        output_dir = os.path.dirname(input_file)
    os.makedirs(output_dir, exist_ok=True)

    basename = os.path.splitext(os.path.basename(input_file))[0]
    output_filename = os.path.join(output_dir, "%s.prom" % basename)
    print("Saving %s..." % (output_filename))
    loma.save_analyses(output_filename,
                       labels,
                       prominences,
                       boundaries)

    # Plotting
    if plot != 0:
        fig, ax =  plt.subplots(6, 1, sharex=True,
                                figsize=(len(labels) / 10 * 8, 8),
                                gridspec_kw = {'height_ratios':[1, 1, 1, 2, 4, 1.5]})
        plt.subplots_adjust(hspace=0)

        # Plot individual signals
        ax[0].plot(pitch, linewidth=1)
        ax[0].set_ylabel("Pitch", rotation="horizontal", ha="right", va="center")

        ax[1].plot(energy, linewidth=1)
        ax[1].set_ylabel("Energy", rotation="horizontal", ha="right", va="center")

        ax[2].plot(rate, linewidth=1)
        ax[2].set_ylabel("Speech rate", rotation="horizontal", ha="right", va="center")

        # Plot combined signal
        ax[3].plot(params, linewidth=1)
        ax[3].set_ylabel("Combined \n signal", rotation="horizontal", ha="right", va="center")
        plt.xlim(0, len(params))

        # Wavelet and loma
        cwt[cwt>0] = np.log(cwt[cwt>0]+1.)
        cwt[cwt<-0.1] = -0.1
        ax[4].contourf(cwt,100, cmap="inferno")
        loma.plot_loma(pos_loma, ax[4], color="black")
        loma.plot_loma(neg_loma, ax[4], color="white")
        ax[4].set_ylabel("Wavelet & \n LOMA", rotation="horizontal", ha="right", va="center")
        
        # Add labels
        prom_text =  prominences[:, 1]/(np.max(prominences[:, 1]))*2.5 + 0.5
        lab.plot_labels(labels, ypos=0.3, size=6, prominences=prom_text, fig=ax[5], boundary=False, background=False)
        ax[5].set_ylabel("Labels", rotation="horizontal", ha="right", va="center")
        for i in range(0, len(labels)):
            for a in [0, 1, 2, 3, 4, 5]:
                ax[a].axvline(x=labels[i][0], color='black',
                              linestyle="-", linewidth=0.2, alpha=0.5)
                
                ax[a].axvline(x=labels[i][1], color='black',
                              linestyle="-", linewidth=0.2+boundaries[i][-1] * 2,
                              alpha=0.5)

        plt.xlim(0, cwt.shape[1])
    
        # Align ylabels and remove axis
        fig.align_ylabels(ax)
        for i in range(len(ax)-1):
            ax[i].tick_params(
                axis='x',          # changes apply to the x-axis
                which='both',      # both major and minor ticks are affected
                bottom=False,      # ticks along the bottom edge are off
                top=False,         # ticks along the top edge are off
                labelbottom=False) # labels along the bottom edge are off
            ax[i].tick_params(
                axis='y',          # changes apply to the x-axis
                which='both',      # both major and minor ticks are affected
                left=False,      # ticks along the bottom edge are off
                right=False,         # ticks along the top edge are off
                labelleft=False) # labels along the bottom edge are off

        ax[len(ax)-1].tick_params(
            axis='y',          # changes apply to the x-axis
            which='both',      # both major and minor ticks are affected
            left=False,      # ticks along the bottom edge are off
            right=False,         # ticks along the top edge are off
            labelleft=False) # labels along the bottom edge are off

        # Plot
        if plot < 0:
            output_filename = os.path.join(output_dir, "%s.png" % basename)
            logger.info("Save plot %s" % output_filename)
            fig.savefig(output_filename, bbox_inches='tight', dpi=400)
        elif plot > 0:
            plt.show()
コード例 #3
0
def run():
    """Main entry function

    This function contains the code needed to achieve the analysis and/or the synthesis
    """
    global args

    warnings.simplefilter(
        "ignore",
        FutureWarning)  # Plotting can't deal with complex, but we don't care

    # Loading default configuration
    configuration = defaultdict()
    with open(
            os.path.dirname(os.path.realpath(__file__)) +
            "/configs/default.yaml", 'r') as f:
        configuration = apply_configuration(
            configuration, defaultdict(lambda: False, yaml.load(f)))
        logging.debug("default configuration")
        logging.debug(configuration)

    # Loading dedicated analysis.synthesis configuration
    with open(
            os.path.dirname(os.path.realpath(__file__)) +
            "/configs/synthesis.yaml", 'r') as f:
        configuration = apply_configuration(
            configuration, defaultdict(lambda: False, yaml.load(f)))
        logging.debug("configuration filled with synthesis part")
        logging.debug(configuration)

    # Loading user configuration
    if args.configuration_file:
        try:
            with open(args.configuration_file, 'r') as f:
                configuration = apply_configuration(
                    configuration, defaultdict(lambda: False, yaml.load(f)))
                logging.debug("configuration filled with user part")
                logging.debug(configuration)
        except IOError as ex:
            logging.error("configuration file " + args.config +
                          " could not be loaded:")
            logging.error(ex.msg)
            sys.exit(1)

    # Analysis Mode
    if args.mode == 0:
        raw_f0 = load_f0(args.input_file, args.binary_mode, configuration)

        logging.info("Processing f0")
        f0 = f0_processing.process(raw_f0)
        # FIXME: reintegrated
        if args.plot:
            plt.title("F0 preprocessing and interpolation")
            plt.plot(f0, color="red", alpha=0.5, linewidth=3)
            plt.plot(raw_f0, color="gray", alpha=0.5)
            plt.show()

        # # FIXME: read this?
        # logging.info("writing interpolated lf0\t" + output_file + ".interp")
        # np.savetxt(output_file + ".interp", f0.astype('float'),
        #            fmt="%f", delimiter="\n")

        # Perform continuous wavelet transform of mean-substracted f0 with 12 scales, one octave apart
        logging.info(
            "Starting analysis with (num_scale=%d, scale_distance=%f, mother_name=%s)"
            % (configuration["wavelet"]["num_scales"],
               configuration["wavelet"]["scale_distance"],
               configuration["wavelet"]["mother_wavelet"]))
        full_scales, widths, _ = cwt_utils.cwt_analysis(
            f0 - np.mean(f0),
            mother_name=configuration["wavelet"]["mother_wavelet"],
            period=configuration["wavelet"]["period"],
            num_scales=configuration["wavelet"]["num_scales"],
            scale_distance=configuration["wavelet"]["scale_distance"],
            apply_coi=False)
        full_scales = np.real(full_scales)
        # SSW parameterization, adjacent scales combined (with extra scales to handle long utterances)
        scales = cwt_utils.combine_scales(
            np.real(full_scales), configuration["wavelet"]["combined_scales"])
        for i in range(0, len(scales)):
            logging.debug("Mean scale[%d]: %s" % (i, str(np.mean(scales[i]))))

        # Saving matrix
        logging.info("writing wavelet matrix in \"%s\"" % args.output_file)
        if args.binary_mode:
            with open(args.output_file, "wb") as f_out:
                scales.T.astype(np.float32).tofile(f_out)
        else:
            np.savetxt(args.output_file,
                       scales.T.astype('float'),
                       fmt="%f",
                       delimiter=",")

    # Synthesis mode
    if args.mode == 1:
        if args.binary_mode:
            scales = np.fromfile(args.input_file, dtype=np.float32)
            scales = scales.reshape(
                -1, len(configuration["wavelet"]["combined_scales"])).T
        else:
            scales = np.loadtxt(args.input_file,
                                delimiter=",").T  # FIXME: hardcoded

        rec = cwt_utils.cwt_synthesis(scales, args.mean_f0)

        logging.info("Save reconstructed f0 in %s" % args.output_file)
        if args.binary_mode:
            with open(args.output_file, "wb") as f_out:
                rec.astype(np.float32).tofile(f_out)
        else:
            np.savetxt(args.output_file, rec, fmt="%f")

    # Debugging /plotting part
    if args.plot:
        nb_sub = 2
        if args.mode == 0:
            nb_sub = 3

        ax = plt.subplot(nb_sub, 1, 1)
        # pylab.title("CWT decomposition to % scales and reconstructed signal" % len(configuration["wavelet"]["combined_scales"]))

        if args.mode == 0:
            plt.plot(f0, linewidth=1, color="red")
            rec = cwt_utils.cwt_synthesis(scales, np.mean(f0))

        plt.plot(rec, color="blue", alpha=0.3)

        plt.subplot(nb_sub, 1, 2, sharex=ax)
        for i in range(0, len(scales)):
            plt.plot(scales[i] + max(rec) * 1.5 + i * 75,
                     color="blue",
                     alpha=0.5)
            #plt.plot(scales[len(scales)-i-1] + max(rec)*1.5 + i*75,

        if args.mode == 0:
            plt.subplot(nb_sub, 1, 3, sharex=ax)
            plt.contourf(np.real(full_scales),
                         100,
                         norm=colors.SymLogNorm(linthresh=0.2,
                                                linscale=0.05,
                                                vmin=np.min(full_scales),
                                                vmax=np.max(full_scales)),
                         cmap="jet")
        plt.show()
コード例 #4
0
    def analysis(self):
        prev_zoom = None

        if not self.fUpdate["wav"]:
            prev_zoom = self.ax[3].axis()

        if not self.cur_wav:
            return

        self.refresh_updates()

        # show spectrogram
        if self.fUpdate['wav']:
            self.toolbar.update()
            self.logger.debug("plot specgram")

            self.ax[0].cla()
            self.orig_sr, self.sig = misc.read_wav(self.cur_wav)
            self.plot_len = int(len(self.sig) * (PLOT_SR / self.orig_sr))
            self.ax[0].specgram(self.sig,
                                NFFT=200,
                                noverlap=40,
                                Fs=self.orig_sr,
                                xextent=[0, self.plot_len],
                                cmap="jet")

        if self.fUpdate['energy']:
            # 'energy' is just a smoothed envelope here
            self.logger.debug("analyzing energy..")
            self.energy = energy_processing.extract_energy(
                self.sig, self.orig_sr,
                self.configuration["energy"]["band_min"],
                self.configuration["energy"]["band_max"],
                self.configuration["energy"]["calculation_method"])

            if self.configuration["energy"]["smooth_energy"]:
                self.energy_smooth = smooth_and_interp.peak_smooth(
                    self.energy, 30, 3)  # FIXME: 30? 3?
            else:
                self.energy_smooth = self.energy

        raw_pitch = None

        if self.fUpdate['f0']:
            self.ax[1].cla()
            self.pitch = None
            raw_pitch = None

            # if f0 file is provided, use that
            if self.bUseExistingF0.isChecked():
                raw_pitch = f0_processing.read_f0(self.cur_wav)

            # else use reaper
            if raw_pitch is None:
                # analyze pitch
                self.logger.debug("analyzing pitch..")
                min_f0 = float(str(self.min_f0.text()))
                max_f0 = float(str(self.max_f0.text()))
                max_f0 = np.max([max_f0, 10.])
                min_f0 = np.min([max_f0 - 1., min_f0])

                raw_pitch = f0_processing.extract_f0(
                    self.sig, self.orig_sr, min_f0, max_f0,
                    float(self.harmonics.value()), float(self.voicing.value()),
                    self.configuration["f0"]["pitch_tracker"])

            # FIXME: fix errors, smooth and interpolate
            try:
                self.pitch = f0_processing.process(raw_pitch)
            except Exception as ex:
                exception_log(
                    self.logger, "no idea!!!", ex,
                    logging.DEBUG)  # FIXME: more human friendly message
                # f0_processing.process crashes if raw_pitch is all zeros, kludge
                self.pitch = raw_pitch

            self.ax[1].plot(raw_pitch, color='black', linewidth=1)
            self.ax[1].plot(self.pitch, color='black', linewidth=2)
            self.ax[1].set_ylim(
                np.min(self.pitch) * 0.75,
                np.max(self.pitch) * 1.2)

        if self.fUpdate['duration']:

            self.rate = np.zeros(len(self.pitch))

            self.logger.debug("analyzing duration...")

            # signal method for speech rate, quite shaky
            if self.signalRate.isChecked():
                self.rate = duration_processing.get_rate(self.energy)
                self.rate = smooth_and_interp.smooth(self.rate, 30)

            # word / syllable / segment duration from labels
            else:
                sig_tiers = []
                for item in self.signalTiers.selectedItems():
                    sig_tiers.append(self.tiers[item.text()])

                try:
                    # Only if some tiers are selected
                    if (len(sig_tiers)) > 0:
                        self.rate = duration_processing.get_duration_signal(
                            sig_tiers,
                            sil_symbols=self.configuration["duration"]
                            ["silence_symbols"])
                except Exception as ex:
                    exception_log(self.logger,
                                  "Duration signal construction failed", ex,
                                  logging.ERROR)

            if self.diffDur.isChecked():
                self.rate = np.diff(self.rate, 1)

            try:
                self.rate = np.pad(self.rate,
                                   (0, len(self.pitch) - len(self.rate)),
                                   'edge')
            except Exception:
                self.rate = self.rate[0:len(self.pitch)]

        # combine acoustic features by normalizing, fixing lengths and summing (or multiplying)
        if self.fUpdate['params']:
            self.ax[2].cla()
            self.ax[3].cla()

            self.ax[2].plot(misc.normalize_std(self.pitch) + 12, label="F0")
            self.ax[2].plot(misc.normalize_std(self.energy_smooth) + 8,
                            label="Energy")
            self.ax[2].plot(misc.normalize_std(self.rate) + 4,
                            label="Duration")

            self.energy_smooth = self.energy_smooth[:np.min(
                [len(self.pitch), len(self.energy_smooth)])]
            self.pitch = self.pitch[:np.min(
                [len(self.pitch), len(self.energy_smooth)])]
            self.rate = self.rate[:np.min([len(self.pitch), len(self.rate)])]

            if self.mul_feats.isChecked():
                pitch = np.ones(len(self.pitch))
                energy = np.ones(len(self.pitch))
                duration = np.ones(len(self.pitch))

                if float(self.wF0.text()) > 0 and np.std(self.pitch) > 0:
                    pitch = misc.normalize_minmax(self.pitch) + float(
                        self.wF0.text())
                if float(self.wEnergy.text()) > 0 and np.std(
                        self.energy_smooth) > 0:
                    energy = misc.normalize_minmax(self.energy_smooth) + float(
                        self.wEnergy.text())
                if float(self.wDuration.text()) > 0 and np.std(self.rate) > 0:
                    duration = misc.normalize_minmax(self.rate) + float(
                        self.wDuration.text())

                params = pitch * energy * duration
            else:
                params = misc.normalize_std(self.pitch) * float(self.wF0.text()) + \
                         misc.normalize_std(self.energy_smooth) * float(self.wEnergy.text()) + \
                         misc.normalize_std(self.rate) * float(self.wDuration.text())

            if self.configuration["feature_combination"]["detrend"]:
                params = smooth_and_interp.remove_bias(params,
                                                       800)  # FIXME: 800?

            self.params = misc.normalize_std(params)
            self.ax[2].plot(params,
                            color="black",
                            linewidth=2,
                            label="Combined")

        try:
            labels = self.tiers[unicode(self.tierlist.currentText())]
        except Exception:
            labels = None

        if self.fUpdate['tiers']:
            self.ax[3].cla()

        # do wavelet analysis
        n_scales = 40
        scale_dist = 0.25

        if self.fUpdate['cwt']:
            self.logger.debug("wavelet transform...")

            (self.cwt, self.scales, self.freqs) = cwt_utils.cwt_analysis(
                self.params,
                mother_name=self.configuration["wavelet"]["mother_wavelet"],
                period=self.configuration["wavelet"]["period"],
                num_scales=self.configuration["wavelet"]["num_scales"],
                scale_distance=self.configuration["wavelet"]["scale_distance"],
                apply_coi=True)
            if self.configuration["wavelet"]["magnitude"]:
                self.cwt = np.log(np.abs(self.cwt) + 1.)
            else:
                self.cwt = np.real(self.cwt)

            self.fUpdate['loma'] = True
            # operate on frames, not time
            self.scales *= PLOT_SR
        if self.fUpdate['tiers'] or self.fUpdate['cwt']:
            import matplotlib.colors as colors
            self.ax[-1].contourf(np.real(self.cwt),
                                 100,
                                 norm=colors.SymLogNorm(linthresh=0.01,
                                                        linscale=0.05,
                                                        vmin=-1.0,
                                                        vmax=1.0),
                                 cmap="jet")

        # calculate lines of maximum and minimum amplitude
        if self.fUpdate['loma'] and labels:
            self.logger.debug("lines of maximum amplitude...")
            n_scales = self.configuration["wavelet"]["num_scales"]
            scale_dist = self.configuration["wavelet"]["scale_distance"]

            # get scale corresponding to avg unit length of selected tier
            unit_scale = misc.get_best_scale2(self.scales, labels)

            unit_scale = np.max([8, unit_scale])
            unit_scale = np.min([n_scales - 2, unit_scale])
            print(unit_scale)
            labdur = []
            for l in labels:
                labdur.append(l[1] - l[0])

            # Define the scale information (FIXME: description)
            pos_loma_start_scale = unit_scale + int(
                self.configuration["loma"]["prom_start"] /
                scale_dist)  # three octaves down from average unit length
            pos_loma_end_scale = unit_scale + int(
                self.configuration["loma"]["prom_end"] / scale_dist)
            neg_loma_start_scale = unit_scale + int(
                self.configuration["loma"]["boundary_start"] /
                scale_dist)  # two octaves down
            neg_loma_end_scale = unit_scale + int(
                self.configuration["loma"]["boundary_end"] /
                scale_dist)  # one octave up

            # some bug if starting from 0-3 scales
            pos_loma_start_scale = np.max([4, pos_loma_start_scale])
            neg_loma_start_scale = np.max([4, neg_loma_start_scale])
            pos_loma_end_scale = np.min([n_scales, pos_loma_end_scale])
            neg_loma_end_scale = np.min([n_scales, neg_loma_end_scale])

            pos_loma = loma.get_loma(np.real(self.cwt), self.scales,
                                     pos_loma_start_scale, pos_loma_end_scale)
            loma.plot_loma(pos_loma, self.ax[-1], color="black")
            neg_loma = loma.get_loma(-np.real(self.cwt), self.scales,
                                     neg_loma_start_scale, neg_loma_end_scale)
            loma.plot_loma(neg_loma, self.ax[-1], color="white")

            if labels:
                max_loma = loma.get_prominences(pos_loma, labels)
                self.prominences = np.array(max_loma)
                self.boundaries = np.array(
                    loma.get_boundaries(max_loma, neg_loma, labels))

            self.fUpdate['tiers'] = True

        # plot labels
        if self.fUpdate['tiers'] and labels:
            labels = self.tiers[unicode(self.tierlist.currentText())]
            text_prominence = self.prominences[:, 1] / (np.max(
                self.prominences[:, 1])) * 2.5 + 0.5

            lab.plot_labels(labels,
                            ypos=1,
                            fig=self.ax[-1],
                            size=5,
                            prominences=text_prominence,
                            boundary=True)

            for i in range(0, len(labels)):
                self.ax[-1].axvline(x=labels[i][1],
                                    color='black',
                                    linestyle="-",
                                    linewidth=self.boundaries[i][-1] * 4,
                                    alpha=0.3)

        #
        # save analyses
        if labels:
            pass  # FIXME: ????
            loma.save_analyses(
                os.path.splitext(unicode(self.cur_wav))[0] + ".prom", labels,
                self.prominences, self.boundaries, PLOT_SR)

        self.ax[-1].set_ylim(0, n_scales)
        self.ax[-1].set_xlim(0, len(self.params))
        self.ax[0].set_ylabel("Spec (Hz)")
        self.ax[1].set_ylabel("F0 (Hz)")
        self.ax[2].set_ylabel("Signals")

        self.ax[2].set_yticklabels(["sum", "dur", "en", "f0"])
        self.ax[3].set_ylabel("Wavelet scale (Hz)")

        plt.setp([a.get_xticklabels() for a in self.ax[0:-1]], visible=False)
        vals = self.ax[-1].get_xticks()[1:]
        ticks_x = ticker.FuncFormatter(
            lambda vals, p: '{:1.2f}'.format(float(vals / PLOT_SR)))
        self.ax[-1].xaxis.set_major_formatter(ticks_x)

        # can't comprehend matplotlib ticks.. construct frequency axis manually
        self.ax[3].set_yticks(np.linspace(0, len(self.freqs), len(self.freqs)))
        self.ax[3].set_yticklabels(np.around(self.freqs[:-1], 2).astype('str'))

        for index, label in enumerate(self.ax[3].yaxis.get_ticklabels()):
            if index % 4 != 0 or index == 0:
                label.set_visible(False)

        for i in range(0, 3):
            nbins = len(self.ax[i].get_yticklabels())
            self.ax[i].yaxis.set_major_locator(
                MaxNLocator(nbins=nbins, prune='lower'))

        self.figure.subplots_adjust(hspace=0, wspace=0)

        if prev_zoom:
            self.ax[3].axis(prev_zoom)

        self.canvas.draw()
        self.canvas.show()

        self.fUpdate = dict.fromkeys(self.fUpdate, False)
コード例 #5
0
def calc_global_spectrum(wav_file, period=5, n_scales=60, plot=False):
    """
    """

    # Extract signal envelope, scale and normalize
    (fs, waveform) = misc.read_wav(wav_file)
    waveform = misc.resample(waveform, fs, 16000)
    energy = energy_processing.extract_energy(waveform,
                                              min_freq=30,
                                              method="hilbert")
    energy[energy < 0] = 0
    energy = np.cbrt(energy + 0.1)
    params = misc.normalize_std(energy)

    # perform continous wavelet transform on envelope with morlet wavelet

    # increase _period to get sharper spectrum
    matrix, scales, freq = cwt_utils.cwt_analysis(params,
                                                  first_freq=16,
                                                  num_scales=n_scales,
                                                  scale_distance=0.1,
                                                  period=period,
                                                  mother_name="Morlet",
                                                  apply_coi=True)

    # power, arbitrary scaling to prevent underflow
    p_matrix = (abs(matrix)**2).astype('float32') * 1000.0
    power_spec = np.nanmean(p_matrix, axis=1)

    if plot:
        f, wave_pics = plt.subplots(1,
                                    2,
                                    gridspec_kw={'width_ratios': [5, 1]},
                                    sharey=True)
        f.subplots_adjust(hspace=10)
        f.subplots_adjust(wspace=0)
        wave_pics[0].set_ylim(0, n_scales)
        wave_pics[0].set_xlabel("Time(m:s)")
        wave_pics[0].set_ylabel("Frequency(Hz)")
        wave_pics[1].set_xlabel("power")
        wave_pics[1].tick_params(labelright=True)

        fname = os.path.basename(wav_file)
        title = "CWT Morlet(p=" + str(period) + ") global spectrum, " + fname
        wave_pics[0].contourf(p_matrix, 100)
        wave_pics[0].set_title(title, loc="center")
        wave_pics[0].plot(params * 3, color="white", alpha=0.5)

        freq_labels = [
            round(x, 3) if
            (np.isclose(x, round(x)) or
             (x < 2 and np.isclose(x * 100., round(x * 100))) or
             (x < 0.5 and np.isclose(x * 10000., round(x * 10000)))) else ""
            for x in list(freq)
        ]

        wave_pics[0].set_yticks(
            np.linspace(0,
                        len(freq_labels) - 1, len(freq_labels)))
        wave_pics[0].set_yticklabels(freq_labels)
        formatter = matplotlib.ticker.FuncFormatter(
            lambda ms, x: time.strftime('%M:%S', time.gmtime(ms // 200)))
        wave_pics[0].xaxis.set_major_formatter(formatter)
        wave_pics[1].grid(axis="y")
        wave_pics[1].plot(power_spec,
                          np.linspace(0, len(power_spec), len(power_spec)),
                          "-")
        plt.show()

    return (power_spec, freq)