Exemplo n.º 1
0
    def __call__(self, item):
        """Extract features

        Parameters
        ----------
        item : dict

        Returns
        -------
        features : SlidingWindowFeature

        """

        # --- load audio file
        y, sample_rate = read_audio(item,
                                    sample_rate=self.sample_rate,
                                    mono=True)

        data = self.process(y, sample_rate)

        if np.any(np.isnan(data)):
            uri = get_unique_identifier(item)
            msg = 'Features extracted from "{uri}" contain NaNs.'
            warnings.warn(msg.format(uri=uri))

        return SlidingWindowFeature(data.T, self.sliding_window_)
Exemplo n.º 2
0
    def __call__(self, item):

        path = self.get_path(self.root_dir, item)
        if not os.path.exists(path):
            uri = get_unique_identifier(item)
            print(uri)
            msg = 'No precomputed features for "{uri}".'
            raise PyannoteFeatureExtractionError(msg.format(uri=uri))

        f = h5py.File(path)
        data = np.array(f['array'])
        f.close()

        return SlidingWindowFeature(data, self.sliding_window_)
Exemplo n.º 3
0
    def files(self) -> Iterator[ProtocolFile]:
        """Iterate over all files in `protocol`"""

        # imported here to avoid circular imports
        from pyannote.database.util import get_unique_identifier

        yielded_uris = set()

        for method in [
            "development",
            "development_enrolment",
            "development_trial",
            "test",
            "test_enrolment",
            "test_trial",
            "train",
            "train_enrolment",
            "train_trial",
        ]:

            if not hasattr(self, method):
                continue

            def iterate():
                try:
                    for file in getattr(self, method)():
                        yield file
                except (AttributeError, NotImplementedError):
                    return

            for current_file in iterate():

                # skip "files" that do not contain a "uri" entry.
                # this happens for speaker verification trials that contain
                # two nested files "file1" and "file2"
                # see https://github.com/pyannote/pyannote-db-voxceleb/issues/4
                if "uri" not in current_file:
                    continue

                for current_file_ in current_file.files():

                    # corner case when the same file is yielded several times
                    uri = get_unique_identifier(current_file_)
                    if uri in yielded_uris:
                        continue

                    yield current_file_

                    yielded_uris.add(uri)
Exemplo n.º 4
0
    def __call__(self, item):

        path = Path(self.get_path(item))

        if not path.exists():
            uri = get_unique_identifier(item)
            msg = f'No precomputed features for "{uri}".'
            raise PyannoteFeatureExtractionError(msg)

        if self.use_memmap:
            data = np.load(str(path), mmap_mode='r')
        else:
            data = np.load(str(path))

        return SlidingWindowFeature(data, self.sliding_window_)
Exemplo n.º 5
0
    def __call__(self, item):

        path = Path(self.get_path(item))

        if not path.exists():
            uri = get_unique_identifier(item)
            msg = f'No precomputed features for "{uri}".'
            raise PyannoteFeatureExtractionError(msg)

        if self.use_memmap:
            data = np.load(str(path), mmap_mode='r')
        else:
            data = np.load(str(path))

        return SlidingWindowFeature(data, self.sliding_window_)
Exemplo n.º 6
0
    def _validation_set(self, protocol_name, subset='development'):
        # this generator is hacked to generate y_true
        # (which is stored in its internal preprocessed_ attribute)
        batch_generator = SpeechActivityDetectionBatchGenerator(
            self.feature_extraction_)
        batch_generator.cache_preprocessed_ = True

        # iterate over each test file and generate y_true
        protocol = get_protocol(protocol_name,
                                progress=False,
                                preprocessors=self.preprocessors_)
        file_generator = getattr(protocol, subset)()
        for current_file in file_generator:
            identifier = get_unique_identifier(current_file)
            batch_generator.preprocess(current_file, identifier=identifier)

        return batch_generator.preprocessed_['y']
