def __init__(
        self,
        feature_extraction: Wrappable,
        protocol: Protocol,
        subset: Subset = "train",
        duration: float = 1.0,
        min_duration: float = None,
        per_turn: int = 1,
        per_label: int = 3,
        per_fold: Optional[int] = None,
        per_epoch: float = None,
        label_min_duration: float = 0.0,
    ):

        self.feature_extraction = Wrapper(feature_extraction)
        self.per_turn = per_turn
        self.per_label = per_label
        self.per_fold = per_fold
        self.duration = duration
        self.min_duration = duration if min_duration is None else min_duration
        self.label_min_duration = label_min_duration
        self.weighted_ = True

        total_duration = self._load_metadata(protocol, subset=subset)
        if per_epoch is None:
            per_epoch = total_duration / (24 * 60 * 60)
        self.per_epoch = per_epoch
Exemplo n.º 2
0
    def __init__(
        self,
        task: Task,
        feature_extraction: Wrappable,
        protocol: Protocol,
        subset: Subset = "train",
        resolution: Optional[Resolution] = None,
        alignment: Optional[Alignment] = None,
        duration: float = 2.0,
        batch_size: int = 32,
        per_epoch: float = None,
        exhaustive: bool = False,
        step: float = 0.1,
        mask: Text = None,
        local_labels: bool = False,
    ):

        self.task = task
        self.feature_extraction = Wrapper(feature_extraction)
        self.duration = duration
        self.exhaustive = exhaustive
        self.step = step
        self.mask = mask
        self.local_labels = local_labels

        self.resolution_ = resolution

        if alignment is None:
            alignment = "center"
        self.alignment = alignment

        self.batch_size = batch_size

        # load metadata and estimate total duration of training data
        total_duration = self._load_metadata(protocol, subset=subset)

        #
        if per_epoch is None:

            # 1 epoch = covering the whole training set once
            #
            per_epoch = total_duration / SECONDS_IN_A_DAY

            # when exhaustive is False, this is not completely correct.
            # in practice, it will randomly sample audio chunk until their
            # overall duration reaches the duration of the training set.
            # but nothing guarantees that every single part of the training set
            # has been seen exactly once: it might be more than once, it might
            # be less than once. on average, however, after a certain amount of
            # epoch, this will be correct

            # when exhaustive is True, however, we can actually make sure every
            # single part of the training set has been seen. we just have to
            # make sur we account for the step used by the exhaustive sliding
            # window
            if self.exhaustive:
                per_epoch *= np.ceil(1 / self.step)

        self.per_epoch = per_epoch
Exemplo n.º 3
0
    def __init__(self, embedding: Wrappable = None,
                       metric: Optional[str] = 'cosine',
                       method: Optional[str] = 'pool',
                       window_wise: Optional[bool] = False):
        super().__init__()

        if embedding is None:
            embedding = "@emb"
        self.embedding = embedding
        self._embedding = Wrapper(self.embedding)

        self.metric = metric
        self.method = method

        if self.method == 'affinity_propagation':
            self.clustering = AffinityPropagationClustering(
                metric=self.metric)

            # sklearn documentation: Preferences for each point - points with
            # larger values of preferences are more likely to be chosen as
            # exemplars. The number of exemplars, ie of clusters, is influenced by
            # the input preferences value. If the preferences are not passed as
            # arguments, they will be set to the median of the input similarities.

            # NOTE one could set the preference value of each speech turn
            # according to their duration. longer speech turns are expected to
            # have more accurate embeddings, therefore should be prefered for
            # exemplars

        else:
            self.clustering = HierarchicalAgglomerativeClustering(
                method=self.method, metric=self.metric, use_threshold=True)

        self.window_wise = window_wise
Exemplo n.º 4
0
    def __init__(self, batch_size=16, threshold=0.15, model="ami", device="cuda", progress_hook=None):
        if batch_size is None:
            batch_size = 16
        
        if threshold is None:
            threshold = 0.15

        self.batch_size = batch_size
        self.threshold = threshold
        self.model = model
        self.device = device
        self.progress_hook = progress_hook

        SAD_AMI_PATH = 'sad_ami/train/ami.train/weights/0140.pt'
        SAD_DIHARD_PATH = 'sad_dihard/train/dihard.train/weights/0231.pt'

        if self.model == "ami":
            model = SAD_AMI_PATH
        elif self.model == "dihard":
            model = SAD_DIHARD_PATH
        else:
            model = SAD_AMI_PATH

        self.pipeline = Wrapper(Path(Path(__file__).resolve().parent).joinpath(model), 
                    batch_size=self.batch_size, device=self.device, progress_hook=self.progress_hook)
    def __init__(self, embedding: Wrappable = None, metric: Optional[str] = "cosine"):
        super().__init__()

        if embedding is None:
            embedding = "@emb"
        self.embedding = embedding
        self._embedding = Wrapper(self.embedding)

        self.metric = metric

        self.closest_assignment = ClosestAssignment(metric=self.metric)
