コード例 #1
0
    def __init__(
        self,
        csv_file,
        sorting="random",
        reverb_prob=1.0,
        rir_scale_factor=1.0,
        replacements={},
    ):
        super().__init__()
        self.csv_file = csv_file
        self.sorting = sorting
        self.reverb_prob = reverb_prob
        self.replacements = replacements
        self.rir_scale_factor = rir_scale_factor

        # Create a data loader for the RIR waveforms
        dataset = ExtendedCSVDataset(
            csvpath=self.csv_file,
            sorting=self.sorting if self.sorting != "random" else "original",
            replacements=self.replacements,
        )
        self.data_loader = make_dataloader(dataset,
                                           shuffle=(self.sorting == "random"))
        self.rir_data = iter(self.data_loader)
コード例 #2
0
    def _load_noise(self, lengths, max_length):
        """Load a batch of noises"""
        lengths = lengths.long().squeeze(1)
        batch_size = len(lengths)

        # Load a noise batch
        if not hasattr(self, "data_loader"):
            # Set parameters based on input
            self.device = lengths.device

            # Create a data loader for the noise wavforms
            if self.csv_file is not None:
                dataset = ExtendedCSVDataset(
                    csvpath=self.csv_file,
                    output_keys=self.csv_keys,
                    sorting=self.sorting
                    if self.sorting != "random" else "original",
                    replacements=self.replacements,
                )
                self.data_loader = make_dataloader(
                    dataset,
                    batch_size=batch_size,
                    num_workers=self.num_workers,
                    shuffle=(self.sorting == "random"),
                )
                self.noise_data = iter(self.data_loader)

        # Load noise to correct device
        noise_batch, noise_len = self._load_noise_batch_of_size(batch_size)
        noise_batch = noise_batch.to(lengths.device)
        noise_len = noise_len.to(lengths.device)

        # Convert relative length to an index
        noise_len = (noise_len * noise_batch.shape[1]).long()

        # Ensure shortest wav can cover speech signal
        # WARNING: THIS COULD BE SLOW IF THERE ARE VERY SHORT NOISES
        if self.pad_noise:
            while torch.any(noise_len < lengths):
                min_len = torch.min(noise_len)
                prepend = noise_batch[:, :min_len]
                noise_batch = torch.cat((prepend, noise_batch), axis=1)
                noise_len += min_len

        # Ensure noise batch is long enough
        elif noise_batch.size(1) < max_length:
            padding = (0, max_length - noise_batch.size(1))
            noise_batch = torch.nn.functional.pad(noise_batch, padding)

        # Select a random starting location in the waveform
        start_index = self.start_index
        if self.start_index is None:
            start_index = 0
            max_chop = (noise_len - lengths).min().clamp(min=1)
            start_index = torch.randint(high=max_chop,
                                        size=(1, ),
                                        device=lengths.device)

        # Truncate noise_batch to max_length
        noise_batch = noise_batch[:, start_index:start_index + max_length]
        noise_len = (noise_len -
                     start_index).clamp(max=max_length).unsqueeze(1)
        return noise_batch, noise_len