Example #1
0
            def augmented_dynamics(t, y_aug):
                # Dynamics of the original system augmented with
                # the adjoint wrt y, and an integrator wrt t and args.
                y = y_aug[1]
                adj_y = y_aug[2]
                # ignore gradients wrt time and parameters

                with torch.enable_grad():
                    t_ = t.detach()
                    t = t_.requires_grad_(True)
                    y = y.detach().requires_grad_(True)

                    # If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which
                    # doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients
                    # wrt t here means we won't compute that if we don't need it.
                    func_eval = func(t if t_requires_grad else t_, y)

                    # Workaround for PyTorch bug #39784
                    _t = torch.as_strided(t, (), ())
                    _y = torch.as_strided(y, (), ())
                    _params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params)

                    vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
                        func_eval, (t, y) + adjoint_params, -adj_y,
                        allow_unused=True, retain_graph=True
                    )

                # autograd.grad returns None if no gradient, set to zero.
                vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
                vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
                vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
                              for param, vjp_param in zip(adjoint_params, vjp_params)]

                return (vjp_t, func_eval, vjp_y, *vjp_params)
Example #2
0
def complex_view(x, dim=-1, squeeze=True):
    r"""Returns a real and imaginary views into the complex tensor, assumed
    to be in interleaved layout (double-real, i.e. re-im). The returned tensor
    is half the size along the specified dimension and is not, in general,
    contiguous in memory.

    Arguments
    ---------
    x : torch.tensor
        Time series of measurement values
    dim : int
        Axis along which the periodogram is computed, i.e. ``dim=-1``.

    Returns
    -------
    real : torch.tensor
        The view into a real part of the tensor.
    imag : torch.tensor
        The view into a real part of the tensor.
    """
    dim = fix_dim(dim, x.dim())
    shape, strides = list(x.size()), list(x.stride())
    offset = x.storage_offset()

    # compute new shape and strides
    strided_size, rem = divmod(shape[dim], 2)
    if rem != 0:
        warnings.warn(
            f"Odd dimension size for the complex data unpacking: "
            f"taking the least size that fits.", RuntimeWarning)

    # new shape and stride structure
    if shape[dim] == 2 and squeeze:
        # if the complex dimension is exactly two, then just drop it
        shape_view = shape[:dim] + shape[dim + 1:]
        strides_view = strides[:dim] + strides[dim + 1:]

    else:
        # otherwise, half the size and double the stride
        size, rem = divmod(shape[dim], 2)
        shape_view = shape[:dim] + [size] + shape[dim + 1:]
        strides_view = strides[:dim] + [2 * strides[dim]] + strides[dim + 1:]

    # differentiable strided view into real and imaginary parts
    real = torch.as_strided(x, shape_view, strides_view, offset)
    imag = torch.as_strided(x, shape_view, strides_view, offset + strides[dim])

    return real, imag
Example #3
0
def preprocess_audio_batch(audio, sr, center=True, hop_size=0.1):
    if audio.dim() == 2:
        audio = torch.mean(audio, axis=1)

    if sr != TARGET_SR:
        audio = julius.resample_frac(audio, sr, TARGET_SR)

    audio_len = audio.size()[0]
    frame_len = TARGET_SR
    hop_len = int(hop_size * TARGET_SR)

    if center:
        audio = center_audio(audio, frame_len)

    audio = pad_audio(audio, frame_len, hop_len)

    n_frames = 1 + int((len(audio) - frame_len) / float(hop_len))
    x = torch.as_strided(
        audio,
        size=(frame_len, n_frames),
        stride=(1, hop_len),
    )
    x = torch.transpose(x, 0, 1)
    x = x.unsqueeze(1)
    return x
Example #4
0
def gen_sequence_batch(batch_size: int,
                       input_size: int,
                       min_len: int,
                       max_len: int,
                       device: str = 'cpu') -> List[Tensor]:
    """ Randomly generate one sequence batch on the device.

    Returns:
        A sequence batch, List[Tensor]. Each element of the returned list is a
        2D tensor with a shape of [sequence_length, input_dim].
    """
    random.seed(1234)
    torch.manual_seed(1234)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(1234)
    seq_len = [random.randint(min_len, max_len) for _ in range(batch_size)]
    batch = torch.randn(sum(seq_len), input_size, device=device)

    offset = 0
    batch_list = []
    for i in range(batch_size):
        a_seq = torch.as_strided(batch,
                                 size=(seq_len[i], input_size),
                                 stride=(input_size, 1),
                                 storage_offset=offset)
        offset += seq_len[i] * input_size
        batch_list.append(a_seq)
    return batch_list, seq_len