Exemplo n.º 6
0
    def __init__(
        self,
        sad: Union[Text, Path] = {"sad": {
            "duration": 2.0,
            "step": 0.1
        }},
        emb: Union[Text, Path] = "emb",
        batch_size: int = None,
        only_sad: bool = False,
    ):

        super().__init__()

        self.sad = Wrapper(sad)
        if batch_size is not None:
            self.sad.batch_size = batch_size
        self.sad_speech_index_ = self.sad.classes.index("speech")

        self.sad_threshold_on = Uniform(0.0, 1.0)
        self.sad_threshold_off = Uniform(0.0, 1.0)
        self.sad_min_duration_on = Uniform(0.0, 0.5)
        self.sad_min_duration_off = Uniform(0.0, 0.5)

        self.only_sad = only_sad
        if self.only_sad:
            return

        self.emb = Wrapper(emb)
        if batch_size is not None:
            self.emb.batch_size = batch_size

        max_duration = self.emb.duration
        min_duration = getattr(self.emb, "min_duration", 0.25 * max_duration)
        self.emb_duration = Uniform(min_duration, max_duration)
        self.emb_step_ratio = Uniform(0.1, 1.0)
        self.emb_threshold = Uniform(0.0, 2.0)
    def __init__(self, scores: Wrappable = None, fscore: bool = False):
        super().__init__()

        if scores is None:
            scores = "@sad_scores"
        self.scores = scores
        self._scores = Wrapper(self.scores)

        self.fscore = fscore

        # hyper-parameters
        self.onset = Uniform(0., 1.)
        self.offset = Uniform(0., 1.)
        self.min_duration_on = Uniform(0., 2.)
        self.min_duration_off = Uniform(0., 2.)
        self.pad_onset = Uniform(-1., 1.)
        self.pad_offset = Uniform(-1., 1.)
Exemplo n.º 8
0
    def __init__(self, scores: Wrappable = None,
                       purity: Optional[float] = 0.95,
                       fscore: bool = False,
                       diarization: bool = False):
        super().__init__()

        if scores is None:
            scores = "@scd_scores"
        self.scores = scores
        self._scores = Wrapper(self.scores)

        self.purity = purity
        self.fscore = fscore
        self.diarization = diarization

        # hyper-parameters
        self.alpha = Uniform(0., 1.)
        self.min_duration = Uniform(0., 10.)
    def __init__(self,
                 scores: Wrappable = None,
                 precision: float = 0.9,
                 fscore: bool = False):
        super().__init__()

        if scores is None:
            scores = "@ovl_scores"
        self.scores = scores
        self._scores = Wrapper(self.scores)

        self.precision = precision
        self.fscore = fscore

        # hyper-parameters
        self.onset = Uniform(0.0, 1.0)
        self.offset = Uniform(0.0, 1.0)
        self.min_duration_on = Uniform(0.0, 2.0)
        self.min_duration_off = Uniform(0.0, 2.0)
        self.pad_onset = Uniform(-1.0, 1.0)
        self.pad_offset = Uniform(-1.0, 1.0)