Exemplo n.º 7
0
    def objective_function(parameters, beta=1.0):

        epoch, alpha = parameters

        weights_h5 = WEIGHTS_H5.format(epoch=epoch)
        sequence_embedding = SequenceEmbedding.from_disk(
            architecture_yml, weights_h5)

        segmentation = Segmentation(
            sequence_embedding, feature_extraction,
            duration=duration, step=0.100)

        if epoch not in predictions:
            predictions[epoch] = {}

        purity = SegmentationPurity()
        coverage = SegmentationCoverage()

        f, n = 0., 0
        for dev_file in getattr(protocol, subset)():

            uri = get_unique_identifier(dev_file)
            reference = dev_file['annotation']
            n += 1

            if uri in predictions[epoch]:
                prediction = predictions[epoch][uri]
            else:
                prediction = segmentation.apply(dev_file)
                predictions[epoch][uri] = prediction

            peak = Peak(alpha=alpha)
            hypothesis = peak.apply(prediction)

            p = purity(reference, hypothesis)
            c = coverage(reference, hypothesis)
            f += f_measure(c, p, beta=beta)

        return 1 - (f / n)
Exemplo n.º 8
0
 def get_path(root_dir, item):
     uri = get_unique_identifier(item)
     path = '{root_dir}/{uri}.htk'.format(root_dir=root_dir, uri=uri)
     return path
Exemplo n.º 9
0
    def from_files(self, file_generator, infinite=False,
                   robust=False, incomplete=False):
        """Generate batches by looping over a (possibly infinite) set of files

        Parameters
        ----------
        file_generator : iterable
            File generator yielding dictionaries at least containing the 'uri'
            key (uri = uniform resource identifier). Typically, one would use
            the 'train' method of a protocol available in pyannote.database.
        infinite : boolean, optional
            Loop over the file generator indefinitely, in random order.
            Defaults to exhaust the file generator only once, and then stop.
        robust : boolean, optional
            Set to True to skip files for which preprocessing fails.
            Default behavior is to raise an error.
        incomplete : boolean, optional
            Set to True to yield final batch, even if its incomplete (i.e.
            smaller than requested batch size). Default behavior is to not
            yield incomplete final batch. Has no effect when infinite is True.

        See also
        --------
        pyannote.database
        """

        # create new empty batch
        self.batch_ = self.init(self.signature)
        batch_size = 0

        if infinite:
            file_generator = forever(file_generator, shuffle=True)

        for current_file in file_generator:

            try:
                preprocessed_file = self.preprocess(current_file)
            except Exception as e:
                if robust:
                    uri = get_unique_identifier(current_file)
                    msg = 'Cannot preprocess file "{uri}".'
                    warnings.warn(msg.format(uri=uri))
                    continue
                else:
                    raise e

            for fragment in self.generator.from_file(preprocessed_file):

                # add item to batch
                self.push(fragment, self.signature,
                                current_file=preprocessed_file)
                batch_size += 1

                # fixed batch size
                if self.batch_size > 0 and batch_size == self.batch_size:
                    batch = self.pack(self.signature)
                    yield self.postprocess(batch)
                    self.batch_ = self.init(self.signature)
                    batch_size = 0

            # mono-batch
            if self.batch_size < 1:
                batch = self.pack(self.signature)
                yield self.postprocess(batch)
                self.batch_ = self.init(self.signature)
                batch_size = 0

        # yield incomplete final batch
        if batch_size > 0 and batch_size < self.batch_size and incomplete:
            batch = self.pack(self.signature)
            yield self.postprocess(batch)
