Esempio n. 1
0
def _test_no_adjustments(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint,
                         inverse, initial, path_grad, scalar_term):
    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term)

    path_clone = path.clone()
    if isinstance(basepoint, torch.Tensor):
        basepoint_clone = basepoint.clone()
    if isinstance(initial, torch.Tensor):
        initial_clone = initial.clone()

    signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term)

    can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \
                   (isinstance(initial, torch.Tensor) and initial.requires_grad)
    if can_backward:
        grad = torch.rand_like(signature)

        signature_clone = signature.clone()
        grad_clone = grad.clone()

        signature.backward(grad)
    else:
        assert signature.grad_fn is None

    h.diff(path, path_clone)
    if isinstance(basepoint, torch.Tensor):
        h.diff(basepoint, basepoint_clone)
    if isinstance(initial, torch.Tensor):
        h.diff(initial, initial_clone)
    if can_backward:
        h.diff(signature, signature_clone)
        h.diff(grad, grad_clone)
Esempio n. 2
0
def _test_forward(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint,
                  inverse, initial, scalar_term):
    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)

    expected_exception = (batch_size < 1) or (input_channels < 1) or (basepoint is False and input_stream < 2) or \
                         (input_stream < 1)
    try:
        initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term)
        signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term)
    except ValueError:
        if expected_exception:
            return
        else:
            raise
    else:
        assert not expected_exception

    _test_shape(signature, stream, basepoint, batch_size, input_stream, input_channels, depth, scalar_term)
    h.diff(signature, iisignature_signature(path, depth, stream, basepoint, inverse, initial, scalar_term))

    if path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \
            (isinstance(initial, torch.Tensor) and initial.requires_grad):
        ctx = signature.grad_fn
        if stream:
            ctx = ctx.next_functions[0][0]
        assert type(ctx).__name__ in ('_SignatureFunctionBackward', '_SignatureCombineFunctionBackward')
        ref = weakref.ref(ctx)
        del ctx
        del signature
        gc.collect()
        assert ref() is None
    else:
        assert signature.grad_fn is None
Esempio n. 3
0
    def one_iteration():
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.reset_max_memory_allocated()
        # device='cuda' because it's just easier to keep track of GPU memory
        cuda_path = cpu_path.to('cuda')
        if path_grad:
            cuda_path.requires_grad_()
        if isinstance(cpu_basepoint, torch.Tensor):
            cuda_basepoint = cpu_basepoint.cuda()
            if basepoint is h.with_grad:
                cuda_basepoint.requires_grad_()
        else:
            cuda_basepoint = basepoint

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', message="The logsignature with mode='brackets' has been requested on the "
                                                      "GPU.", category=UserWarning)
            if class_:
                cuda_logsignature = logsignature_instance(cuda_path, basepoint=cuda_basepoint)
            else:
                cuda_logsignature = signatory.logsignature(cuda_path, depth, stream=stream, basepoint=cuda_basepoint,
                                                           inverse=inverse, mode=mode)

        h.diff(cuda_logsignature.cpu(), cpu_logsignature)

        if path_grad:
            cuda_grad = cpu_grad.cuda()
            cuda_logsignature.backward(cuda_grad)
        torch.cuda.synchronize()
        return torch.cuda.max_memory_allocated()
    def one_iteration():
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.reset_max_memory_allocated()
        cuda_signature = cpu_signature.to('cuda')
        if signature_grad:
            cuda_signature.requires_grad_()
        with warnings.catch_warnings():
            warnings.filterwarnings(
                'ignore',
                message=
                "The logsignature with mode='brackets' has been requested on the "
                "GPU.",
                category=UserWarning)
            if class_:
                cuda_logsignature = signature_to_logsignature_instance(
                    cuda_signature)
            else:
                cuda_logsignature = signatory.signature_to_logsignature(
                    cuda_signature,
                    input_channels,
                    depth,
                    stream=stream,
                    mode=mode)

        h.diff(cuda_logsignature.cpu(), cpu_logsignature)

        if signature_grad:
            cuda_grad = cpu_grad.cuda()
            cuda_logsignature.backward(cuda_grad)
        torch.cuda.synchronize()
        return torch.cuda.max_memory_allocated()
