Ejemplo n.º 1
0
def _test_no_adjustments(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint,
                         inverse, initial, path_grad, scalar_term):
    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term)

    path_clone = path.clone()
    if isinstance(basepoint, torch.Tensor):
        basepoint_clone = basepoint.clone()
    if isinstance(initial, torch.Tensor):
        initial_clone = initial.clone()

    signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term)

    can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \
                   (isinstance(initial, torch.Tensor) and initial.requires_grad)
    if can_backward:
        grad = torch.rand_like(signature)

        signature_clone = signature.clone()
        grad_clone = grad.clone()

        signature.backward(grad)
    else:
        assert signature.grad_fn is None

    h.diff(path, path_clone)
    if isinstance(basepoint, torch.Tensor):
        h.diff(basepoint, basepoint_clone)
    if isinstance(initial, torch.Tensor):
        h.diff(initial, initial_clone)
    if can_backward:
        h.diff(signature, signature_clone)
        h.diff(grad, grad_clone)
def _test_no_adjustments(class_, device, batch_size, input_stream,
                         input_channels, depth, stream, mode, signature_grad):
    path = h.get_path(batch_size,
                      input_stream,
                      input_channels,
                      device,
                      path_grad=False)
    signature = signatory.signature(path, depth, stream=stream)
    signature_clone = signature.clone()
    if signature_grad:
        signature.requires_grad_()
    with warnings.catch_warnings():
        warnings.filterwarnings(
            'ignore',
            message=
            "The logsignature with mode='brackets' has been requested on the "
            "GPU.",
            category=UserWarning)
        logsignature = signatory_signature_to_logsignature(
            class_, signature, input_channels, depth, stream, mode)

    if signature_grad:
        grad = torch.rand_like(logsignature)
        logsignature_clone = logsignature.clone()
        grad_clone = grad.clone()
        logsignature.backward(grad)
    else:
        assert logsignature.grad_fn is None

    h.diff(signature, signature_clone)
    if signature_grad:
        h.diff(logsignature, logsignature_clone)
        h.diff(grad, grad_clone)
Ejemplo n.º 3
0
def _test_forward(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint,
                  inverse, initial, scalar_term):
    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)

    expected_exception = (batch_size < 1) or (input_channels < 1) or (basepoint is False and input_stream < 2) or \
                         (input_stream < 1)
    try:
        initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term)
        signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term)
    except ValueError:
        if expected_exception:
            return
        else:
            raise
    else:
        assert not expected_exception

    _test_shape(signature, stream, basepoint, batch_size, input_stream, input_channels, depth, scalar_term)
    h.diff(signature, iisignature_signature(path, depth, stream, basepoint, inverse, initial, scalar_term))

    if path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \
            (isinstance(initial, torch.Tensor) and initial.requires_grad):
        ctx = signature.grad_fn
        if stream:
            ctx = ctx.next_functions[0][0]
        assert type(ctx).__name__ in ('_SignatureFunctionBackward', '_SignatureCombineFunctionBackward')
        ref = weakref.ref(ctx)
        del ctx
        del signature
        gc.collect()
        assert ref() is None
    else:
        assert signature.grad_fn is None
Ejemplo n.º 4
0
def _test_no_adjustments(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse,
                         mode, path_grad):

    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)

    path_clone = path.clone()
    if isinstance(basepoint, torch.Tensor):
        basepoint_clone = basepoint.clone()

    logsignature = signatory_logsignature(class_, path, depth, stream, basepoint, inverse, mode)

    can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad)
    if can_backward:
        grad = torch.rand_like(logsignature)

        logsignature_clone = logsignature.clone()
        grad_clone = grad.clone()

        logsignature.backward(grad)
    else:
        assert logsignature.grad_fn is None

    h.diff(path, path_clone)
    if isinstance(basepoint, torch.Tensor):
        h.diff(basepoint, basepoint_clone)
    if can_backward:
        h.diff(logsignature, logsignature_clone)
        h.diff(grad, grad_clone)