Example #5
0
    def test_mismatching_stride_no_check(self, device):
        a = torch.rand((2, 2), device=device)
        b = torch.as_strided(a.clone().t().contiguous(), a.shape,
                             a.stride()[::-1])

        for fn in self.assert_fns():
            fn(a, b, check_stride=False)
Example #6
0
 def test_mismatching_stride_no_check(self, device):
     actual = torch.rand((2, 2), device=device)
     expected = torch.as_strided(actual.clone().t().contiguous(),
                                 actual.shape,
                                 actual.stride()[::-1])
     for fn in self.assert_fns_with_inputs(actual, expected):
         fn(check_stride=False)
Example #7
0
    def test_mismatching_stride(self):
        actual = torch.empty((2, 2))
        expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])

        for fn in assert_fns_with_inputs(actual, expected):
            with self.assertRaisesRegex(AssertionError, "stride"):
                fn()
Example #8
0
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """Returns the :py:class:`torch.Tensor` after adding context frames.

        Args:
            x: :py:class:`torch.Tensor` with size ``(1, features, seq_len)``.

        Returns:
            A :py:class:`torch.Tensor` with size ``(2*n_context + 1, features,
            seq_len)``.
        """
        # Pad to ensure first and last n_context frames in original sequence
        # have at least n_context frames to their left and right respectively.
        assert x.size(0) == 1
        x = x.squeeze().T
        steps, features = x.shape
        padding = torch.zeros((self.n_context, features), dtype=x.dtype)
        x = torch.cat((padding, x, padding))

        window_size = self.n_context + 1 + self.n_context
        strides = x.stride()
        strided_x = torch.as_strided(
            x,
            # Shape of the new array.
            (steps, window_size, features),
            # Strides of the new array (bytes to step in each dim).
            (strides[0], strides[0], strides[1]),
        )

        return strided_x.clone().detach().permute(1, 2, 0)
def _pool1d(tensor, kernel_size: int = 2, stride: int = 2, mode="max"):
    output_shape = (
        (tensor.shape[0] - kernel_size) // stride + 1,
    )
    kernel_size = (kernel_size,)
    # kernel_size = (kernel_size, kernel_size)
    b = torch.ones(tensor.shape)  # when torch.Tensor.stride() is supported: replace with A.stride()
    a_strides = b.stride()
    a_w = torch.as_strided(
        tensor,
        size=output_shape + kernel_size,
        stride=(stride * a_strides[0],) + a_strides,
    )
    a_w = a_w.reshape(-1, *kernel_size)
    result = []
    if mode == "max":
        for channel in range(a_w.shape[0]):
            result.append(a_w[channel].max())
    elif mode == "mean":
        for channel in range(a_w.shape[0]):
            result.append(torch.mean(a_w[channel]))
    else:
        raise ValueError("unknown pooling mode")

    result = torch.stack(result).reshape(output_shape)
    return result
Example #10
0
    def test_mismatching_stride(self, device):
        a = torch.empty((2, 2), device=device)
        b = torch.as_strided(a.clone().t().contiguous(), a.shape,
                             a.stride()[::-1])

        for fn in self.assert_fns():
            with self.assertRaisesRegex(AssertionError, "stride"):
                fn(a, b)
