Example #1
0
def _test_no_adjustments(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint,
                         inverse, initial, path_grad, scalar_term):
    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term)

    path_clone = path.clone()
    if isinstance(basepoint, torch.Tensor):
        basepoint_clone = basepoint.clone()
    if isinstance(initial, torch.Tensor):
        initial_clone = initial.clone()

    signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term)

    can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \
                   (isinstance(initial, torch.Tensor) and initial.requires_grad)
    if can_backward:
        grad = torch.rand_like(signature)

        signature_clone = signature.clone()
        grad_clone = grad.clone()

        signature.backward(grad)
    else:
        assert signature.grad_fn is None

    h.diff(path, path_clone)
    if isinstance(basepoint, torch.Tensor):
        h.diff(basepoint, basepoint_clone)
    if isinstance(initial, torch.Tensor):
        h.diff(initial, initial_clone)
    if can_backward:
        h.diff(signature, signature_clone)
        h.diff(grad, grad_clone)
Example #2
0
def _test_forward(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint,
                  inverse, initial, scalar_term):
    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)

    expected_exception = (batch_size < 1) or (input_channels < 1) or (basepoint is False and input_stream < 2) or \
                         (input_stream < 1)
    try:
        initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term)
        signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term)
    except ValueError:
        if expected_exception:
            return
        else:
            raise
    else:
        assert not expected_exception

    _test_shape(signature, stream, basepoint, batch_size, input_stream, input_channels, depth, scalar_term)
    h.diff(signature, iisignature_signature(path, depth, stream, basepoint, inverse, initial, scalar_term))

    if path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \
            (isinstance(initial, torch.Tensor) and initial.requires_grad):
        ctx = signature.grad_fn
        if stream:
            ctx = ctx.next_functions[0][0]
        assert type(ctx).__name__ in ('_SignatureFunctionBackward', '_SignatureCombineFunctionBackward')
        ref = weakref.ref(ctx)
        del ctx
        del signature
        gc.collect()
        assert ref() is None
    else:
        assert signature.grad_fn is None
Example #3
0
def _test_repeat_and_memory_leaks(class_, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint,
                                  inverse, initial, scalar_term):
    cpu_path = h.get_path(batch_size, input_stream, input_channels, device='cpu', path_grad=False)
    cpu_basepoint = h.get_basepoint(batch_size, input_channels, device='cpu', basepoint=basepoint)
    cpu_initial = h.get_initial(batch_size, input_channels, device='cpu', depth=depth, initial=initial,
                                scalar_term=scalar_term)

    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', message="Argument 'initial' has been set but argument 'basepoint' has "
                                                  "not.", category=UserWarning)
        if class_:
            signature_instance = signatory.Signature(depth, stream=stream, inverse=inverse, scalar_term=scalar_term)
            cpu_signature = signature_instance(cpu_path, basepoint=cpu_basepoint, initial=cpu_initial)
        else:
            cpu_signature = signatory.signature(cpu_path, depth, stream=stream, basepoint=cpu_basepoint,
                                                inverse=inverse, initial=cpu_initial, scalar_term=scalar_term)
    cpu_grad = torch.rand_like(cpu_signature)

    def one_iteration():
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.reset_max_memory_allocated()
        cuda_path = cpu_path.to('cuda')
        if path_grad:
            cuda_path.requires_grad_()
        if isinstance(cpu_basepoint, torch.Tensor):
            cuda_basepoint = cpu_basepoint.cuda()
            if basepoint is h.with_grad:
                cuda_basepoint.requires_grad_()
        else:
            cuda_basepoint = basepoint
        if isinstance(cpu_initial, torch.Tensor):
            cuda_initial = cpu_initial.cuda()
            if initial is h.with_grad:
                cuda_initial.requires_grad_()
        else:
            cuda_initial = initial

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', message="Argument 'initial' has been set but argument 'basepoint' has "
                                                      "not.", category=UserWarning)
            if class_:
                cuda_signature = signature_instance(cuda_path, basepoint=cuda_basepoint, initial=cuda_initial)
            else:
                cuda_signature = signatory.signature(cuda_path, depth, stream=stream, basepoint=cuda_basepoint,
                                                     inverse=inverse, initial=cuda_initial, scalar_term=scalar_term)

        h.diff(cuda_signature.cpu(), cpu_signature)

        if path_grad:
            cuda_grad = cpu_grad.cuda()
            cuda_signature.backward(cuda_grad)
        torch.cuda.synchronize()
        return torch.cuda.max_memory_allocated()

    memory_used = one_iteration()

    for repeat in range(10):
        assert one_iteration() <= memory_used
