Example #1
0
 def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
     validate(cuts)
     cuts = cuts.sort_by_duration()
     for tfnm in self.cut_transforms:
         cuts = tfnm(cuts)
     inputs, input_lens = self.input_strategy(cuts)
     for tfnm in self.input_transforms:
         inputs = tfnm(inputs)
     return {
         "inputs": inputs,
         "input_lens": input_lens,
         "is_voice": self.input_strategy.supervision_masks(cuts),
         "cut": cuts,
     }
Example #2
0
    def __getitem__(self,
                    cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
        """
        Return a new batch, with the batch size automatically determined using the constraints
        of max_frames and max_cuts.
        """
        validate_for_asr(cuts)

        self.hdf5_fix.update()

        # Sort the cuts by duration so that the first one determines the batch time dimensions.
        cuts = cuts.sort_by_duration(ascending=False)

        # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
        # the supervision boundaries.
        for tnfm in self.cut_transforms:
            cuts = tnfm(cuts)

        # Get a tensor with batched feature matrices, shape (B, T, F)
        # Collation performs auto-padding, if necessary.
        input_tpl = self.input_strategy(cuts)
        if len(input_tpl) == 3:
            # An input strategy with fault tolerant audio reading mode.
            # "cuts" may be a subset of the original "cuts" variable,
            # that only has cuts for which we succesfully read the audio.
            inputs, _, cuts = input_tpl
        else:
            inputs, _ = input_tpl

        # Get a dict of tensors that encode the positional information about supervisions
        # in the batch of feature matrices. The tensors are named "sequence_idx",
        # "start_frame/sample" and "num_frames/samples".
        supervision_intervals = self.input_strategy.supervision_intervals(cuts)

        # Apply all available transforms on the inputs, i.e. either audio or features.
        # This could be feature extraction, global MVN, SpecAugment, etc.
        segments = torch.stack(list(supervision_intervals.values()), dim=1)
        for tnfm in self.input_transforms:
            inputs = tnfm(inputs, supervision_segments=segments)

        batch = {
            "inputs":
            inputs,
            "supervisions":
            default_collate([{
                "text": supervision.text,
            } for sequence_idx, cut in enumerate(cuts)
                             for supervision in cut.supervisions]),
        }
        # Update the 'supervisions' field with sequence_idx and start/num frames/samples
        batch["supervisions"].update(supervision_intervals)
        if self.return_cuts:
            batch["supervisions"]["cut"] = [
                cut for cut in cuts for sup in cut.supervisions
            ]

        has_word_alignments = all(
            s.alignment is not None and "word" in s.alignment for c in cuts
            for s in c.supervisions)
        if has_word_alignments:
            # TODO: might need to refactor BatchIO API to move the following conditional logic
            #       into these objects (e.g. use like: self.input_strategy.convert_timestamp(),
            #       that returns either num_frames or num_samples depending on the strategy).
            words, starts, ends = [], [], []
            frame_shift = cuts[0].frame_shift
            sampling_rate = cuts[0].sampling_rate
            if frame_shift is None:
                try:
                    frame_shift = self.input_strategy.extractor.frame_shift
                except AttributeError:
                    raise ValueError(
                        "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. "
                    )
            for c in cuts:
                for s in c.supervisions:
                    words.append(
                        [aliword.symbol for aliword in s.alignment["word"]])
                    starts.append([
                        compute_num_frames(
                            aliword.start,
                            frame_shift=frame_shift,
                            sampling_rate=sampling_rate,
                        ) for aliword in s.alignment["word"]
                    ])
                    ends.append([
                        compute_num_frames(
                            aliword.end,
                            frame_shift=frame_shift,
                            sampling_rate=sampling_rate,
                        ) for aliword in s.alignment["word"]
                    ])
            batch["supervisions"]["word"] = words
            batch["supervisions"]["word_start"] = starts
            batch["supervisions"]["word_end"] = ends

        return batch