Exemplo n.º 10
0
def embed(protocol, tune_dir, apply_dir, subset='test', step=None,
          internal=None, aggregate=False):

    mkdir_p(apply_dir)

    train_dir = os.path.dirname(os.path.dirname(tune_dir))

    duration, _, _, heterogeneous = \
        path_to_duration(os.path.basename(train_dir))

    config_dir = os.path.dirname(os.path.dirname(os.path.dirname(train_dir)))
    config_yml = config_dir + '/config.yml'
    with open(config_yml, 'r') as fp:
        config = yaml.load(fp)

    # -- FEATURE EXTRACTION --
    feature_extraction_name = config['feature_extraction']['name']
    features = __import__('pyannote.audio.features',
                          fromlist=[feature_extraction_name])
    FeatureExtraction = getattr(features, feature_extraction_name)
    feature_extraction = FeatureExtraction(
        **config['feature_extraction'].get('params', {}))

    # -- HYPER-PARAMETERS --
    tune_yml = tune_dir + '/tune.yml'
    with open(tune_yml, 'r') as fp:
        tune = yaml.load(fp)

    architecture_yml = train_dir + '/architecture.yml'
    WEIGHTS_H5 = train_dir + '/weights/{epoch:04d}.h5'
    weights_h5 = WEIGHTS_H5.format(epoch=tune['epoch'])

    sequence_embedding = SequenceEmbedding.from_disk(
        architecture_yml, weights_h5)

    extraction = Extraction(sequence_embedding, feature_extraction,
                            duration=duration, step=step,
                            internal=internal, aggregate=aggregate)

    dimension = extraction.dimension
    sliding_window = extraction.sliding_window

    # create metadata file at root that contains
    # sliding window and dimension information
    path = Precomputed.get_config_path(apply_dir)
    f = h5py.File(path)
    f.attrs['start'] = sliding_window.start
    f.attrs['duration'] = sliding_window.duration
    f.attrs['step'] = sliding_window.step
    f.attrs['dimension'] = dimension
    f.close()

    for item in getattr(protocol, subset)():

        uri = get_unique_identifier(item)
        path = Precomputed.get_path(apply_dir, item)

        extracted = extraction.apply(item)

        # create parent directory
        mkdir_p(os.path.dirname(path))

        f = h5py.File(path)
        f.attrs['start'] = sliding_window.start
        f.attrs['duration'] = sliding_window.duration
        f.attrs['step'] = sliding_window.step
        f.attrs['dimension'] = dimension
        f.create_dataset('features', data=extracted.data)
        f.close()
Exemplo n.º 11
0
 def get_path(root_dir, item):
     uri = get_unique_identifier(item)
     file_name = uri.split('.')[0].split('/')[1]
     path = '{root_dir}/{file_name}/audio/{file_name}.Mix-Headset.hdf5'.format(
         root_dir=root_dir, file_name=file_name)
     return path
Exemplo n.º 12
0
 def get_path(self, item):
     uri = get_unique_identifier(item)
     path = "{root_dir}/{uri}.npy".format(root_dir=self.root_dir, uri=uri)
     return path
Exemplo n.º 13
0
 def get_path(root_dir, item):
     uri = get_unique_identifier(item)
     path = '{root_dir}/{uri}.h5'.format(root_dir=root_dir, uri=uri)
     return path
