예제 #1
0
    def getSummary(self, avianz_batch, CNN=False, CNNmodel=None):
        autoSegNum = 0
        autoSegCT = [[] for i in range(len(self.calltypes))]
        ws = WaveletSegment.WaveletSegment()
        TP = FP = TN = FN = 0
        for root, dirs, files in os.walk(self.testDir):
            for file in files:
                wavFile = os.path.join(root, file)
                if file.lower().endswith('.wav') and os.stat(wavFile).st_size != 0 and \
                        file + '.tmpdata' in files and file[:-4] + '-res' + str(float(self.window)) + 'sec.txt' in files:
                    autoSegCTCurrent = [[] for i in range(len(self.calltypes))]
                    avianz_batch.filename = os.path.join(root, file)
                    avianz_batch.loadFile([self.filtname], anysound=False)
                    duration = int(
                        np.ceil(
                            len(avianz_batch.audiodata) /
                            avianz_batch.sampleRate))
                    for i in range(len(self.calltypes)):
                        ctsegments = self.findCTsegments(
                            avianz_batch.filename, i)
                        post = Segment.PostProcess(
                            configdir=self.configdir,
                            audioData=avianz_batch.audiodata,
                            sampleRate=avianz_batch.sampleRate,
                            tgtsampleRate=self.sampleRate,
                            segments=ctsegments,
                            subfilter=self.currfilt['Filters'][i],
                            CNNmodel=CNNmodel,
                            cert=50)
                        post.wind()
                        if CNN and CNNmodel:
                            post.CNN()
                        if 'F0' in self.currfilt['Filters'][
                                i] and 'F0Range' in self.currfilt['Filters'][i]:
                            if self.currfilt['Filters'][i]["F0"]:
                                print("Checking for fundamental frequency...")
                                post.fundamentalFrq()
                        post.joinGaps(
                            maxgap=self.currfilt['Filters'][i]['TimeRange'][3])
                        post.deleteShort(minlength=self.currfilt['Filters'][i]
                                         ['TimeRange'][0])
                        if post.segments:
                            for seg in post.segments:
                                autoSegCTCurrent[i].append(seg[0])
                                autoSegCT[i].append(seg[0])
                                autoSegNum += 1
                    # back-convert to 0/1:
                    det01 = np.zeros(duration)
                    for i in range(len(self.calltypes)):
                        for seg in autoSegCTCurrent[i]:
                            det01[int(seg[0]):int(seg[1])] = 1
                    # get and parse the agreement metrics
                    GT = self.loadGT(
                        os.path.join(
                            root, file[:-4] + '-res' +
                            str(float(self.window)) + 'sec.txt'), duration)
                    _, _, tp, fp, tn, fn = ws.fBetaScore(GT, det01)
                    TP += tp
                    FP += fp
                    TN += tn
                    FN += fn
        # Summary
        total = TP + FP + TN + FN
        if total == 0:
            print("ERROR: failed to find any testing data")
            return

        if TP + FN != 0:
            recall = TP / (TP + FN)
        else:
            recall = 0
        if TP + FP != 0:
            precision = TP / (TP + FP)
        else:
            precision = 0
        if TN + FP != 0:
            specificity = TN / (TN + FP)
        else:
            specificity = 0
        accuracy = (TP + TN) / (TP + FP + TN + FN)

        if CNN:
            self.outfile.write(
                "\n\n-- Wavelet Pre-Processor + CNN detection summary --\n")
        else:
            self.outfile.write(
                "\n-- Wavelet Pre-Processor detection summary --\n")
        self.outfile.write(
            "TP | FP | TN | FN seconds:\t %.2f | %.2f | %.2f | %.2f\n" %
            (TP, FP, TN, FN))
        self.outfile.write("Specificity:\t\t%.2f %%\n" % (specificity * 100))
        self.outfile.write("Recall (sensitivity):\t%.2f %%\n" % (recall * 100))
        self.outfile.write("Precision (PPV):\t%.2f %%\n" % (precision * 100))
        self.outfile.write("Accuracy:\t\t%.2f %%\n\n" % (accuracy * 100))
        self.outfile.write("Manually labelled segments:\t%d\n" %
                           (self.manSegNum))
        for i in range(len(self.calltypes)):
            self.outfile.write("Auto suggested \'%s\' segments:\t%d\n" %
                               (self.calltypes[i], len(autoSegCT[i])))
        self.outfile.write("Total auto suggested segments:\t%d\n\n" %
                           (autoSegNum))

        if CNN:
            text = "Wavelet Pre-Processor + CNN detection summary\n\n\tTrue Positives:\t%d seconds (%.2f %%)\n\tFalse Positives:\t%d seconds (%.2f %%)\n\tTrue Negatives:\t%d seconds (%.2f %%)\n\tFalse Negatives:\t%d seconds (%.2f %%)\n\n\tSpecificity:\t%.2f %%\n\tRecall:\t\t%.2f %%\n\tPrecision:\t%.2f %%\n\tAccuracy:\t%.2f %%\n" \
                   % (TP, TP * 100 / total, FP, FP * 100 / total, TN, TN * 100 / total, FN, FN * 100 / total,
                      specificity * 100, recall * 100, precision * 100, accuracy * 100)
        else:
            text = "Wavelet Pre-Processor detection summary\n\n\tTrue Positives:\t%d seconds (%.2f %%)\n\tFalse Positives:\t%d seconds (%.2f %%)\n\tTrue Negatives:\t%d seconds (%.2f %%)\n\tFalse Negatives:\t%d seconds (%.2f %%)\n\n\tSpecificity:\t%.2f %%\n\tRecall:\t\t%.2f %%\n\tPrecision:\t%.2f %%\n\tAccuracy:\t%.2f %%\n" \
                   % (TP, TP * 100 / total, FP, FP * 100 / total, TN, TN * 100 / total, FN, FN * 100 / total,
                      specificity * 100, recall * 100, precision * 100, accuracy * 100)
        return text