Ejemplo n.º 5
0
def _test_forward(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint,
                  inverse, mode):

    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)

    expected_exception = (batch_size < 1) or (input_channels < 1) or (basepoint is False and input_stream < 2) or \
                         (input_stream < 1)
    try:
        logsignature = signatory_logsignature(class_, path, depth, stream, basepoint, inverse, mode)
    except ValueError:
        if expected_exception:
            return
        else:
            raise
    else:
        assert not expected_exception

    _test_shape(logsignature, mode, batch_size, input_stream, input_channels, depth, stream, basepoint)
    h.diff(logsignature, iisignature_logsignature(path, depth, stream, basepoint, inverse, mode))

    # Check that the 'ctx' object is properly garbage collected and we don't have a memory leak
    # (Easy to accidentally happen due to PyTorch bug 25340)
    if path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad):
        ctx = logsignature.grad_fn
        if stream:
            ctx = ctx.next_functions[0][0]
        assert type(ctx).__name__ == '_SignatureToLogsignatureFunctionBackward'
        ref = weakref.ref(ctx)
        del ctx
        del logsignature
        gc.collect()
        assert ref() is None
    else:
        assert logsignature.grad_fn is None
Ejemplo n.º 6
0
def _test_repeat_and_memory_leaks(class_, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint,
                                  inverse, initial, scalar_term):
    cpu_path = h.get_path(batch_size, input_stream, input_channels, device='cpu', path_grad=False)
    cpu_basepoint = h.get_basepoint(batch_size, input_channels, device='cpu', basepoint=basepoint)
    cpu_initial = h.get_initial(batch_size, input_channels, device='cpu', depth=depth, initial=initial,
                                scalar_term=scalar_term)

    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', message="Argument 'initial' has been set but argument 'basepoint' has "
                                                  "not.", category=UserWarning)
        if class_:
            signature_instance = signatory.Signature(depth, stream=stream, inverse=inverse, scalar_term=scalar_term)
            cpu_signature = signature_instance(cpu_path, basepoint=cpu_basepoint, initial=cpu_initial)
        else:
            cpu_signature = signatory.signature(cpu_path, depth, stream=stream, basepoint=cpu_basepoint,
                                                inverse=inverse, initial=cpu_initial, scalar_term=scalar_term)
    cpu_grad = torch.rand_like(cpu_signature)

    def one_iteration():
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.reset_max_memory_allocated()
        cuda_path = cpu_path.to('cuda')
        if path_grad:
            cuda_path.requires_grad_()
        if isinstance(cpu_basepoint, torch.Tensor):
            cuda_basepoint = cpu_basepoint.cuda()
            if basepoint is h.with_grad:
                cuda_basepoint.requires_grad_()
        else:
            cuda_basepoint = basepoint
        if isinstance(cpu_initial, torch.Tensor):
            cuda_initial = cpu_initial.cuda()
            if initial is h.with_grad:
                cuda_initial.requires_grad_()
        else:
            cuda_initial = initial

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', message="Argument 'initial' has been set but argument 'basepoint' has "
                                                      "not.", category=UserWarning)
            if class_:
                cuda_signature = signature_instance(cuda_path, basepoint=cuda_basepoint, initial=cuda_initial)
            else:
                cuda_signature = signatory.signature(cuda_path, depth, stream=stream, basepoint=cuda_basepoint,
                                                     inverse=inverse, initial=cuda_initial, scalar_term=scalar_term)

        h.diff(cuda_signature.cpu(), cpu_signature)

        if path_grad:
            cuda_grad = cpu_grad.cuda()
            cuda_signature.backward(cuda_grad)
        torch.cuda.synchronize()
        return torch.cuda.max_memory_allocated()

    memory_used = one_iteration()

    for repeat in range(10):
        assert one_iteration() <= memory_used
def _test_repeat_and_memory_leaks(class_, batch_size, input_stream,
                                  input_channels, depth, stream, mode,
                                  signature_grad):
    cpu_path = h.get_path(batch_size,
                          input_stream,
                          input_channels,
                          device='cpu',
                          path_grad=False)
    cpu_signature = signatory.signature(cpu_path, depth, stream=stream)
    if class_:
        signature_to_logsignature_instance = signatory.SignatureToLogsignature(
            input_channels, depth, stream=stream, mode=mode)
        cpu_logsignature = signature_to_logsignature_instance(cpu_signature)
    else:
        cpu_logsignature = signatory.signature_to_logsignature(cpu_signature,
                                                               input_channels,
                                                               depth,
                                                               stream=stream,
                                                               mode=mode)
    cpu_grad = torch.rand_like(cpu_logsignature)

    def one_iteration():
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.reset_max_memory_allocated()
        cuda_signature = cpu_signature.to('cuda')
        if signature_grad:
            cuda_signature.requires_grad_()
        with warnings.catch_warnings():
            warnings.filterwarnings(
                'ignore',
                message=
                "The logsignature with mode='brackets' has been requested on the "
                "GPU.",
                category=UserWarning)
            if class_:
                cuda_logsignature = signature_to_logsignature_instance(
                    cuda_signature)
            else:
                cuda_logsignature = signatory.signature_to_logsignature(
                    cuda_signature,
                    input_channels,
                    depth,
                    stream=stream,
                    mode=mode)

        h.diff(cuda_logsignature.cpu(), cpu_logsignature)

        if signature_grad:
            cuda_grad = cpu_grad.cuda()
            cuda_logsignature.backward(cuda_grad)
        torch.cuda.synchronize()
        return torch.cuda.max_memory_allocated()

    memory_used = one_iteration()
    for repeat in range(10):
        # This one seems to be a bit inconsistent with how much memory is used on each run, so we give some
        # leeway by doubling
        assert one_iteration() <= 2 * memory_used