Example #11
0
 def max_pool2d_layer_numpy(self, x):
     b, c, h, w = x.shape
     b_strided, c_strided, h_strided, w_strided = x.stride()
     x_strided = as_strided(x, (b, c, h // 2, w // 2, 2, 2),
                            (b_strided, c_strided, h_strided * 2,
                             w_strided * 2, h_strided, w_strided))
     result = torch.amax(x_strided, dim=(-1, -2))
     return result
Example #12
0
    def __call__(self, feat_data, item):
        sliced = super(AsFramedSlice, self).__call__(feat_data, item)
        if self.as_strided:
            if isinstance(sliced, np.ndarray):
                as_strided = lambda tensor: torch.as_strided(
                    torch.from_numpy(tensor),
                    size=(self.length - self.frame_size + 1, self.frame_size),
                    stride=(1, 1))
            else:
                as_strided = lambda tensor: torch.as_strided(
                    tensor,
                    size=(self.length - self.frame_size + 1, self.frame_size),
                    stride=(1, 1))

            with torch.no_grad():
                return as_strided(sliced)
        else:
            return sliced.reshape(-1, self.frame_size)
Example #13
0
def signal_framing(signal, frame_length, frame_step):
    shape = list(signal.size())
    shape = shape[:-1] + [
        (shape[-1] - frame_length + frame_step) // frame_step, frame_length
    ]
    strides = list(signal.stride())
    strides.insert(-1, frame_step * strides[-1])
    signal = torch.as_strided(signal, size=shape, stride=strides)
    return signal
Example #14
0
def unfold(tensor, size, step, dilation=1):
    assert tensor.dim() == 1
    o_stride = tensor.stride(0)
    numel = tensor.numel()
    new_stride = (step * o_stride, dilation * o_stride)
    new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
    if new_size[0] < 1:
        new_size = (0, size)
    return torch.as_strided(tensor, new_size, new_stride)
Example #15
0
    def forward(self, y, y_hat):
        """Calculate time domain loss

        Args:
            y (Tensor): real waveform
            y_hat (Tensor): fake waveform
        Return: 
            total_loss (Tensor): total loss of time domain
            
        """

        # Energy loss & Time loss & Phase loss
        loss_e = torch.zeros(self.len).to(y)
        loss_t = torch.zeros(self.len).to(y)
        loss_p = torch.zeros(self.len).to(y)

        for i in range(self.len):
            y_tmp = torch.as_strided(y, self.shapes[i], self.strides[i])
            y_hat_tmp = torch.as_strided(y_hat, self.shapes[i],
                                         self.strides[i])

            loss_e[i] = F.l1_loss(torch.mean(y_tmp**2, dim=-1),
                                  torch.mean(y_hat_tmp**2, dim=-1))
            loss_t[i] = F.l1_loss(torch.mean(y_tmp, dim=-1),
                                  torch.mean(y_hat_tmp, dim=-1))
            if i == 0:
                y_phase = F.pad(y_tmp.transpose(1, 2),
                                (1, 0), "constant", 0) - F.pad(
                                    y_tmp.transpose(1, 2),
                                    (0, 1), "constant", 0)
                y_hat_phase = F.pad(y_hat_tmp.transpose(1, 2),
                                    (1, 0), "constant", 0) - F.pad(
                                        y_hat_tmp.transpose(1, 2),
                                        (0, 1), "constant", 0)
            else:
                y_phase = F.pad(y_tmp, (1, 0), "constant", 0) - F.pad(
                    y_tmp, (0, 1), "constant", 0)
                y_hat_phase = F.pad(y_hat_tmp, (1, 0), "constant", 0) - F.pad(
                    y_hat_tmp, (0, 1), "constant", 0)
            loss_p[i] = F.l1_loss(y_phase, y_hat_phase)

        total_loss = torch.sum(loss_e) + torch.sum(loss_t) + torch.sum(loss_p)

        return total_loss
Example #16
0
    def _create_views(self, tensor):
        views = []

        for stride in self.strides:
            outdim = tensor.size(0) - stride + 1
            view = torch.as_strided(tensor, (outdim, stride, self.dim),
                                    (self.dim, self.dim, 1))
            views.append(view)

        return views
Example #17
0
def unfold(tensor, size, step, dilation=1):
    assert tensor.dim() == 1
    o_stride = tensor.stride(0)
    numel = tensor.numel()
    new_stride = (step * o_stride, dilation * o_stride)
    new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
    if new_size[0] < 1:
        #new_size = (0, size)
        #not to exclude video shorther than size
        return tensor.unsqueeze(0)
    return torch.as_strided(tensor, new_size, new_stride)
Example #18
0
def window_view(x, dim, size, stride, at=None):
    r"""Returns a sliding window view into the tensor.

    Similar to `torch.unfold()`, but the window dimensions of size `size`
    is placed right after `dim` (by default), and not appended.

    Arguments
    ---------
    x : torch.tensor
        Time series of measurement values
    dim : int
        Axis along which the periodogram is computed, i.e. ``dim=-1``.
    size : int
        The size of the sliding windows.
    stride : int
        The step between two sliding windows.
    at : int, optional
        The dimension at which to put the slice of each window.

    Returns
    -------
    x_view : torch.tensor
        The view into a sliding window. The returned tensor is not, in
        general, contiguous in memory.
    """
    if size <= 0:
        raise ValueError(f"""`size` must be a positive integer.""")

    if stride < 0:
        raise ValueError(f"""`stride` must be a nonnegative integer.""")

    dim = fix_dim(dim, x.dim())
    if x.shape[dim] < size:
        raise ValueError(f"""`x` at dim {dim} is too short ({x.shape[dim]}) """
                         f"""for this window size ({size}).""")

    if at is None:
        at = dim + 1
    at = fix_dim(at, x.dim() + 1)

    # compute new shape and strides
    shape, strides = list(x.size()), list(x.stride())
    strided_size = ((shape[dim] - size + 1) + stride - 1) // stride

    # new shape and stride structure
    shape_view = shape[:dim] + [strided_size] + shape[dim+1:]
    shape_view.insert(at, size)

    strides_view = strides[:dim] + [strides[dim] * stride] + strides[dim+1:]
    strides_view.insert(at, strides[dim])

    # differentiable strided view
    return torch.as_strided(x, shape_view, strides_view)
Example #19
0
def vjp(outputs, inputs, **kwargs):
    if torch.is_tensor(inputs):
        inputs = [inputs]
    _dummy_inputs = [torch.as_strided(i, (), ())
                     for i in inputs]  # Workaround for PyTorch bug #39784.

    if torch.is_tensor(outputs):
        outputs = [outputs]
    outputs = make_seq_requires_grad(outputs)

    _vjp = torch.autograd.grad(outputs, inputs, **kwargs)
    return convert_none_to_zeros(_vjp, inputs)
Example #20
0
    def forward(self, y, y_hat):
        """Calculate time domain loss

        Args:
            y (Tensor): real waveform
            y_hat (Tensor): fake waveform
        Return: 
            total_loss (Tensor): total loss of time domain
            
        """

        # Energy loss
        loss_e = torch.zeros(self.len).to(y)
        for i in range(self.len):
            y_energy = torch.as_strided(y**2, self.shapes[i], self.strides[i])
            y_hat_energy = torch.as_strided(y_hat**2, self.shapes[i],
                                            self.strides[i])
            loss_e[i] = F.l1_loss(torch.mean(y_energy, dim=-1),
                                  torch.mean(y_hat_energy, dim=-1))

        # Time loss
        loss_t = torch.zeros(self.len).to(y)
        for i in range(self.len):
            y_time = torch.as_strided(y, self.shapes[i], self.strides[i])
            y_hat_time = torch.as_strided(y_hat, self.shapes[i],
                                          self.strides[i])
            loss_t[i] = F.l1_loss(torch.mean(y_time, dim=-1),
                                  torch.mean(y_hat_time, dim=-1))

        # Phase loss
        y_phase = F.pad(y,
                        (1, 0), "constant", 0) - F.pad(y,
                                                       (0, 1), "constant", 0)
        y_hat_phase = F.pad(y_hat, (1, 0), "constant", 0) - F.pad(
            y_hat, (0, 1), "constant", 0)
        loss_p = F.l1_loss(y_phase, y_hat_phase)

        total_loss = torch.sum(loss_e) + torch.sum(loss_t) + loss_p

        return total_loss
Example #21
0
 def tensor_creation_ops(self):
     i = torch.tensor([[0, 1, 1], [2, 0, 2]])
     v = torch.tensor([3, 4, 5], dtype=torch.float32)
     real = torch.tensor([1, 2], dtype=torch.float32)
     imag = torch.tensor([3, 4], dtype=torch.float32)
     inp = torch.tensor([-1.5, 0.0, 2.0])
     values = torch.tensor([0.5])
     quantized = torch.quantize_per_channel(
         torch.tensor([[-1.0, 0.0], [1.0, 2.0]]),
         torch.tensor([0.1, 0.01]),
         torch.tensor([10, 0]),
         0,
         torch.quint8,
     )
     return (
         torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]),
         # torch.sparse_coo_tensor(i, v, [2, 3]), # not work for iOS
         torch.as_tensor([1, 2, 3]),
         torch.as_strided(torch.randn(3, 3), (2, 2), (1, 2)),
         torch.zeros(2, 3),
         torch.zeros((2, 3)),
         torch.zeros([2, 3], out=i),
         torch.zeros(5),
         torch.zeros_like(torch.empty(2, 3)),
         torch.ones(2, 3),
         torch.ones((2, 3)),
         torch.ones([2, 3]),
         torch.ones(5),
         torch.ones_like(torch.empty(2, 3)),
         torch.arange(5),
         torch.arange(1, 4),
         torch.arange(1, 2.5, 0.5),
         torch.range(1, 4),
         torch.range(1, 4, 0.5),
         torch.linspace(3.0, 3.0, steps=1),
         torch.logspace(start=2, end=2, steps=1, base=2.0),
         torch.eye(3),
         torch.empty(2, 3),
         torch.empty_like(torch.empty(2, 3), dtype=torch.int64),
         torch.empty_strided((2, 3), (1, 2)),
         torch.full((2, 3), 3.141592),
         torch.full_like(torch.full((2, 3), 3.141592), 2.71828),
         torch.quantize_per_tensor(
             torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8
         ),
         torch.dequantize(quantized),
         torch.complex(real, imag),
         torch.polar(real, imag),
         torch.heaviside(inp, values),
     )