Exemplo n.º 10
0
class LabelingTaskGenerator(BatchGenerator):
    """Base batch generator for various labeling tasks

    This class should be inherited from: it should not be used directy

    Parameters
    ----------
    task : Task
        Task
    feature_extraction : Wrappable
        Describes how features should be obtained.
        See pyannote.audio.features.wrapper.Wrapper documentation for details.
    protocol : Protocol
    subset : {'train', 'development', 'test'}, optional
        Protocol and subset.
    resolution : `pyannote.core.SlidingWindow`, optional
        Override `feature_extraction.sliding_window`. This is useful for
        models that include the feature extraction step (e.g. SincNet) and
        therefore output a lower sample rate than that of the input.
        Defaults to `feature_extraction.sliding_window`
    alignment : {'center', 'loose', 'strict'}, optional
        Which mode to use when cropping labels. This is useful for models that
        include the feature extraction step (e.g. SincNet) and therefore use a
        different cropping mode. Defaults to 'center'.
    duration : float, optional
        Duration of audio chunks. Defaults to 2s.
    batch_size : int, optional
        Batch size. Defaults to 32.
    per_epoch : float, optional
        Force total audio duration per epoch, in days.
        Defaults to total duration of protocol subset.
    exhaustive : bool, optional
        Ensure training files are covered exhaustively (useful in case of
        non-uniform label distribution).
    step : `float`, optional
        Ratio of audio chunk duration used as step between two consecutive
        audio chunks. Defaults to 0.1. Has not effect when exhaustive is False.
    mask : str, optional
        When provided, protocol files are expected to contain a key named after
        this `mask` variable and providing a `SlidingWindowFeature` instance.
        Generated batches will contain an additional "mask" key (on top of
        existing "X" and "y" keys) computed as an excerpt of `current_file[mask]`
        time-aligned with "y". Defaults to not add any "mask" key.
    local_labels : bool, optional
        Set to True to yield samples with local (file-level) labels.
        Defaults to use global (protocol-level) labels.
    """

    def __init__(
        self,
        task: Task,
        feature_extraction: Wrappable,
        protocol: Protocol,
        subset: Subset = "train",
        resolution: Optional[Resolution] = None,
        alignment: Optional[Alignment] = None,
        duration: float = 2.0,
        batch_size: int = 32,
        per_epoch: float = None,
        exhaustive: bool = False,
        step: float = 0.1,
        mask: Text = None,
        local_labels: bool = False,
    ):

        self.task = task
        self.feature_extraction = Wrapper(feature_extraction)
        self.duration = duration
        self.exhaustive = exhaustive
        self.step = step
        self.mask = mask
        self.local_labels = local_labels

        self.resolution_ = resolution

        if alignment is None:
            alignment = "center"
        self.alignment = alignment

        self.batch_size = batch_size

        # load metadata and estimate total duration of training data
        total_duration = self._load_metadata(protocol, subset=subset)

        #
        if per_epoch is None:

            # 1 epoch = covering the whole training set once
            #
            per_epoch = total_duration / SECONDS_IN_A_DAY

            # when exhaustive is False, this is not completely correct.
            # in practice, it will randomly sample audio chunk until their
            # overall duration reaches the duration of the training set.
            # but nothing guarantees that every single part of the training set
            # has been seen exactly once: it might be more than once, it might
            # be less than once. on average, however, after a certain amount of
            # epoch, this will be correct

            # when exhaustive is True, however, we can actually make sure every
            # single part of the training set has been seen. we just have to
            # make sur we account for the step used by the exhaustive sliding
            # window
            if self.exhaustive:
                per_epoch *= np.ceil(1 / self.step)

        self.per_epoch = per_epoch

    # TODO. use cached property (Python 3.8 only)
    # https://docs.python.org/fr/3/library/functools.html#functools.cached_property
    @property
    def resolution(self):

        if self.resolution_ in [None, RESOLUTION_FRAME]:
            return self.feature_extraction.sliding_window

        if self.resolution_ == RESOLUTION_CHUNK:
            return self.SlidingWindow(
                duration=self.duration, step=self.step * self.duration
            )

        return self.resolution_

    def postprocess_y(self, Y: np.ndarray) -> np.ndarray:
        """This function does nothing but return its input.
        It should be overriden by subclasses.

        Parameters
        ----------
        Y : (n_samples, n_speakers) numpy.ndarray

        Returns
        -------
        postprocessed :

        """
        return Y

    def initialize_y(self, current_file):
        """Precompute y for the whole file

        Parameters
        ----------
        current_file : `dict`
            File as provided by a pyannote.database protocol.

        Returns
        -------
        y : `SlidingWindowFeature`
            Precomputed y for the whole file
        """

        if self.local_labels:
            labels = current_file["annotation"].labels()
        else:
            labels = self.segment_labels_

        y = one_hot_encoding(
            current_file["annotation"],
            get_annotated(current_file),
            self.resolution,
            labels=labels,
            mode="center",
        )

        y.data = self.postprocess_y(y.data)
        return y

    def crop_y(self, y, segment):
        """Extract y for specified segment

        Parameters
        ----------
        y : `pyannote.core.SlidingWindowFeature`
            Output of `initialize_y` above.
        segment : `pyannote.core.Segment`
            Segment for which to obtain y.

        Returns
        -------
        cropped_y : (n_samples, dim) `np.ndarray`
            y for specified `segment`
        """

        return y.crop(segment, mode=self.alignment, fixed=self.duration)

    def _load_metadata(self, protocol, subset: Subset = "train") -> float:
        """Load training set metadata

        This function is called once at instantiation time, returns the total
        training set duration, and populates the following attributes:

        Attributes
        ----------
        data_ : dict

            {'segments': <list of annotated segments>,
             'duration': <total duration of annotated segments>,
             'current_file': <protocol dictionary>,
             'y': <labels as numpy array>}

        segment_labels_ : list
            Sorted list of (unique) labels in protocol.

        file_labels_ : dict of list
            Sorted lists of (unique) file labels in protocol

        Returns
        -------
        duration : float
            Total duration of annotated segments, in seconds.
        """

        self.data_ = {}
        segment_labels, file_labels = set(), dict()

        # loop once on all files
        files = getattr(protocol, subset)()
        for current_file in tqdm(files, desc="Loading labels", unit="file"):

            # ensure annotation/annotated are cropped to actual file duration
            support = Segment(start=0, end=current_file["duration"])
            current_file["annotated"] = get_annotated(current_file).crop(
                support, mode="intersection"
            )
            current_file["annotation"] = current_file["annotation"].crop(
                support, mode="intersection"
            )

            # keep track of unique segment labels
            segment_labels.update(current_file["annotation"].labels())

            # keep track of unique file labels
            for key, value in current_file.items():
                if isinstance(value, (Annotation, Timeline, SlidingWindowFeature)):
                    continue
                if key not in file_labels:
                    file_labels[key] = set()
                file_labels[key].add(value)

            segments = [
                s for s in current_file["annotated"] if s.duration > self.duration
            ]

            # corner case where no segment is long enough
            # and we removed them all...
            if not segments:
                continue

            # total duration of label in current_file (after removal of
            # short segments).
            duration = sum(s.duration for s in segments)

            # store all these in data_ dictionary
            datum = {
                "segments": segments,
                "duration": duration,
                "current_file": current_file,
            }
            uri = get_unique_identifier(current_file)
            self.data_[uri] = datum

        self.file_labels_ = {k: sorted(file_labels[k]) for k in file_labels}
        self.segment_labels_ = sorted(segment_labels)

        for uri in list(self.data_):
            current_file = self.data_[uri]["current_file"]
            y = self.initialize_y(current_file)
            self.data_[uri]["y"] = y
            if self.mask is not None:
                mask = current_file[self.mask]
                current_file[self.mask] = mask.align(y)

        return sum(datum["duration"] for datum in self.data_.values())

    @property
    def specifications(self):
        """Task & sample specifications

        Returns
        -------
        specs : `dict`
            ['task'] (`pyannote.audio.train.Task`) : task
            ['X']['dimension'] (`int`) : features dimension
            ['y']['classes'] (`list`) : list of classes
        """

        specs = {
            "task": self.task,
            "X": {"dimension": self.feature_extraction.dimension},
        }

        if not self.local_labels:
            specs["y"] = {"classes": self.segment_labels_}

        return specs

    def samples(self):
        if self.exhaustive:
            return self._sliding_samples()
        else:
            return self._random_samples()

    def _random_samples(self):
        """Random samples

        Returns
        -------
        samples : generator
            Generator that yields {'X': ..., 'y': ...} samples indefinitely.
        """

        uris = list(self.data_)
        durations = np.array([self.data_[uri]["duration"] for uri in uris])
        probabilities = durations / np.sum(durations)

        while True:

            # choose file at random with probability
            # proportional to its (annotated) duration
            uri = uris[np.random.choice(len(uris), p=probabilities)]

            datum = self.data_[uri]
            current_file = datum["current_file"]

            # choose one segment at random with probability
            # proportional to its duration
            segment = next(random_segment(datum["segments"], weighted=True))

            # choose fixed-duration subsegment at random
            subsegment = next(random_subsegment(segment, self.duration))

            X = self.feature_extraction.crop(
                current_file, subsegment, mode="center", fixed=self.duration
            )

            y = self.crop_y(datum["y"], subsegment)
            sample = {"X": X, "y": y}

            if self.mask is not None:
                mask = self.crop_y(current_file[self.mask], subsegment)
                sample["mask"] = mask

            for key, classes in self.file_labels_.items():
                sample[key] = classes.index(current_file[key])

            yield sample

    def _sliding_samples(self):

        uris = list(self.data_)
        durations = np.array([self.data_[uri]["duration"] for uri in uris])
        probabilities = durations / np.sum(durations)
        sliding_segments = SlidingWindow(
            duration=self.duration, step=self.step * self.duration
        )

        while True:

            np.random.shuffle(uris)

            # loop on all files
            for uri in uris:

                datum = self.data_[uri]

                # make a copy of current file
                current_file = dict(datum["current_file"])

                # compute features for the whole file
                features = self.feature_extraction(current_file)

                # randomly shift 'annotated' segments start time so that
                # we avoid generating exactly the same subsequence twice
                annotated = Timeline()
                for segment in get_annotated(current_file):
                    shifted_segment = Segment(
                        segment.start + np.random.random() * self.duration, segment.end
                    )
                    if shifted_segment:
                        annotated.add(shifted_segment)

                samples = []
                for sequence in sliding_segments(annotated):

                    X = features.crop(sequence, mode="center", fixed=self.duration)
                    y = self.crop_y(datum["y"], sequence)
                    sample = {"X": X, "y": y}

                    if self.mask is not None:

                        # extract mask for current sub-segment
                        mask = current_file[self.mask].crop(
                            sequence, mode="center", fixed=self.duration
                        )

                        # it might happen that "mask" and "y" use different
                        # sliding windows. therefore, we simply resample "mask"
                        # to match "y"
                        if len(mask) != len(y):
                            mask = scipy.signal.resample(mask, len(y), axis=0)
                        sample["mask"] = mask

                    for key, classes in self.file_labels_.items():
                        sample[key] = classes.index(current_file[key])

                    samples.append(sample)

                np.random.shuffle(samples)
                for sample in samples:
                    yield sample

    @property
    def batches_per_epoch(self):
        """Number of batches needed to complete an epoch"""
        duration_per_epoch = self.per_epoch * SECONDS_IN_A_DAY
        duration_per_batch = self.duration * self.batch_size
        return int(np.ceil(duration_per_epoch / duration_per_batch))