Ejemplo n.º 8
0
def _test_backward(class_, device, batch_size, input_stream, input_channels,
                   depth, stream, mode, scalar_term):

    # This test (in the comment below) runs out of memory! So we don't do this, and do something else instead.
    #
    # path = h.get_path(batch_size, input_stream, input_channels, device, path_grad=False)
    # signature = signatory.signature(path, depth, stream=stream)
    # signature.requires_grad_()
    # if class_:
    #     def check_fn(signature):
    #         return signatory.SignatureToLogSignature(input_channels, depth, stream=stream, mode=mode)(signature)
    # else:
    #     def check_fn(signature):
    #         return signatory.signature_to_logsignature(signature, input_channels, depth, stream=stream, mode=mode)
    # try:
    #     autograd.gradcheck(check_fn, (signature,), atol=2e-05, rtol=0.002)
    # except RuntimeError:
    #     pytest.fail()

    path = h.get_path(batch_size,
                      input_stream,
                      input_channels,
                      device,
                      path_grad=True)
    signature = signatory.signature(path,
                                    depth,
                                    stream=stream,
                                    scalar_term=scalar_term)
    with warnings.catch_warnings():
        warnings.filterwarnings(
            'ignore',
            message=
            "The logsignature with mode='brackets' has been requested on the "
            "GPU.",
            category=UserWarning)
        logsignature = signatory_signature_to_logsignature(
            class_, signature, input_channels, depth, stream, mode,
            scalar_term)

    grad = torch.rand_like(logsignature)
    logsignature.backward(grad)

    path_grad = path.grad.clone()
    path.grad.zero_()

    with warnings.catch_warnings():
        warnings.filterwarnings(
            'ignore',
            message=
            "The logsignature with mode='brackets' has been requested on the "
            "GPU.",
            category=UserWarning)
        true_logsignature = signatory.logsignature(path,
                                                   depth,
                                                   stream=stream,
                                                   mode=mode)
    true_logsignature.backward(grad)
    h.diff(path.grad, path_grad)