Example #22
0
def get_frame(audio, step_size, center):
    if center:
        audio = nn.functional.pad(audio, pad=(512, 512))
    # make 1024-sample frames of the audio with hop length of 10 milliseconds
    hop_length = int(16000 * step_size / 1000)
    n_frames = 1 + int((len(audio) - 1024) / hop_length)
    assert audio.dtype == torch.float32
    itemsize = 1  # float32 byte size
    frames = torch.as_strided(audio, size=(1024, n_frames), stride=(itemsize, hop_length * itemsize))
    frames = frames.transpose(0, 1).clone()

    frames -= (torch.mean(frames, axis=1).unsqueeze(-1))
    frames /= (torch.std(frames, axis=1).unsqueeze(-1))
    return frames
Example #23
0
 def convolution(self, m):
     f = self.weight
     m_h, m_w = m.shape[-2:]
     f_h, f_w = f.shape[-2:]
     batch_size = m.shape[0]
     m_c = m.shape[1]
     Hout = m_h - f_h + 1
     Wout = m_w - f_w + 1
     stride_batch_size, stride_c, stride_h, stride_w = m.stride()
     m_strided = as_strided(m, (batch_size, Hout, Wout, m_c, f_h, f_w),
                            (stride_batch_size, stride_h, stride_w,
                             stride_c, stride_h, stride_w))
     result = einsum('bmncuv,kcuv->bkmn', m_strided, f)
     return result