Esempio n. 5
0
def _test_forward(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint,
                  inverse, mode):

    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)

    expected_exception = (batch_size < 1) or (input_channels < 1) or (basepoint is False and input_stream < 2) or \
                         (input_stream < 1)
    try:
        logsignature = signatory_logsignature(class_, path, depth, stream, basepoint, inverse, mode)
    except ValueError:
        if expected_exception:
            return
        else:
            raise
    else:
        assert not expected_exception

    _test_shape(logsignature, mode, batch_size, input_stream, input_channels, depth, stream, basepoint)
    h.diff(logsignature, iisignature_logsignature(path, depth, stream, basepoint, inverse, mode))

    # Check that the 'ctx' object is properly garbage collected and we don't have a memory leak
    # (Easy to accidentally happen due to PyTorch bug 25340)
    if path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad):
        ctx = logsignature.grad_fn
        if stream:
            ctx = ctx.next_functions[0][0]
        assert type(ctx).__name__ == '_SignatureToLogsignatureFunctionBackward'
        ref = weakref.ref(ctx)
        del ctx
        del logsignature
        gc.collect()
        assert ref() is None
    else:
        assert logsignature.grad_fn is None
Esempio n. 6
0
def _test_backward(class_, device, batch_size, input_stream, input_channels,
                   depth, stream, mode, scalar_term):

    # This test (in the comment below) runs out of memory! So we don't do this, and do something else instead.
    #
    # path = h.get_path(batch_size, input_stream, input_channels, device, path_grad=False)
    # signature = signatory.signature(path, depth, stream=stream)
    # signature.requires_grad_()
    # if class_:
    #     def check_fn(signature):
    #         return signatory.SignatureToLogSignature(input_channels, depth, stream=stream, mode=mode)(signature)
    # else:
    #     def check_fn(signature):
    #         return signatory.signature_to_logsignature(signature, input_channels, depth, stream=stream, mode=mode)
    # try:
    #     autograd.gradcheck(check_fn, (signature,), atol=2e-05, rtol=0.002)
    # except RuntimeError:
    #     pytest.fail()

    path = h.get_path(batch_size,
                      input_stream,
                      input_channels,
                      device,
                      path_grad=True)
    signature = signatory.signature(path,
                                    depth,
                                    stream=stream,
                                    scalar_term=scalar_term)
    with warnings.catch_warnings():
        warnings.filterwarnings(
            'ignore',
            message=
            "The logsignature with mode='brackets' has been requested on the "
            "GPU.",
            category=UserWarning)
        logsignature = signatory_signature_to_logsignature(
            class_, signature, input_channels, depth, stream, mode,
            scalar_term)

    grad = torch.rand_like(logsignature)
    logsignature.backward(grad)

    path_grad = path.grad.clone()
    path.grad.zero_()

    with warnings.catch_warnings():
        warnings.filterwarnings(
            'ignore',
            message=
            "The logsignature with mode='brackets' has been requested on the "
            "GPU.",
            category=UserWarning)
        true_logsignature = signatory.logsignature(path,
                                                   depth,
                                                   stream=stream,
                                                   mode=mode)
    true_logsignature.backward(grad)
    h.diff(path.grad, path_grad)