Exemplo n.º 11
0
class SpeechSegmentGenerator(BatchGenerator):
    """Generate batch of pure speech segments with associated speaker labels

    Parameters
    ----------
    feature_extraction : `pyannote.audio.features.FeatureExtraction`
        Feature extraction.
    protocol : `pyannote.database.Protocol`
    subset : {'train', 'development', 'test'}
    duration : float, optional
        Chunks duration, in seconds. Defaults to 1.
    min_duration : float, optional
        When provided, generate chunks of random duration between `min_duration`
        and `duration`. All chunks in a batch will still use the same duration.
        Defaults to generating fixed duration chunks.
    per_turn : int, optional
        Number of chunks per speech turn. Defaults to 1.
    per_label : int, optional
        Number of speech turns per speaker in each batch.
        Defaults to 3.
    per_fold : int, optional
        Number of different speakers in each batch.
        Defaults to all speakers.
    per_epoch : float, optional
        Force total audio duration per epoch, in days.
        Defaults to total duration of protocol subset.
    label_min_duration : float, optional
        Remove speakers with less than `label_min_duration` seconds of speech.
        Defaults to 0 (i.e. keep it all).
    """
    def __init__(
        self,
        feature_extraction: Wrappable,
        protocol: Protocol,
        subset: Subset = "train",
        duration: float = 1.0,
        min_duration: float = None,
        per_turn: int = 1,
        per_label: int = 3,
        per_fold: Optional[int] = None,
        per_epoch: float = None,
        label_min_duration: float = 0.0,
    ):

        self.feature_extraction = Wrapper(feature_extraction)
        self.per_turn = per_turn
        self.per_label = per_label
        self.per_fold = per_fold
        self.duration = duration
        self.min_duration = duration if min_duration is None else min_duration
        self.label_min_duration = label_min_duration
        self.weighted_ = True

        total_duration = self._load_metadata(protocol, subset=subset)
        if per_epoch is None:
            per_epoch = total_duration / (24 * 60 * 60)
        self.per_epoch = per_epoch

    def _load_metadata(self,
                       protocol: Protocol,
                       subset: Subset = "train") -> float:
        """Load training set metadata

        This function is called once at instantiation time, returns the total
        training set duration, and populates the following attributes:

        Attributes
        ----------
        data_ : dict
            Dictionary where keys are speaker labels and values are lists of
            (segments, duration, current_file) tuples where
            - segments is a list of segments by the speaker in the file
            - duration is total duration of speech by the speaker in the file
            - current_file is the file (as ProtocolFile)

        segment_labels_ : list
            Sorted list of (unique) labels in protocol.

        file_labels_ : dict of list
            Sorted lists of (unique) file-level labels in protocol

        Returns
        -------
        duration : float
            Total duration of annotated segments, in seconds.
        """

        self.data_ = {}
        segment_labels, file_labels = set(), dict()

        # loop once on all files
        files = getattr(protocol, subset)()
        for current_file in tqdm(files, desc="Loading labels", unit="file"):

            # keep track of unique file labels
            for key in current_file:
                if key in ["annotation", "annotated", "audio", "duration"]:
                    continue
                if key not in file_labels:
                    file_labels[key] = set()
                file_labels[key].add(current_file[key])

            # get annotation for current file
            # ensure annotation is cropped to actual file duration
            support = Segment(start=0, end=current_file["duration"])
            current_file["annotation"] = current_file["annotation"].crop(
                support, mode="intersection")
            annotation = current_file["annotation"]

            # loop on each label in current file
            for label in annotation.labels():

                # get all segments with current label
                timeline = annotation.label_timeline(label)

                # remove segments shorter than maximum chunk duration
                segments = [s for s in timeline if s.duration > self.duration]

                # corner case where no segment is long enough
                # and we removed them all...
                if not segments:
                    continue

                # total duration of label in current_file (after removal of
                # short segments).
                duration = sum(s.duration for s in segments)

                # store all these in data_ dictionary
                # datum = (segment_generator, duration, current_file, features)
                datum = (segments, duration, current_file)
                self.data_.setdefault(label, []).append(datum)

        # remove labels with less than 'label_min_duration' of speech
        # otherwise those may generate the same segments over and over again
        dropped_labels = set()
        for label, data in self.data_.items():
            total_duration = sum(datum[1] for datum in data)
            if total_duration < self.label_min_duration:
                dropped_labels.add(label)

        for label in dropped_labels:
            self.data_.pop(label)

        self.file_labels_ = {k: sorted(file_labels[k]) for k in file_labels}
        self.segment_labels_ = sorted(self.data_)

        return sum(
            sum(datum[1] for datum in data) for data in self.data_.values())

    def samples(self):

        labels = list(self.data_)

        # batch_counter counts samples in current batch.
        # as soon as it reaches batch_size, a new random duration is selected
        # so that the next batch will use a different chunk duration
        batch_counter = 0
        batch_size = self.batch_size
        batch_duration = self.min_duration + np.random.rand() * (
            self.duration - self.min_duration)

        while True:

            # shuffle labels
            np.random.shuffle(labels)

            # loop on each label
            for label in labels:

                # load data for this label
                # segment_generators, durations, files, features = \
                #     zip(*self.data_[label])
                segments, durations, files = zip(*self.data_[label])

                # choose 'per_label' files at random with probability
                # proportional to the total duration of 'label' in those files
                probabilities = durations / np.sum(durations)
                chosen = np.random.choice(len(files),
                                          size=self.per_label,
                                          p=probabilities)

                # loop on (randomly) chosen files
                for i in chosen:

                    # choose one segment at random with
                    # probability proportional to duration
                    # segment = next(segment_generators[i])
                    segment = next(
                        random_segment(segments[i], weighted=self.weighted_))

                    # choose per_turn chunk(s) at random
                    for chunk in itertools.islice(
                            random_subsegment(segment, batch_duration),
                            self.per_turn):

                        yield {
                            "X":
                            self.feature_extraction.crop(files[i],
                                                         chunk,
                                                         mode="center",
                                                         fixed=batch_duration),
                            "y":
                            self.segment_labels_.index(label),
                        }

                        # increment number of samples in current batch
                        batch_counter += 1

                        # as soon as the batch is complete, a new random
                        # duration is selected so that the next batch will use
                        # a different chunk duration
                        if batch_counter == batch_size:
                            batch_counter = 0
                            batch_duration = self.min_duration + np.random.rand(
                            ) * (self.duration - self.min_duration)

    @property
    def batch_size(self) -> int:
        if self.per_fold is not None:
            return self.per_turn * self.per_label * self.per_fold
        return self.per_turn * self.per_label * len(self.data_)

    @property
    def batches_per_epoch(self) -> int:

        # duration per epoch
        duration_per_epoch = self.per_epoch * 24 * 60 * 60

        # (average) duration per batch
        duration_per_batch = 0.5 * (self.min_duration +
                                    self.duration) * self.batch_size

        # number of batches per epoch
        return int(np.ceil(duration_per_epoch / duration_per_batch))

    @property
    def specifications(self):
        return {
            "X": {
                "dimension": self.feature_extraction.dimension
            },
            "y": {
                "classes": self.segment_labels_
            },
            "task":
            Task(type=TaskType.REPRESENTATION_LEARNING,
                 output=TaskOutput.VECTOR),
        }