Example #24
0
 def test_stride(self):
     cpu_ones = torch.ones(3, 3)
     ort_ones = cpu_ones.to("ort")
     y = torch.as_strided(ort_ones, (2, 2), (1, 2))
     assert y.size() == (2, 2)
     assert y.is_contiguous() == False
     contiguous_y = y.contiguous()
     w = torch.ones((2, 3))
     ort_w = w.to("ort")
     z = torch.zeros((2, 3))
     ort_z = z.to("ort")
     ort_z = torch.addmm(ort_z, contiguous_y, ort_w)
     cpu_z = torch.addmm(z, torch.ones(2, 2), w)
     assert torch.allclose(ort_z.cpu(), cpu_z)
Example #25
0
def preprocess_audio_batch(audio, sr, center=True, hop_size=0.1, sampler="julian"):
    if audio.ndim == 3:
        audio = torch.mean(audio, axis=2)

    if sr != TARGET_SR:
        if sampler == "julian":
            audio = julius.resample_frac(audio, sr, TARGET_SR)

        elif sampler == "resampy":
            audio = torch.tensor(
                resampy.resample(
                    audio.detach().cpu().numpy(),
                    sr_orig=sr,
                    sr_new=TARGET_SR,
                    filter="kaiser_best",
                ),
                dtype=audio.dtype,
                device=audio.device,
            )

        else:
            raise ValueError("Only julian and resampy works!")

    frame_len = TARGET_SR
    hop_len = int(hop_size * TARGET_SR)
    if center:
        audio = center_audio(audio, frame_len)

    audio = pad_audio(audio, frame_len, hop_len)
    n_frames = 1 + int((audio.size()[1] - frame_len) / float(hop_len))
    x = []
    xframes_shape = None
    for i in range(audio.shape[0]):
        xframes = (
            torch.as_strided(
                audio[i],
                size=(frame_len, n_frames),
                stride=(1, hop_len),
            )
            .transpose(0, 1)
            .unsqueeze(1)
        )
        if xframes_shape is None:
            xframes_shape = xframes.shape
        assert xframes.shape == xframes_shape
        x.append(xframes)
    x = torch.vstack(x)
    return x