Example #4
0
def _test_backward(class_, device, batch_size, input_stream, input_channels,
                   depth, stream, basepoint, inverse, initial):
    path = h.get_path(batch_size,
                      input_stream,
                      input_channels,
                      device,
                      path_grad=True)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    initial = h.get_initial(batch_size, input_channels, device, depth, initial)

    # This is the test we'd like to run. Unfortunately it takes forever, so we do something else instead.
    #
    # if class_:
    #     def check_fn(path, basepoint, initial):
    #         return signatory.Signature(depth, stream=stream, inverse=inverse)(path, basepoint=basepoint,
    #                                                                           initial=initial)
    # else:
    #     def check_fn(path, basepoint, initial):
    #         return signatory.signature(path, depth, stream=stream, basepoint=basepoint, inverse=inverse,
    #                                    initial=initial)
    # try:
    #     autograd.gradcheck(check_fn, (path, basepoint, initial), atol=2e-05, rtol=0.002)
    # except RuntimeError:
    #     pytest.fail()

    signature = signatory_signature(class_, path, depth, stream, basepoint,
                                    inverse, initial)

    grad = torch.rand_like(signature)
    signature.backward(grad)

    path_grad = path.grad.clone()
    path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        basepoint_grad = basepoint.grad.clone()
        basepoint.grad.zero_()
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        initial_grad = initial.grad.clone()
        initial.grad.zero_()

    iisignature_signature_result = iisignature_signature(
        path, depth, stream, basepoint, inverse, initial)
    iisignature_signature_result.backward(grad)

    # iisignature uses float32 for this calculation so we need a lower tolerance
    h.diff(path.grad, path_grad, atol=1e-4)
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        h.diff(basepoint.grad, basepoint_grad, atol=1e-4)
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        h.diff(initial.grad, initial_grad, atol=1e-4)
Example #5
0
def _test_batch_trick(class_, device, path_grad, batch_size, input_stream,
                      input_channels, depth, stream, basepoint, inverse,
                      initial):
    if device == 'cuda':
        threshold = 512
    else:
        from signatory import impl
        threshold = impl.hardware_concurrency()
        if threshold < 2:
            return  # can't test the batch trick in this case
    if round(float(threshold) / batch_size) < 2:
        batch_size = int(threshold / 2)

    path = h.get_path(batch_size, input_stream, input_channels, device,
                      path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    initial = h.get_initial(batch_size, input_channels, device, depth, initial)

    current_parallelism = signatory.max_parallelism()
    try:
        signatory.max_parallelism(1)  # disable batch trick
        signature = signatory_signature(class_, path, depth, stream, basepoint,
                                        inverse, initial)
    finally:
        signatory.max_parallelism(current_parallelism)  # enable batch trick

    batch_trick_signature = signatory.signature.__globals__[
        '_signature_batch_trick'](path,
                                  depth,
                                  stream=stream,
                                  basepoint=basepoint,
                                  inverse=inverse,
                                  initial=initial)

    assert batch_trick_signature is not None  # that the batch trick is viable in this case

    h.diff(signature, batch_trick_signature)

    can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \
                   (isinstance(initial, torch.Tensor) and initial.requires_grad)
    try:
        grad = torch.rand_like(signature)
        signature.backward(grad)
    except RuntimeError:
        assert not can_backward
        return
    else:
        assert can_backward

    if path_grad:
        path_grad_ = path.grad.clone()
        path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        basepoint_grad = basepoint.grad.clone()
        basepoint.grad.zero_()
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        initial_grad = initial.grad.clone()
        initial.grad.zero_()
    batch_trick_signature.backward(grad)
    if path_grad:
        h.diff(path_grad_, path.grad)
        path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        h.diff(basepoint_grad, basepoint.grad)
        basepoint.grad.zero_()
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        h.diff(initial_grad, initial.grad)
        initial.grad.zero_()