def __call__(self, batch, device): """Transforms a batch and send it to a device :param list batch: The batch to transform :param torch.device device: The device to send to :return: a tuple xs_pad, ilens, ys_pad :rtype (torch.Tensor, torch.Tensor, torch.Tensor) """ # batch should be located in list assert len(batch) == 1 xs, ys = batch[0] ys = list(ys) # Convert zip object to list in python 3.x # perform subsampling if self.subsampling_factor > 1: xs = [x[::self.subsampling_factor, :] for x in xs] # get batch of lengths of input sequences ilens = np.array([x.shape[0] for x in xs]) # perform padding and convert to tensor xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(device) ilens = torch.from_numpy(ilens).to(device) ys_pad = [torch.from_numpy(y[0]).long() for y in ys] + [torch.from_numpy(y[1]).long() for y in ys] ys_pad = pad_list(ys_pad, self.ignore_id) ys_pad = ys_pad.view(2, -1, ys_pad.size(1)).transpose(0, 1).to(device) # (num_spkrs, B, Tmax) return xs_pad, ilens, ys_pad
def __call__(self, batch, device=torch.device("cpu")): """Transform a batch and send it to a device. Args: batch (list(tuple(str, dict[str, dict[str, Any]]))): The batch to transform. device (torch.device): The device to send to. Returns: tuple(torch.Tensor, torch.Tensor, torch.Tensor): Transformed batch. """ # batch should be located in list assert len(batch) == 1 xs, ys = batch[0][0], batch[0][-self.num_spkrs :] # perform subsampling if self.subsampling_factor > 1: xs = [x[:: self.subsampling_factor, :] for x in xs] # get batch of lengths of input sequences ilens = np.array([x.shape[0] for x in xs]) # perform padding and convert to tensor # currently only support real number if xs[0].dtype.kind == "c": xs_pad_real = pad_list( [torch.from_numpy(x.real).float() for x in xs], 0 ).to(device, dtype=self.dtype) xs_pad_imag = pad_list( [torch.from_numpy(x.imag).float() for x in xs], 0 ).to(device, dtype=self.dtype) # Note(kamo): # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E. # Don't create ComplexTensor and give it to E2E here # because torch.nn.DataParallel can't handle it. xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag} else: xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to( device, dtype=self.dtype ) ilens = torch.from_numpy(ilens).to(device) if not isinstance(ys[0], np.ndarray): ys_pad = [] for i in range(len(ys)): # speakers ys_pad += [torch.from_numpy(y).long() for y in ys[i]] ys_pad = pad_list(ys_pad, self.ignore_id) ys_pad = ( ys_pad.view(self.num_spkrs, -1, ys_pad.size(1)) .transpose(0, 1) .to(device) ) # (B, num_spkrs, Tmax) else: ys_pad = pad_list( [torch.from_numpy(y).long() for y in ys], self.ignore_id ).to(device) return xs_pad, ilens, ys_pad