def unfold(tensor, frames_per_clip, frames_between_clips, dilation=1):
    """
    similar to tensor.unfold, but with the dilation
    and specialized for 1d tensors
    Returns all consecutive windows of `size` elements, with
    `step` between windows. The distance between each element
    in a window is given by `dilation`.
    """
    assert tensor.dim() == 1
    o_stride = tensor.stride(0)
    numel = tensor.numel()
    new_stride = (frames_between_clips * o_stride, dilation * o_stride)
    new_size = ((numel - (dilation * (frames_per_clip - 1) + 1)) // frames_between_clips + 1, frames_per_clip)
    if new_size[0] < 1:
        new_size = (0, frames_per_clip)
    return torch.as_strided(tensor, new_size, new_stride)
Example #27
0
def t_window(x, size, shift=None, stride=1):
    try:
        nd = len(size)
    except TypeError:
        size = tuple(size for i in x.shape)
        nd = len(size)
    if nd != x.ndimension():
        raise ValueError("size has length {0} instead of "
                         "x.ndim which is {1}".format(len(size),
                                                      x.ndimension()))
    out_shape = tuple(xi - wi + 1 for xi, wi in zip(x.shape, size)) + size
    if not all(i > 0 for i in out_shape):
        raise ValueError("size is bigger than input array along at "
                         "least one dimension")
    out_strides = x.stride() * 2
    return t.as_strided(x, out_shape, out_strides)
Example #28
0
            def batch_stripe(a):
                """
                Get a diagonal stripe of a matrix m x n, where n > m
                this implementation also takes into account batched matrices,
                so the stripe is calculated over a batch x for a matrix of size[x, m, n]
                """
                # another solution
                # a = a[::-1]  # ValueError: negative step not yet supported
                # do the usual left top to right bottom
                # return a[::-1]

                b, i, j = a.size()
                assert i > j
                b_s, k, l = a.stride()

                # left top to right bottom
                return torch.as_strided(a, (b, i - j + 1, j), (b_s, k, k + l))
Example #29
0
def unfold(tensor: torch.Tensor,
           size: int,
           step: int,
           dilation: int = 1) -> torch.Tensor:
    """
    similar to tensor.unfold, but with the dilation
    and specialized for 1d tensors

    Returns all consecutive windows of `size` elements, with
    `step` between windows. The distance between each element
    in a window is given by `dilation`.
    """
    assert tensor.dim() == 1
    o_stride = tensor.stride(0)
    numel = tensor.numel()
    new_stride = (step * o_stride, dilation * o_stride)
    new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
    if new_size[0] < 1:
        new_size = (0, size)
    return torch.as_strided(tensor, new_size, new_stride)
Example #30
0
def jvp(outputs, inputs, grad_inputs=None, **kwargs):
    # Unlike `torch.autograd.functional.jvp`, this function avoids repeating forward computation.
    if torch.is_tensor(inputs):
        inputs = [inputs]
    _dummy_inputs = [torch.as_strided(i, (), ())
                     for i in inputs]  # Workaround for PyTorch bug #39784.

    if torch.is_tensor(outputs):
        outputs = [outputs]
    outputs = make_seq_requires_grad(outputs)

    dummy_outputs = [torch.zeros_like(o, requires_grad=True) for o in outputs]
    _vjp = torch.autograd.grad(outputs,
                               inputs,
                               grad_outputs=dummy_outputs,
                               create_graph=True,
                               allow_unused=True)
    _vjp = make_seq_requires_grad(convert_none_to_zeros(_vjp, inputs))

    _jvp = torch.autograd.grad(_vjp,
                               dummy_outputs,
                               grad_outputs=grad_inputs,
                               **kwargs)
    return convert_none_to_zeros(_jvp, dummy_outputs)