def _test_backward(signature_combine, amount, device, batch_size, input_stream,
                   input_channels, depth, inverse):
    paths = []
    for _ in range(amount):
        paths.append(
            torch.rand(batch_size,
                       input_stream,
                       input_channels,
                       device=device,
                       dtype=torch.double,
                       requires_grad=True))
    signatures = []
    basepoint = False
    for path in paths:
        signature = iisignature_signature(path,
                                          depth,
                                          basepoint=basepoint,
                                          inverse=inverse)
        signatures.append(signature)
        basepoint = path[:, -1]

    # This is the test we'd like to run here, but it takes too long.
    # Furthermore we'd also prefer to only go backwards through the signature combine, not through the signature, but
    # we can't really do that with our faster alternative.
    #
    # if signature_combine:
    #     def check_fn(*signatures):
    #         return signatory.signature_combine(signatures[0], signatures[1], input_channels, depth, inverse=inverse)
    # else:
    #     def check_fn(*signatures):
    #         return signatory.multi_signature_combine(signatures, input_channels, depth, inverse=inverse)
    # try:
    #     autograd.gradcheck(check_fn, tuple(signatures), atol=2e-05, rtol=0.002)
    # except RuntimeError:
    #     pytest.fail()

    if signature_combine:
        combined_signatures = signatory.signature_combine(signatures[0],
                                                          signatures[1],
                                                          input_channels,
                                                          depth,
                                                          inverse=inverse)
    else:
        combined_signatures = signatory.multi_signature_combine(
            signatures, input_channels, depth, inverse=inverse)
    grad = torch.rand_like(combined_signatures)
    combined_signatures.backward(grad)
    path_grads = [path.grad.clone() for path in paths]
    for path in paths:
        path.grad.zero_()

    true_signature = iisignature_signature(torch.cat(paths, dim=1),
                                           depth,
                                           inverse=inverse)
    true_signature.backward(grad)
    for path_grad, path in zip(path_grads, paths):
        h.diff(path_grad, path.grad)
def _test_forward(signature_combine, signature_grad, amount, device,
                  batch_size, input_stream, input_channels, depth, inverse,
                  scalar_term):
    paths = []
    for _ in range(amount):
        paths.append(
            torch.rand(batch_size,
                       input_stream,
                       input_channels,
                       device=device,
                       dtype=torch.double))
    signatures = []
    basepoint = False
    for path in paths:
        signature = iisignature_signature(path,
                                          depth,
                                          basepoint=basepoint,
                                          inverse=inverse,
                                          scalar_term=scalar_term)
        if signature_grad:
            signature.requires_grad_()
        signatures.append(signature)
        basepoint = path[:, -1]
    if signature_combine:
        combined_signatures = signatory.signature_combine(
            signatures[0],
            signatures[1],
            input_channels,
            depth,
            inverse=inverse,
            scalar_term=scalar_term)
    else:
        combined_signatures = signatory.multi_signature_combine(
            signatures,
            input_channels,
            depth,
            inverse=inverse,
            scalar_term=scalar_term)
    combined_paths = torch.cat(paths, dim=1)
    true_combined_signatures = iisignature_signature(combined_paths,
                                                     depth,
                                                     inverse=inverse,
                                                     scalar_term=scalar_term)
    h.diff(combined_signatures, true_combined_signatures)

    if signature_grad:
        ctx = combined_signatures.grad_fn
        assert type(ctx).__name__ == '_SignatureCombineFunctionBackward'
        ref = weakref.ref(ctx)
        del ctx
        del combined_signatures
        gc.collect()
        assert ref() is None
    else:
        assert combined_signatures.grad_fn is None
Esempio n. 9
0
    def one_iteration(start, end):
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.reset_max_memory_allocated()
        try:
            tensor = candidate(start, end)
        except ValueError as e:
            try:
                true(start, end)
            except ValueError:
                return 0
            else:
                pytest.fail(str(e))
        try:
            true_tensor = true(start, end)
        except ValueError as e:
            pytest.fail(str(e))
        h.diff(tensor, true_tensor)  # Test #3

        extra(true_tensor)  # Any extra tests

        if tensor.requires_grad:
            grad = torch.rand_like(tensor)
            tensor.backward(grad)
            path_grads = []
            for path in path_obj.path:
                if path.grad is None:
                    path_grads.append(None)
                else:
                    path_grads.append(path.grad.clone())
                    path.grad.zero_()
            true_tensor.backward(grad)
            for path, path_grad in zip(path_obj.path, path_grads):
                if path_grad is None:
                    assert (path.grad is None) or (path.grad.nonzero().numel()
                                                   == 0)
                else:
                    h.diff(path.grad, path_grad)  # Test #4
                    path.grad.zero_()
            ctx = tensor.grad_fn
            assert type(ctx).__name__ == backward_name
            ref = weakref.ref(ctx)
            del ctx
            del tensor
            gc.collect()
            assert ref() is None  # Test #2
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            return torch.cuda.max_memory_allocated()
        else:
            return 0