Exemplo n.º 12
0
def apply_pretrained(validate_dir: Path,
                     protocol_name: str,
                     subset: Optional[str] = "test",
                     duration: Optional[float] = None,
                     step: float = 0.25,
                     device: Optional[torch.device] = None,
                     batch_size: int = 32,
                     pretrained: Optional[str] = None,
                     Pipeline: type = None,
                     **kwargs):
    """Apply pre-trained model

    Parameters
    ----------
    validate_dir : Path
    protocol_name : `str`
    subset : 'train' | 'development' | 'test', optional
        Defaults to 'test'.
    duration : `float`, optional
    step : `float`, optional
    device : `torch.device`, optional
    batch_size : `int`, optional
    pretrained : `str`, optional
    Pipeline : `type`
    """

    if pretrained is None:
        pretrained = Pretrained(validate_dir=validate_dir,
                                duration=duration,
                                step=step,
                                batch_size=batch_size,
                                device=device)
        output_dir = validate_dir / 'apply' / f'{pretrained.epoch_:04d}'
    else:

        if pretrained in torch.hub.list('pyannote/pyannote-audio'):
            output_dir = validate_dir / pretrained
        else:
            output_dir = validate_dir

        pretrained = Wrapper(pretrained,
                             duration=duration,
                             step=step,
                             batch_size=batch_size,
                             device=device)

    params = {}
    try:
        params['classes'] = pretrained.classes
    except AttributeError as e:
        pass
    try:
        params['dimension'] = pretrained.dimension
    except AttributeError as e:
        pass

    # create metadata file at root that contains
    # sliding window and dimension information
    precomputed = Precomputed(root_dir=output_dir,
                              sliding_window=pretrained.sliding_window,
                              **params)

    # file generator
    protocol = get_protocol(protocol_name,
                            progress=True,
                            preprocessors=pretrained.preprocessors_)

    for current_file in getattr(protocol, subset)():
        fX = pretrained(current_file)
        precomputed.dump(current_file, fX)

    # do not proceed with the full pipeline
    # when there is no such thing for current task
    if Pipeline is None:
        return

    # do not proceed with the full pipeline when its parameters cannot be loaded.
    # this might happen when applying a model that has not been validated yet
    try:
        pipeline_params = pretrained.pipeline_params_
    except AttributeError as e:
        return

    # instantiate pipeline
    pipeline = Pipeline(scores=output_dir)
    pipeline.instantiate(pipeline_params)

    # load pipeline metric (when available)
    try:
        metric = pipeline.get_metric()
    except NotImplementedError as e:
        metric = None

    # apply pipeline and dump output to RTTM files
    output_rttm = output_dir / f'{protocol_name}.{subset}.rttm'
    with open(output_rttm, 'w') as fp:
        for current_file in getattr(protocol, subset)():
            hypothesis = pipeline(current_file)
            pipeline.write_rttm(fp, hypothesis)

            # compute evaluation metric (when possible)
            if 'annotation' not in current_file:
                metric = None

            # compute evaluation metric (when available)
            if metric is None:
                continue

            reference = current_file['annotation']
            uem = get_annotated(current_file)
            _ = metric(reference, hypothesis, uem=uem)

    # print pipeline metric (when available)
    if metric is None:
        return

    output_eval = output_dir / f'{protocol_name}.{subset}.eval'
    with open(output_eval, 'w') as fp:
        fp.write(str(metric))