Exemplo n.º 14
0
    def from_files(self,
                   file_generator,
                   infinite=False,
                   robust=False,
                   incomplete=False):
        """Generate batches by looping over a (possibly infinite) set of files

        Parameters
        ----------
        file_generator : iterable
            File generator yielding dictionaries at least containing the 'uri'
            key (uri = uniform resource identifier). Typically, one would use
            the 'train' method of a protocol available in pyannote.database.
        infinite : boolean, optional
            Loop over the file generator indefinitely, in random order.
            Defaults to exhaust the file generator only once, and then stop.
        robust : boolean, optional
            Set to True to skip files for which preprocessing fails.
            Default behavior is to raise an error.
        incomplete : boolean, optional
            Set to True to yield final batch, even if its incomplete (i.e.
            smaller than requested batch size). Default behavior is to not
            yield incomplete final batch. Has no effect when infinite is True.

        See also
        --------
        pyannote.database
        """

        # create new empty batch
        self.batch_ = self.init(self.signature)
        batch_size = 0

        if infinite:
            file_generator = forever(file_generator, shuffle=True)

        for current_file in file_generator:

            try:
                preprocessed_file = self.preprocess(current_file)
            except Exception as e:
                if robust:
                    uri = get_unique_identifier(current_file)
                    msg = 'Cannot preprocess file "{uri}".'
                    warnings.warn(msg.format(uri=uri))
                    continue
                else:
                    raise e

            for fragment in self.generator.from_file(preprocessed_file):

                # add item to batch
                self.push(fragment,
                          self.signature,
                          current_file=preprocessed_file)
                batch_size += 1

                # fixed batch size
                if self.batch_size > 0 and batch_size == self.batch_size:
                    batch = self.pack(self.signature)
                    yield self.postprocess(batch)
                    self.batch_ = self.init(self.signature)
                    batch_size = 0

            # mono-batch
            if self.batch_size < 1:
                batch = self.pack(self.signature)
                yield self.postprocess(batch)
                self.batch_ = self.init(self.signature)
                batch_size = 0

        # yield incomplete final batch
        if batch_size > 0 and batch_size < self.batch_size and incomplete:
            batch = self.pack(self.signature)
            yield self.postprocess(batch)
Exemplo n.º 15
0
    def __call__(self, item):
        """Extract features

        Parameters
        ----------
        item : dict

        Returns
        -------
        features : SlidingWindowFeature

        """

        # --- load audio file
        y, sample_rate = read_audio(item,
                                    sample_rate=self.sample_rate,
                                    mono=True)

        # --- update data_flow every time sample rate changes
        if not hasattr(self, 'sample_rate_') or self.sample_rate_ != sample_rate:
            self.sample_rate_ = sample_rate
            feature_plan = yaafelib.FeaturePlan(sample_rate=self.sample_rate_)
            for name, recipe in self.definition():
                assert feature_plan.addFeature(
                    "{name}: {recipe}".format(name=name, recipe=recipe))
            data_flow = feature_plan.getDataFlow()
            self.engine_.load(data_flow)

        # Yaafe needs this: float64, column-contiguous, 2-dimensional
        y = np.array(y, dtype=np.float64, order='C').reshape((1, -1))

        # --- extract features
        features = self.engine_.processAudio(y)
        data = np.hstack([features[name] for name, _ in self.definition()])

        # --- stack features
        n_samples, n_features = data.shape
        zero_padding = self.stack // 2
        if self.stack % 2 == 0:
            expanded_data = np.concatenate(
                (np.zeros((zero_padding, n_features)) + data[0],
                data,
                np.zeros((zero_padding - 1, n_features)) + data[-1]))
        else:
            expanded_data = np.concatenate((
                np.zeros((zero_padding, n_features)) + data[0],
                data,
                np.zeros((zero_padding, n_features)) + data[-1]))

        data = np.lib.stride_tricks.as_strided(
            expanded_data,
            shape=(n_samples, n_features * self.stack),
            strides=data.strides)

        self.engine_.reset()

        # --- return as SlidingWindowFeature
        if np.any(np.isnan(data)):
            uri = get_unique_identifier(item)
            msg = 'Features extracted from "{uri}" contain NaNs.'
            warnings.warn(msg.format(uri=uri))

        return SlidingWindowFeature(data, self.sliding_window_)