Esempio n. 10
0
def _test_no_adjustments(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse,
                         mode, path_grad):

    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)

    path_clone = path.clone()
    if isinstance(basepoint, torch.Tensor):
        basepoint_clone = basepoint.clone()

    logsignature = signatory_logsignature(class_, path, depth, stream, basepoint, inverse, mode)

    can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad)
    if can_backward:
        grad = torch.rand_like(logsignature)

        logsignature_clone = logsignature.clone()
        grad_clone = grad.clone()

        logsignature.backward(grad)
    else:
        assert logsignature.grad_fn is None

    h.diff(path, path_clone)
    if isinstance(basepoint, torch.Tensor):
        h.diff(basepoint, basepoint_clone)
    if can_backward:
        h.diff(logsignature, logsignature_clone)
        h.diff(grad, grad_clone)
Esempio n. 11
0
    def one_iteration():
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.reset_max_memory_allocated()
        cuda_path = cpu_path.to('cuda')
        if path_grad:
            cuda_path.requires_grad_()
        if isinstance(cpu_basepoint, torch.Tensor):
            cuda_basepoint = cpu_basepoint.cuda()
            if basepoint is h.with_grad:
                cuda_basepoint.requires_grad_()
        else:
            cuda_basepoint = basepoint
        if isinstance(cpu_initial, torch.Tensor):
            cuda_initial = cpu_initial.cuda()
            if initial is h.with_grad:
                cuda_initial.requires_grad_()
        else:
            cuda_initial = initial

        with warnings.catch_warnings():
            warnings.filterwarnings(
                'ignore',
                message=
                "Argument 'initial' has been set but argument 'basepoint' has "
                "not.",
                category=UserWarning)
            if class_:
                cuda_signature = signature_instance(cuda_path,
                                                    basepoint=cuda_basepoint,
                                                    initial=cuda_initial)
            else:
                cuda_signature = signatory.signature(cuda_path,
                                                     depth,
                                                     stream=stream,
                                                     basepoint=cuda_basepoint,
                                                     inverse=inverse,
                                                     initial=cuda_initial)

        h.diff(cuda_signature.cpu(), cpu_signature)

        if path_grad:
            cuda_grad = cpu_grad.cuda()
            cuda_signature.backward(cuda_grad)
        torch.cuda.synchronize()
        return torch.cuda.max_memory_allocated()
Esempio n. 12
0
def _test_forward(class_, device, batch_size, input_stream, input_channels,
                  depth, stream, mode, signature_grad, scalar_term):
    path = h.get_path(batch_size,
                      input_stream,
                      input_channels,
                      device,
                      path_grad=False)
    signature = signatory.signature(path,
                                    depth,
                                    stream=stream,
                                    scalar_term=scalar_term)
    if signature_grad:
        signature.requires_grad_()
    with warnings.catch_warnings():
        warnings.filterwarnings(
            'ignore',
            message=
            "The logsignature with mode='brackets' has been requested on the "
            "GPU.",
            category=UserWarning)
        logsignature = signatory_signature_to_logsignature(
            class_,
            signature,
            input_channels,
            depth,
            stream,
            mode,
            scalar_term=scalar_term)
        true_logsignature = signatory.logsignature(path,
                                                   depth,
                                                   stream=stream,
                                                   mode=mode)
    h.diff(logsignature, true_logsignature)

    if signature_grad:
        ctx = logsignature.grad_fn
        if stream:
            ctx = ctx.next_functions[0][0]
        assert type(ctx).__name__ == '_SignatureToLogsignatureFunctionBackward'
        ref = weakref.ref(ctx)
        del ctx
        del logsignature
        gc.collect()
        assert ref() is None
    else:
        assert logsignature.grad_fn is None