Ejemplo n.º 9
0
def _test_backward(class_, device, batch_size, input_stream, input_channels,
                   depth, stream, basepoint, inverse, initial):
    path = h.get_path(batch_size,
                      input_stream,
                      input_channels,
                      device,
                      path_grad=True)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    initial = h.get_initial(batch_size, input_channels, device, depth, initial)

    # This is the test we'd like to run. Unfortunately it takes forever, so we do something else instead.
    #
    # if class_:
    #     def check_fn(path, basepoint, initial):
    #         return signatory.Signature(depth, stream=stream, inverse=inverse)(path, basepoint=basepoint,
    #                                                                           initial=initial)
    # else:
    #     def check_fn(path, basepoint, initial):
    #         return signatory.signature(path, depth, stream=stream, basepoint=basepoint, inverse=inverse,
    #                                    initial=initial)
    # try:
    #     autograd.gradcheck(check_fn, (path, basepoint, initial), atol=2e-05, rtol=0.002)
    # except RuntimeError:
    #     pytest.fail()

    signature = signatory_signature(class_, path, depth, stream, basepoint,
                                    inverse, initial)

    grad = torch.rand_like(signature)
    signature.backward(grad)

    path_grad = path.grad.clone()
    path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        basepoint_grad = basepoint.grad.clone()
        basepoint.grad.zero_()
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        initial_grad = initial.grad.clone()
        initial.grad.zero_()

    iisignature_signature_result = iisignature_signature(
        path, depth, stream, basepoint, inverse, initial)
    iisignature_signature_result.backward(grad)

    # iisignature uses float32 for this calculation so we need a lower tolerance
    h.diff(path.grad, path_grad, atol=1e-4)
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        h.diff(basepoint.grad, basepoint_grad, atol=1e-4)
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        h.diff(initial.grad, initial_grad, atol=1e-4)
Ejemplo n.º 10
0
def _test_forward(class_, device, batch_size, input_stream, input_channels,
                  depth, stream, mode, signature_grad, scalar_term):
    path = h.get_path(batch_size,
                      input_stream,
                      input_channels,
                      device,
                      path_grad=False)
    signature = signatory.signature(path,
                                    depth,
                                    stream=stream,
                                    scalar_term=scalar_term)
    if signature_grad:
        signature.requires_grad_()
    with warnings.catch_warnings():
        warnings.filterwarnings(
            'ignore',
            message=
            "The logsignature with mode='brackets' has been requested on the "
            "GPU.",
            category=UserWarning)
        logsignature = signatory_signature_to_logsignature(
            class_,
            signature,
            input_channels,
            depth,
            stream,
            mode,
            scalar_term=scalar_term)
        true_logsignature = signatory.logsignature(path,
                                                   depth,
                                                   stream=stream,
                                                   mode=mode)
    h.diff(logsignature, true_logsignature)

    if signature_grad:
        ctx = logsignature.grad_fn
        if stream:
            ctx = ctx.next_functions[0][0]
        assert type(ctx).__name__ == '_SignatureToLogsignatureFunctionBackward'
        ref = weakref.ref(ctx)
        del ctx
        del logsignature
        gc.collect()
        assert ref() is None
    else:
        assert logsignature.grad_fn is None
