コード例 #1
0
    def __call__(self, batch, device):
        """Transforms a batch and send it to a device

        :param list batch: The batch to transform
        :param torch.device device: The device to send to
        :return: a tuple xs_pad, ilens, ys_pad
        :rtype (torch.Tensor, torch.Tensor, torch.Tensor)
        """
        # batch should be located in list
        assert len(batch) == 1
        xs, ys = batch[0]
        ys = list(ys)  # Convert zip object to list in python 3.x

        # perform subsampling
        if self.subsampling_factor > 1:
            xs = [x[::self.subsampling_factor, :] for x in xs]

        # get batch of lengths of input sequences
        ilens = np.array([x.shape[0] for x in xs])

        # perform padding and convert to tensor
        xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(device)
        ilens = torch.from_numpy(ilens).to(device)
        ys_pad = [torch.from_numpy(y[0]).long() for y in ys] + [torch.from_numpy(y[1]).long() for y in ys]
        ys_pad = pad_list(ys_pad, self.ignore_id)
        ys_pad = ys_pad.view(2, -1, ys_pad.size(1)).transpose(0, 1).to(device)  # (num_spkrs, B, Tmax)

        return xs_pad, ilens, ys_pad
コード例 #2
0
ファイル: asr_mix.py プロジェクト: espnet/espnet
    def __call__(self, batch, device=torch.device("cpu")):
        """Transform a batch and send it to a device.

        Args:
            batch (list(tuple(str, dict[str, dict[str, Any]]))): The batch to transform.
            device (torch.device): The device to send to.

        Returns:
            tuple(torch.Tensor, torch.Tensor, torch.Tensor): Transformed batch.

        """
        # batch should be located in list
        assert len(batch) == 1
        xs, ys = batch[0][0], batch[0][-self.num_spkrs :]

        # perform subsampling
        if self.subsampling_factor > 1:
            xs = [x[:: self.subsampling_factor, :] for x in xs]

        # get batch of lengths of input sequences
        ilens = np.array([x.shape[0] for x in xs])

        # perform padding and convert to tensor
        # currently only support real number
        if xs[0].dtype.kind == "c":
            xs_pad_real = pad_list(
                [torch.from_numpy(x.real).float() for x in xs], 0
            ).to(device, dtype=self.dtype)
            xs_pad_imag = pad_list(
                [torch.from_numpy(x.imag).float() for x in xs], 0
            ).to(device, dtype=self.dtype)
            # Note(kamo):
            # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
            # Don't create ComplexTensor and give it to E2E here
            # because torch.nn.DataParallel can't handle it.
            xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
        else:
            xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
                device, dtype=self.dtype
            )

        ilens = torch.from_numpy(ilens).to(device)
        if not isinstance(ys[0], np.ndarray):
            ys_pad = []
            for i in range(len(ys)):  # speakers
                ys_pad += [torch.from_numpy(y).long() for y in ys[i]]
            ys_pad = pad_list(ys_pad, self.ignore_id)
            ys_pad = (
                ys_pad.view(self.num_spkrs, -1, ys_pad.size(1))
                .transpose(0, 1)
                .to(device)
            )  # (B, num_spkrs, Tmax)
        else:
            ys_pad = pad_list(
                [torch.from_numpy(y).long() for y in ys], self.ignore_id
            ).to(device)

        return xs_pad, ilens, ys_pad