Exemplo n.º 16
0
    def __call__(self, item):
        """Extract features

        Parameters
        ----------
        item : dict

        Returns
        -------
        features : SlidingWindowFeature

        """

        # --- load audio file
        y, sample_rate = read_audio(item,
                                    sample_rate=self.sample_rate,
                                    mono=True)

        # --- update data_flow every time sample rate changes
        if not hasattr(self,
                       'sample_rate_') or self.sample_rate_ != sample_rate:
            self.sample_rate_ = sample_rate
            feature_plan = yaafelib.FeaturePlan(sample_rate=self.sample_rate_)
            for name, recipe in self.definition():
                assert feature_plan.addFeature("{name}: {recipe}".format(
                    name=name, recipe=recipe))
            data_flow = feature_plan.getDataFlow()
            self.engine_.load(data_flow)

        # Yaafe needs this: float64, column-contiguous, 2-dimensional
        y = np.array(y, dtype=np.float64, order='C').reshape((1, -1))

        # --- extract features
        features = self.engine_.processAudio(y)
        data = np.hstack([features[name] for name, _ in self.definition()])

        # --- stack features
        n_samples, n_features = data.shape
        zero_padding = self.stack // 2
        if self.stack % 2 == 0:
            expanded_data = np.concatenate(
                (np.zeros((zero_padding, n_features)) + data[0], data,
                 np.zeros((zero_padding - 1, n_features)) + data[-1]))
        else:
            expanded_data = np.concatenate(
                (np.zeros((zero_padding, n_features)) + data[0], data,
                 np.zeros((zero_padding, n_features)) + data[-1]))

        data = np.lib.stride_tricks.as_strided(expanded_data,
                                               shape=(n_samples,
                                                      n_features * self.stack),
                                               strides=data.strides)

        self.engine_.reset()

        # --- return as SlidingWindowFeature
        if np.any(np.isnan(data)):
            uri = get_unique_identifier(item)
            msg = 'Features extracted from "{uri}" contain NaNs.'
            warnings.warn(msg.format(uri=uri))

        return SlidingWindowFeature(data, self.sliding_window_)
Exemplo n.º 17
0
def extract(database_name,
            task_name,
            protocol_name,
            preprocessors,
            experiment_dir,
            robust=False):

    database = get_database(database_name, preprocessors=preprocessors)
    protocol = database.get_protocol(task_name, protocol_name, progress=True)

    if task_name == 'SpeakerDiarization':
        items = itertools.chain(protocol.train(), protocol.development(),
                                protocol.test())

    elif task_name == 'SpeakerRecognition':
        items = itertools.chain(protocol.train(yield_name=False),
                                protocol.development_enroll(yield_name=False),
                                protocol.development_test(yield_name=False),
                                protocol.test_enroll(yield_name=False),
                                protocol.test_test(yield_name=False))

    # load configuration file
    config_yml = experiment_dir + '/config.yml'
    with open(config_yml, 'r') as fp:
        config = yaml.load(fp)

    feature_extraction_name = config['feature_extraction']['name']
    features = __import__('pyannote.audio.features',
                          fromlist=[feature_extraction_name])
    FeatureExtraction = getattr(features, feature_extraction_name)
    feature_extraction = FeatureExtraction(
        **config['feature_extraction'].get('params', {}))

    sliding_window = feature_extraction.sliding_window()
    dimension = feature_extraction.dimension()

    # create metadata file at root that contains
    # sliding window and dimension information
    path = Precomputed.get_config_path(experiment_dir)
    f = h5py.File(path)
    f.attrs['start'] = sliding_window.start
    f.attrs['duration'] = sliding_window.duration
    f.attrs['step'] = sliding_window.step
    f.attrs['dimension'] = dimension
    f.close()

    for item in items:

        uri = get_unique_identifier(item)
        path = Precomputed.get_path(experiment_dir, item)

        if os.path.exists(path):
            continue

        try:
            # NOTE item contains the 'channel' key
            features = feature_extraction(item)
        except PyannoteFeatureExtractionError as e:
            if robust:
                msg = 'Feature extraction failed for file "{uri}".'
                msg = msg.format(uri=uri)
                warnings.warn(msg)
                continue
            else:
                raise e

        if features is None:
            msg = 'Feature extraction returned None for file "{uri}".'
            msg = msg.format(uri=uri)
            if not robust:
                raise PyannoteFeatureExtractionError(msg)
            warnings.warn(msg)
            continue

        data = features.data

        if np.any(np.isnan(data)):
            msg = 'Feature extraction returned NaNs for file "{uri}".'
            msg = msg.format(uri=uri)
            if not robust:
                raise PyannoteFeatureExtractionError(msg)
            warnings.warn(msg)
            continue

        # create parent directory
        mkdir_p(os.path.dirname(path))

        f = h5py.File(path)
        f.attrs['start'] = sliding_window.start
        f.attrs['duration'] = sliding_window.duration
        f.attrs['step'] = sliding_window.step
        f.attrs['dimension'] = dimension
        f.create_dataset('features', data=data)
        f.close()