def _test_no_adjustments(class_, device, batch_size, input_stream,
                         input_channels, depth, stream, mode, signature_grad):
    path = h.get_path(batch_size,
                      input_stream,
                      input_channels,
                      device,
                      path_grad=False)
    signature = signatory.signature(path, depth, stream=stream)
    signature_clone = signature.clone()
    if signature_grad:
        signature.requires_grad_()
    with warnings.catch_warnings():
        warnings.filterwarnings(
            'ignore',
            message=
            "The logsignature with mode='brackets' has been requested on the "
            "GPU.",
            category=UserWarning)
        logsignature = signatory_signature_to_logsignature(
            class_, signature, input_channels, depth, stream, mode)

    if signature_grad:
        grad = torch.rand_like(logsignature)
        logsignature_clone = logsignature.clone()
        grad_clone = grad.clone()
        logsignature.backward(grad)
    else:
        assert logsignature.grad_fn is None

    h.diff(signature, signature_clone)
    if signature_grad:
        h.diff(logsignature, logsignature_clone)
        h.diff(grad, grad_clone)
Esempio n. 14
0
def _test_backward(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, mode):
    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad=True)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)

    # This is the test we'd like to run, but it takes forever
    #
    # if class_:
    #     def check_fn(path, basepoint):
    #         return signatory.LogSignature(depth, stream=stream, inverse=inverse, mode=mode)(path,
    #                                                                                         basepoint=basepoint)
    # else:
    #     def check_fn(path, basepoint):
    #         return signatory.logsignature(path, depth, stream=stream, basepoint=basepoint, inverse=inverse,
    #                                       mode=mode)
    # try:
    #     autograd.gradcheck(check_fn, (path, basepoint), atol=2e-05, rtol=0.002)
    # except RuntimeError:
    #     pytest.fail()

    logsignature = signatory_logsignature(class_, path, depth, stream, basepoint, inverse, mode)

    grad = torch.rand_like(logsignature)
    logsignature.backward(grad)

    path_grad = path.grad.clone()
    path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        basepoint_grad = basepoint.grad.clone()
        basepoint.grad.zero_()

    iisignature_logsignature_result = iisignature_logsignature(path, depth, stream, basepoint, inverse, mode)
    iisignature_logsignature_result.backward(grad)

    # iisignature uses float32 for this calculation so we need a lower tolerance
    h.diff(path.grad, path_grad, atol=1e-6)
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        h.diff(basepoint.grad, basepoint_grad, atol=1e-6)
Esempio n. 15
0
def _test_no_adjustments(signature_combine, amount, device, batch_size,
                         input_stream, input_channels, depth, inverse,
                         signature_grad, scalar_term):
    paths = []
    for _ in range(amount):
        paths.append(
            torch.rand(batch_size,
                       input_stream,
                       input_channels,
                       device=device,
                       dtype=torch.double))

    signatures = []
    signatures_clone = []
    basepoint = False
    for path in paths:
        signature = iisignature_signature(path,
                                          depth,
                                          basepoint=basepoint,
                                          inverse=inverse,
                                          scalar_term=scalar_term)
        signatures_clone.append(signature.clone())
        if signature_grad:
            signature.requires_grad_()
        signatures.append(signature)
        basepoint = path[:, -1]
    if signature_combine:
        combined_signatures = signatory.signature_combine(
            signatures[0],
            signatures[1],
            input_channels,
            depth,
            inverse=inverse,
            scalar_term=scalar_term)
    else:
        combined_signatures = signatory.multi_signature_combine(
            signatures,
            input_channels,
            depth,
            inverse=inverse,
            scalar_term=scalar_term)

    if signature_grad:
        grad = torch.rand_like(combined_signatures)
        grad_clone = grad.clone()
        combined_signatures_clone = combined_signatures.clone()
        combined_signatures.backward(grad)

    for signature, signature_clone in zip(signatures, signatures_clone):
        h.diff(signature, signature_clone)
    if signature_grad:
        h.diff(grad, grad_clone)
        h.diff(combined_signatures, combined_signatures_clone)