예제 #2
0
    def detectFile(self, speciesStr, filters):
        """ Actual worker for a file in the detection loop.
            Does not return anything - for use with external try/catch
        """
        # Segment over pages separately, to allow dealing with large files smoothly:
        # TODO: page size fixed for now
        samplesInPage = 900 * 16000
        # (ceil division for large integers)
        numPages = (self.datalength - 1) // samplesInPage + 1

        # Actual segmentation happens here:
        for page in range(numPages):
            print("Segmenting page %d / %d" % (page + 1, numPages))
            start = page * samplesInPage
            end = min(start + samplesInPage, self.datalength)
            thisPageLen = (end - start) / self.sampleRate

            if thisPageLen < 2 and self.method != "Click":
                print("Warning: can't process short file ends (%.2f s)" %
                      thisPageLen)
                continue

            # Process
            if speciesStr == "Any sound":
                # Create spectrogram for median clipping etc
                if not hasattr(self, 'sp'):
                    self.sp = SignalProc.SignalProc(
                        self.config['window_width'], self.config['incr'])
                self.sp.data = self.audiodata[start:end]
                self.sp.sampleRate = self.sampleRate
                _ = self.sp.spectrogram(window='Hann',
                                        mean_normalise=True,
                                        onesided=True,
                                        multitaper=False,
                                        need_even=False)
                self.seg = Segment.Segmenter(self.sp, self.sampleRate)
                # thisPageSegs = self.seg.bestSegments()
                thisPageSegs = self.seg.medianClip(thr=3.5)
                # Post-process
                # 1. Delete windy segments
                # 2. Merge neighbours
                # 3. Delete short segments
                print("Segments detected: ", len(thisPageSegs))
                print("Post-processing...")
                maxgap = int(self.maxgap.value()) / 1000
                minlen = int(self.minlen.value()) / 1000
                maxlen = int(self.maxlen.value()) / 1000
                post = Segment.PostProcess(configdir=self.configdir,
                                           audioData=self.audiodata[start:end],
                                           sampleRate=self.sampleRate,
                                           segments=thisPageSegs,
                                           subfilter={},
                                           cert=0)
                if self.wind:
                    post.wind()
                post.joinGaps(maxgap)
                post.deleteShort(minlen)
                # avoid extra long segments (for Isabel)
                post.splitLong(maxlen)

                # adjust segment starts for 15min "pages"
                if start != 0:
                    for seg in post.segments:
                        seg[0][0] += start / self.sampleRate
                        seg[0][1] += start / self.sampleRate
                # attach mandatory "Don't Know"s etc and put on self.segments
                self.makeSegments(post.segments)
                del self.seg
                gc.collect()
            else:
                if self.method != "Click":
                    # read in the page and resample as needed
                    self.ws.readBatch(self.audiodata[start:end],
                                      self.sampleRate,
                                      d=False,
                                      spInfo=filters,
                                      wpmode="new")

                data_test = []
                click_label = 'None'
                for speciesix in range(len(filters)):
                    print("Working with recogniser:", filters[speciesix])
                    if self.method != "Click":
                        # note: using 'recaa' mode = partial antialias
                        thisPageSegs = self.ws.waveletSegment(speciesix,
                                                              wpmode="new")
                    else:
                        click_label, data_test, gen_spec = self.ClickSearch(
                            self.sp.sg, self.filename)
                        print('number of detected clicks = ', gen_spec)
                        thisPageSegs = []

                    # Post-process:
                    # CNN-classify, delete windy, rainy segments, check for FundFreq, merge gaps etc.
                    print("Segments detected (all subfilters): ", thisPageSegs)
                    if not self.testmode:
                        print("Post-processing...")
                    # postProcess currently operates on single-level list of segments,
                    # so we run it over subfilters for wavelets:
                    spInfo = filters[speciesix]
                    for filtix in range(len(spInfo['Filters'])):
                        if not self.testmode:
                            # TODO THIS IS FULL POST-PROC PIPELINE FOR BIRDS AND BATS
                            # -- Need to check how this should interact with the testmode
                            CNNmodel = None
                            if 'CNN' in spInfo:
                                if spInfo['CNN'][
                                        'CNN_name'] in self.CNNDicts.keys():
                                    # This list contains the model itself, plus parameters for running it
                                    CNNmodel = self.CNNDicts[spInfo['CNN']
                                                             ['CNN_name']]

                            if self.method == "Click":
                                # bat-style CNN:
                                model = CNNmodel[0]
                                thr1 = CNNmodel[5][0]
                                thr2 = CNNmodel[5][1]
                                if click_label == 'Click':
                                    # we enter in the cnn only if we got a click
                                    sg_test = np.ndarray(
                                        shape=(np.shape(data_test)[0],
                                               np.shape(data_test[0][0])[0],
                                               np.shape(data_test[0][0])[1]),
                                        dtype=float)
                                    spec_id = []
                                    print('Number of file spectrograms = ',
                                          np.shape(data_test)[0])
                                    for j in range(np.shape(data_test)[0]):
                                        maxg = np.max(data_test[j][0][:])
                                        sg_test[
                                            j][:] = data_test[j][0][:] / maxg
                                        spec_id.append(data_test[j][1:3])

                                    # CNN classification of clicks
                                    x_test = sg_test
                                    test_images = x_test.reshape(
                                        x_test.shape[0], 6, 512, 1)
                                    test_images = test_images.astype('float32')

                                    # recovering labels
                                    predictions = model.predict(test_images)
                                    # predictions is an array #imagesX #of classes which entries are the probabilities for each class

                                    # Create a label (list of dicts with species, certs) for the single segment
                                    print('Assessing file label...')
                                    label = self.File_label(predictions,
                                                            thr1=thr1,
                                                            thr2=thr2)
                                    print('CNN detected: ', label)
                                    if len(label) > 0:
                                        # Convert the annotation into a full segment in self.segments
                                        thisPageStart = start / self.sampleRate
                                        self.makeSegments([
                                            thisPageStart, thisPageLen, label
                                        ])
                                else:
                                    # do not create any segments
                                    print("Nothing detected")
                            else:
                                # bird-style CNN and other processing:
                                post = Segment.PostProcess(
                                    configdir=self.configdir,
                                    audioData=self.audiodata[start:end],
                                    sampleRate=self.sampleRate,
                                    tgtsampleRate=spInfo["SampleRate"],
                                    segments=thisPageSegs[filtix],
                                    subfilter=spInfo['Filters'][filtix],
                                    CNNmodel=CNNmodel,
                                    cert=50)
                                print("Segments detected after WF: ",
                                      len(thisPageSegs[filtix]))
                                if self.wind and self.useWindF(
                                        spInfo['Filters'][filtix]['FreqRange']
                                    [0], spInfo['Filters'][filtix]['FreqRange']
                                    [1]):
                                    post.wind()

                                if CNNmodel:
                                    print('Post-processing with CNN')
                                    post.CNN()
                                if 'F0' in spInfo['Filters'][
                                        filtix] and 'F0Range' in spInfo[
                                            'Filters'][filtix]:
                                    if spInfo['Filters'][filtix]["F0"]:
                                        print(
                                            "Checking for fundamental frequency..."
                                        )
                                        post.fundamentalFrq()

                                post.joinGaps(maxgap=spInfo['Filters'][filtix]
                                              ['TimeRange'][3])
                                post.deleteShort(minlength=spInfo['Filters']
                                                 [filtix]['TimeRange'][0])

                                # adjust segment starts for 15min "pages"
                                if start != 0:
                                    for seg in post.segments:
                                        seg[0][0] += start / self.sampleRate
                                        seg[0][1] += start / self.sampleRate
                                # attach filter info and put on self.segments:
                                self.makeSegments(post.segments,
                                                  self.species[speciesix],
                                                  spInfo["species"],
                                                  spInfo['Filters'][filtix])
                        else:
                            # TODO: THIS IS testmode. NOT USING ANY BAT STUFF THEN
                            # I.E. testmode not adapted to bats
                            post = Segment.PostProcess(
                                configdir=self.configdir,
                                audioData=self.audiodata[start:end],
                                sampleRate=self.sampleRate,
                                tgtsampleRate=spInfo["SampleRate"],
                                segments=thisPageSegs[filtix],
                                subfilter=spInfo['Filters'][filtix],
                                CNNmodel=None,
                                cert=50)
                            # adjust segment starts for 15min "pages"
                            if start != 0:
                                for seg in post.segments:
                                    seg[0][0] += start / self.sampleRate
                                    seg[0][1] += start / self.sampleRate
                            # attach filter info and put on self.segments:
                            self.makeSegments(post.segments,
                                              self.species[speciesix],
                                              spInfo["species"],
                                              spInfo['Filters'][filtix])