Exemplo n.º 18
0
    def validate(self, protocol_name, subset='development'):

        # prepare paths
        validate_dir = self.VALIDATE_DIR.format(train_dir=self.train_dir_,
                                                protocol=protocol_name)
        validate_txt = self.VALIDATE_TXT.format(validate_dir=validate_dir,
                                                subset=subset)
        validate_png = self.VALIDATE_PNG.format(validate_dir=validate_dir,
                                                subset=subset)
        validate_eps = self.VALIDATE_EPS.format(validate_dir=validate_dir,
                                                subset=subset)

        # create validation directory
        mkdir_p(validate_dir)

        # Build validation set
        y = self._validation_set(protocol_name, subset=subset)

        # list of equal error rates, and current epoch
        eers, epoch = [], 0

        desc_format = ('EER = {eer:.2f}% @ epoch #{epoch:d} ::'
                       ' Best EER = {best_eer:.2f}% @ epoch #{best_epoch:d} :')
        progress_bar = tqdm(unit='epoch', total=1000)

        with open(validate_txt, mode='w') as fp:

            # watch and evaluate forever
            while True:

                weights_h5 = LoggingCallback.WEIGHTS_H5.format(
                    log_dir=self.train_dir_, epoch=epoch)

                # wait until weight file is available
                if not isfile(weights_h5):
                    time.sleep(60)
                    continue

                # load model for current epoch
                sequence_labeling = SequenceLabeling.from_disk(
                    self.train_dir_, epoch)

                # initialize sequence labeling
                duration = self.config_['sequences']['duration']
                step = duration  # hack to make things faster
                # step = self.config_['sequences']['step']
                aggregation = SequenceLabelingAggregation(
                    sequence_labeling,
                    self.feature_extraction_,
                    duration=duration,
                    step=step)
                aggregation.cache_preprocessed_ = False

                # estimate equal error rate (average of all files)
                eers_ = []
                protocol = get_protocol(protocol_name,
                                        progress=False,
                                        preprocessors=self.preprocessors_)
                file_generator = getattr(protocol, subset)()
                for current_file in file_generator:
                    identifier = get_unique_identifier(current_file)
                    uem = get_annotated(current_file)
                    y_true = y[identifier].crop(uem)[:, 1]
                    counts = Counter(y_true)
                    if counts[0] * counts[1] == 0:
                        continue
                    y_pred = aggregation.apply(current_file).crop(uem)[:, 1]

                    _, _, _, eer = det_curve(y_true, y_pred, distances=False)

                    eers_.append(eer)
                eer = np.mean(eers_)
                eers.append(eer)

                # save equal error rate to file
                fp.write(
                    self.VALIDATE_TXT_TEMPLATE.format(epoch=epoch, eer=eer))
                fp.flush()

                # keep track of best epoch so far
                best_epoch, best_eer = np.argmin(eers), np.min(eers)

                progress_bar.set_description(
                    desc_format.format(epoch=epoch,
                                       eer=100 * eer,
                                       best_epoch=best_epoch,
                                       best_eer=100 * best_eer))
                progress_bar.update(1)

                # plot
                fig = plt.figure()
                plt.plot(eers, 'b')
                plt.plot([best_epoch], [best_eer], 'bo')
                plt.plot([0, epoch], [best_eer, best_eer], 'k--')
                plt.grid(True)
                plt.xlabel('epoch')
                plt.ylabel('EER on {subset}'.format(subset=subset))
                TITLE = '{best_eer:.5g} @ epoch #{best_epoch:d}'
                title = TITLE.format(best_eer=best_eer,
                                     best_epoch=best_epoch,
                                     subset=subset)
                plt.title(title)
                plt.tight_layout()
                plt.savefig(validate_png, dpi=75)
                plt.savefig(validate_eps)
                plt.close(fig)

                # validate next epoch
                epoch += 1

        progress_bar.close()