Ejemplo n.º 11
0
def _test_path(device, path_grad, batch_size, input_stream, input_channels,
               depth, basepoint, update_lengths, update_grads, extrarandom):
    path = h.get_path(batch_size, input_stream, input_channels, device,
                      path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    path_obj = signatory.Path(path, depth, basepoint=basepoint)

    if isinstance(basepoint, torch.Tensor):
        full_path = torch.cat([basepoint.unsqueeze(1), path], dim=1)
    elif basepoint is True:
        full_path = torch.cat([
            torch.zeros(batch_size,
                        1,
                        input_channels,
                        device=device,
                        dtype=torch.double), path
        ],
                              dim=1)
    else:
        full_path = path

    # First of all test a Path with no updates
    _test_signature(path_obj, full_path, depth, extrarandom)
    _test_logsignature(path_obj, full_path, depth, extrarandom)
    assert path_obj.depth == depth

    if len(update_lengths) > 1:
        # Then test Path with variable amounts of updates
        for length, grad in zip(update_lengths, update_grads):
            new_path = torch.rand(batch_size,
                                  length,
                                  input_channels,
                                  dtype=torch.double,
                                  device=device,
                                  requires_grad=grad)
            path_obj.update(new_path)
            full_path = torch.cat([full_path, new_path], dim=1)

        _test_signature(path_obj, full_path, depth, extrarandom)
        _test_logsignature(path_obj, full_path, depth, extrarandom)
        assert path_obj.depth == depth
Ejemplo n.º 12
0
def _test_backward(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, mode):
    path = h.get_path(batch_size, input_stream, input_channels, device, path_grad=True)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)

    # This is the test we'd like to run, but it takes forever
    #
    # if class_:
    #     def check_fn(path, basepoint):
    #         return signatory.LogSignature(depth, stream=stream, inverse=inverse, mode=mode)(path,
    #                                                                                         basepoint=basepoint)
    # else:
    #     def check_fn(path, basepoint):
    #         return signatory.logsignature(path, depth, stream=stream, basepoint=basepoint, inverse=inverse,
    #                                       mode=mode)
    # try:
    #     autograd.gradcheck(check_fn, (path, basepoint), atol=2e-05, rtol=0.002)
    # except RuntimeError:
    #     pytest.fail()

    logsignature = signatory_logsignature(class_, path, depth, stream, basepoint, inverse, mode)

    grad = torch.rand_like(logsignature)
    logsignature.backward(grad)

    path_grad = path.grad.clone()
    path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        basepoint_grad = basepoint.grad.clone()
        basepoint.grad.zero_()

    iisignature_logsignature_result = iisignature_logsignature(path, depth, stream, basepoint, inverse, mode)
    iisignature_logsignature_result.backward(grad)

    # iisignature uses float32 for this calculation so we need a lower tolerance
    h.diff(path.grad, path_grad, atol=1e-6)
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        h.diff(basepoint.grad, basepoint_grad, atol=1e-6)
Ejemplo n.º 13
0
def _test_batch_trick(class_, device, path_grad, batch_size, input_stream,
                      input_channels, depth, stream, basepoint, inverse,
                      initial):
    if device == 'cuda':
        threshold = 512
    else:
        from signatory import impl
        threshold = impl.hardware_concurrency()
        if threshold < 2:
            return  # can't test the batch trick in this case
    if round(float(threshold) / batch_size) < 2:
        batch_size = int(threshold / 2)

    path = h.get_path(batch_size, input_stream, input_channels, device,
                      path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    initial = h.get_initial(batch_size, input_channels, device, depth, initial)

    current_parallelism = signatory.max_parallelism()
    try:
        signatory.max_parallelism(1)  # disable batch trick
        signature = signatory_signature(class_, path, depth, stream, basepoint,
                                        inverse, initial)
    finally:
        signatory.max_parallelism(current_parallelism)  # enable batch trick

    batch_trick_signature = signatory.signature.__globals__[
        '_signature_batch_trick'](path,
                                  depth,
                                  stream=stream,
                                  basepoint=basepoint,
                                  inverse=inverse,
                                  initial=initial)

    assert batch_trick_signature is not None  # that the batch trick is viable in this case

    h.diff(signature, batch_trick_signature)

    can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \
                   (isinstance(initial, torch.Tensor) and initial.requires_grad)
    try:
        grad = torch.rand_like(signature)
        signature.backward(grad)
    except RuntimeError:
        assert not can_backward
        return
    else:
        assert can_backward

    if path_grad:
        path_grad_ = path.grad.clone()
        path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        basepoint_grad = basepoint.grad.clone()
        basepoint.grad.zero_()
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        initial_grad = initial.grad.clone()
        initial.grad.zero_()
    batch_trick_signature.backward(grad)
    if path_grad:
        h.diff(path_grad_, path.grad)
        path.grad.zero_()
    if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad:
        h.diff(basepoint_grad, basepoint.grad)
        basepoint.grad.zero_()
    if isinstance(initial, torch.Tensor) and initial.requires_grad:
        h.diff(initial_grad, initial.grad)
        initial.grad.zero_()
Ejemplo n.º 14
0
def _test_path(device, path_grad, batch_size, input_stream, input_channels,
               depth, basepoint, update_lengths, update_grads, scalar_term,
               extrarandom, which):
    path = h.get_path(batch_size, input_stream, input_channels, device,
                      path_grad)
    basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint)
    path_obj = signatory.Path(path,
                              depth,
                              basepoint=basepoint,
                              scalar_term=scalar_term)

    if isinstance(basepoint, torch.Tensor):
        full_path = torch.cat([basepoint.unsqueeze(1), path], dim=1)
    elif basepoint is True:
        full_path = torch.cat([
            torch.zeros(batch_size,
                        1,
                        input_channels,
                        device=device,
                        dtype=torch.double), path
        ],
                              dim=1)
    else:
        full_path = path

    if not path_grad and not (isinstance(basepoint, torch.Tensor)
                              and basepoint.requires_grad):
        backup_path_obj = copy.deepcopy(path_obj)

        # derived objects to test
        copy_path_obj = copy.copy(path_obj)
        shuffle_path_obj1, perm1 = path_obj.shuffle()
        shuffle_path_obj2, perm2 = copy.deepcopy(path_obj).shuffle_()
        getitem1 = _randint(batch_size)
        getitem_path_obj1 = path_obj[getitem1]  # integer

        all_derived = [(copy_path_obj, slice(None)),
                       (shuffle_path_obj1, perm1), (shuffle_path_obj2, perm2),
                       (getitem_path_obj1, getitem1)]

        start = _randint(batch_size)
        end = _randint(batch_size)
        getitem2 = slice(start, end)
        getitem3 = torch.randint(low=0,
                                 high=batch_size,
                                 size=(_randint(int(1.5 * batch_size)), ))
        getitem4 = torch.randint(low=0,
                                 high=batch_size,
                                 size=(_randint(int(1.5 *
                                                    batch_size)), )).numpy()
        getitem5 = torch.randint(low=0,
                                 high=batch_size,
                                 size=(_randint(int(1.5 *
                                                    batch_size)), )).tolist()
        try:
            getitem_path_obj2 = path_obj[
                getitem2]  # slice, perhaps a 'null' slice
        except IndexError as e:
            if start >= end:
                pass
            else:
                pytest.fail(str(e))
        else:
            all_derived.append((getitem_path_obj2, getitem2))
        try:
            getitem_path_obj3 = path_obj[getitem3]  # 1D tensor
        except IndexError as e:
            if len(getitem3) == 0:
                pass
            else:
                pytest.fail(str(e))
        else:
            all_derived.append((getitem_path_obj3, getitem3))
        try:
            getitem_path_obj4 = path_obj[getitem4]  # array
        except IndexError as e:
            if len(getitem4) == 0:
                pass
            else:
                pytest.fail(str(e))
        else:
            all_derived.append((getitem_path_obj4, getitem4))
        try:
            getitem_path_obj5 = path_obj[getitem5]  # list
        except IndexError as e:
            if len(getitem5) == 0:
                pass
            else:
                pytest.fail(str(e))
        else:
            all_derived.append((getitem_path_obj5, getitem5))

        if which == 'random':
            all_derived = [random.choice(all_derived)]
        elif which == 'none':
            all_derived = []

        for derived_path_obj, derived_index in all_derived:
            # tests that the derived objects do what they claim to do
            _test_derived(path_obj, derived_path_obj, derived_index,
                          extrarandom)
            # tests that the derived objects are consistent wrt themselves
            full_path_ = full_path[derived_index]
            if isinstance(derived_index, int):
                full_path_ = full_path_.unsqueeze(0)
            _test_path_obj(full_path_.size(0), input_channels, device,
                           derived_path_obj, full_path_, depth, update_lengths,
                           update_grads, scalar_term, extrarandom)
        # tests that the changes to the derived objects have not affected the original
        assert path_obj == backup_path_obj

    # finally test the original object
    _test_path_obj(batch_size, input_channels, device, path_obj, full_path,
                   depth, update_lengths, update_grads, scalar_term,
                   extrarandom)