Esempio n. 16
0
def _test_backward(class_, device, batch_size, input_stream, input_channels,
                   depth, stream, basepoint, inverse, initial):
    path = h.get_path(batch_size,
                      input_stream,
                      input_channels,
                      device,
                      path_grad=True)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    initial = h.get_initial(batch_size, input_channels, device, depth, initial)

    # This is the test we'd like to run. Unfortunately it takes forever, so we do something else instead.
    #
    # if class_:
    #     def check_fn(path, basepoint, initial):
    #         return signatory.Signature(depth, stream=stream, inverse=inverse)(path, basepoint=basepoint,
    #                                                                           initial=initial)
    # else:
    #     def check_fn(path, basepoint, initial):
    #         return signatory.signature(path, depth, stream=stream, basepoint=basepoint, inverse=inverse,
    #                                    initial=initial)
    # try:
    #     autograd.gradcheck(check_fn, (path, basepoint, initial), atol=2e-05, rtol=0.002)
    # except RuntimeError:
    #     pytest.fail()

    signature = signatory_signature(class_, path, depth, stream, basepoint,
                                    inverse, initial)

    grad = torch.rand_like(signature)
    signature.backward(grad)

    path_grad = path.grad.clone()
    path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        basepoint_grad = basepoint.grad.clone()
        basepoint.grad.zero_()
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        initial_grad = initial.grad.clone()
        initial.grad.zero_()

    iisignature_signature_result = iisignature_signature(
        path, depth, stream, basepoint, inverse, initial)
    iisignature_signature_result.backward(grad)

    # iisignature uses float32 for this calculation so we need a lower tolerance
    h.diff(path.grad, path_grad, atol=1e-4)
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        h.diff(basepoint.grad, basepoint_grad, atol=1e-4)
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        h.diff(initial.grad, initial_grad, atol=1e-4)
Esempio n. 17
0
def _test_batch_trick(class_, device, path_grad, batch_size, input_stream,
                      input_channels, depth, stream, basepoint, inverse,
                      initial):
    if device == 'cuda':
        threshold = 512
    else:
        from signatory import impl
        threshold = impl.hardware_concurrency()
        if threshold < 2:
            return  # can't test the batch trick in this case
    if round(float(threshold) / batch_size) < 2:
        batch_size = int(threshold / 2)

    path = h.get_path(batch_size, input_stream, input_channels, device,
                      path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    initial = h.get_initial(batch_size, input_channels, device, depth, initial)

    current_parallelism = signatory.max_parallelism()
    try:
        signatory.max_parallelism(1)  # disable batch trick
        signature = signatory_signature(class_, path, depth, stream, basepoint,
                                        inverse, initial)
    finally:
        signatory.max_parallelism(current_parallelism)  # enable batch trick

    batch_trick_signature = signatory.signature.__globals__[
        '_signature_batch_trick'](path,
                                  depth,
                                  stream=stream,
                                  basepoint=basepoint,
                                  inverse=inverse,
                                  initial=initial)

    assert batch_trick_signature is not None  # that the batch trick is viable in this case

    h.diff(signature, batch_trick_signature)

    can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \
                   (isinstance(initial, torch.Tensor) and initial.requires_grad)
    try:
        grad = torch.rand_like(signature)
        signature.backward(grad)
    except RuntimeError:
        assert not can_backward
        return
    else:
        assert can_backward

    if path_grad:
        path_grad_ = path.grad.clone()
        path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        basepoint_grad = basepoint.grad.clone()
        basepoint.grad.zero_()
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        initial_grad = initial.grad.clone()
        initial.grad.zero_()
    batch_trick_signature.backward(grad)
    if path_grad:
        h.diff(path_grad_, path.grad)
        path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        h.diff(basepoint_grad, basepoint.grad)
        basepoint.grad.zero_()
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        h.diff(initial_grad, initial.grad)
        initial.grad.zero_()