Exemplo n.º 19
0
def test(protocol, tune_dir, apply_dir, subset='test', beta=1.0):

    os.makedirs(apply_dir)

    train_dir = os.path.dirname(os.path.dirname(os.path.dirname(tune_dir)))

    duration = float(os.path.basename(train_dir))
    config_dir = os.path.dirname(os.path.dirname(os.path.dirname(train_dir)))
    config_yml = config_dir + '/config.yml'
    with open(config_yml, 'r') as fp:
        config = yaml.load(fp)

    # -- FEATURE EXTRACTION --
    feature_extraction_name = config['feature_extraction']['name']
    features = __import__('pyannote.audio.features',
                          fromlist=[feature_extraction_name])
    FeatureExtraction = getattr(features, feature_extraction_name)
    feature_extraction = FeatureExtraction(
        **config['feature_extraction'].get('params', {}))

    # -- HYPER-PARAMETERS --
    tune_yml = tune_dir + '/tune.yml'
    with open(tune_yml, 'r') as fp:
        tune = yaml.load(fp)

    architecture_yml = train_dir + '/architecture.yml'
    WEIGHTS_H5 = train_dir + '/weights/{epoch:04d}.h5'
    weights_h5 = WEIGHTS_H5.format(epoch=tune['epoch'])

    sequence_embedding = SequenceEmbedding.from_disk(
        architecture_yml, weights_h5)

    segmentation = Segmentation(
        sequence_embedding, feature_extraction,
        duration=duration, step=0.100)

    peak = Peak(alpha=tune['alpha'])

    HARD_JSON = apply_dir + '/{uri}.hard.json'
    SOFT_PKL = apply_dir + '/{uri}.soft.pkl'

    eval_txt = apply_dir + '/eval.txt'
    TEMPLATE = '{uri} {purity:.5f} {coverage:.5f} {f_measure:.5f}\n'
    purity = SegmentationPurity()
    coverage = SegmentationCoverage()
    fscore = []

    for test_file in getattr(protocol, subset)():

        soft = segmentation.apply(test_file)
        hard = peak.apply(soft)

        uri = get_unique_identifier(test_file)

        path = SOFT_PKL.format(uri=uri)
        mkdir_p(os.path.dirname(path))
        with open(path, 'w') as fp:
            pickle.dump(soft, fp)

        path = HARD_JSON.format(uri=uri)
        mkdir_p(os.path.dirname(path))
        with open(path, 'w') as fp:
            pyannote.core.json.dump(hard, fp)

        try:
            reference = test_file['annotation']
            uem = test_file['annotated']
        except KeyError as e:
            continue

        p = purity(reference, hard)
        c = coverage(reference, hard)
        f = f_measure(c, p, beta=beta)
        fscore.append(f)

        line = TEMPLATE.format(
            uri=uri, purity=p, coverage=c, f_measure=f)
        with open(eval_txt, 'a') as fp:
            fp.write(line)

    p = abs(purity)
    c = abs(coverage)
    f = np.mean(fscore)
    line = TEMPLATE.format(
        uri='ALL', purity=p, coverage=c, f_measure=f)
    with open(eval_txt, 'a') as fp:
        fp.write(line)
Exemplo n.º 20
0
 def get_path(self, item):
     uri = get_unique_identifier(item)
     path = '{root_dir}/{uri}.npy'.format(root_dir=self.root_dir, uri=uri)
     return path