Ejemplo n.º 15
0
def _test_repeat_and_memory_leaks(class_, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint,
                                  inverse, mode):
    cpu_path = h.get_path(batch_size, input_stream, input_channels, device='cpu', path_grad=False)
    cpu_basepoint = h.get_basepoint(batch_size, input_channels, device='cpu', basepoint=basepoint)

    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', message="The logsignature with mode='brackets' has been requested on the "
                                                  "GPU.", category=UserWarning)
        if class_:
            logsignature_instance = signatory.LogSignature(depth, stream=stream, inverse=inverse, mode=mode)
            cpu_logsignature = logsignature_instance(cpu_path, basepoint=cpu_basepoint)
        else:
            cpu_logsignature = signatory.logsignature(cpu_path, depth, stream=stream, basepoint=cpu_basepoint,
                                                      inverse=inverse, mode=mode)
    cpu_grad = torch.rand_like(cpu_logsignature)

    def one_iteration():
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.reset_max_memory_allocated()
        # device='cuda' because it's just easier to keep track of GPU memory
        cuda_path = cpu_path.to('cuda')
        if path_grad:
            cuda_path.requires_grad_()
        if isinstance(cpu_basepoint, torch.Tensor):
            cuda_basepoint = cpu_basepoint.cuda()
            if basepoint is h.with_grad:
                cuda_basepoint.requires_grad_()
        else:
            cuda_basepoint = basepoint

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', message="The logsignature with mode='brackets' has been requested on the "
                                                      "GPU.", category=UserWarning)
            if class_:
                cuda_logsignature = logsignature_instance(cuda_path, basepoint=cuda_basepoint)
            else:
                cuda_logsignature = signatory.logsignature(cuda_path, depth, stream=stream, basepoint=cuda_basepoint,
                                                           inverse=inverse, mode=mode)

        h.diff(cuda_logsignature.cpu(), cpu_logsignature)

        if path_grad:
            cuda_grad = cpu_grad.cuda()
            cuda_logsignature.backward(cuda_grad)
        torch.cuda.synchronize()
        return torch.cuda.max_memory_allocated()

    memory_used = one_iteration()

    # The calculations are essentially parallel and therefore not quite deterministic in the stream==True case. This
    # means that they sometimes use a bit of extra peak memory.
    if stream:
        memory_used *= 2

    for repeat in range(10):
        try:
            assert one_iteration() <= memory_used
        except AssertionError:
            print(repeat)
            raise