def _test_no_adjustments(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial, path_grad, scalar_term): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term) path_clone = path.clone() if isinstance(basepoint, torch.Tensor): basepoint_clone = basepoint.clone() if isinstance(initial, torch.Tensor): initial_clone = initial.clone() signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term) can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \ (isinstance(initial, torch.Tensor) and initial.requires_grad) if can_backward: grad = torch.rand_like(signature) signature_clone = signature.clone() grad_clone = grad.clone() signature.backward(grad) else: assert signature.grad_fn is None h.diff(path, path_clone) if isinstance(basepoint, torch.Tensor): h.diff(basepoint, basepoint_clone) if isinstance(initial, torch.Tensor): h.diff(initial, initial_clone) if can_backward: h.diff(signature, signature_clone) h.diff(grad, grad_clone)
def _test_no_adjustments(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, mode, path_grad): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) path_clone = path.clone() if isinstance(basepoint, torch.Tensor): basepoint_clone = basepoint.clone() logsignature = signatory_logsignature(class_, path, depth, stream, basepoint, inverse, mode) can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) if can_backward: grad = torch.rand_like(logsignature) logsignature_clone = logsignature.clone() grad_clone = grad.clone() logsignature.backward(grad) else: assert logsignature.grad_fn is None h.diff(path, path_clone) if isinstance(basepoint, torch.Tensor): h.diff(basepoint, basepoint_clone) if can_backward: h.diff(logsignature, logsignature_clone) h.diff(grad, grad_clone)
def _test_forward(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial, scalar_term): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) expected_exception = (batch_size < 1) or (input_channels < 1) or (basepoint is False and input_stream < 2) or \ (input_stream < 1) try: initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term) signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term) except ValueError: if expected_exception: return else: raise else: assert not expected_exception _test_shape(signature, stream, basepoint, batch_size, input_stream, input_channels, depth, scalar_term) h.diff(signature, iisignature_signature(path, depth, stream, basepoint, inverse, initial, scalar_term)) if path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \ (isinstance(initial, torch.Tensor) and initial.requires_grad): ctx = signature.grad_fn if stream: ctx = ctx.next_functions[0][0] assert type(ctx).__name__ in ('_SignatureFunctionBackward', '_SignatureCombineFunctionBackward') ref = weakref.ref(ctx) del ctx del signature gc.collect() assert ref() is None else: assert signature.grad_fn is None
def _test_forward(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, mode): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) expected_exception = (batch_size < 1) or (input_channels < 1) or (basepoint is False and input_stream < 2) or \ (input_stream < 1) try: logsignature = signatory_logsignature(class_, path, depth, stream, basepoint, inverse, mode) except ValueError: if expected_exception: return else: raise else: assert not expected_exception _test_shape(logsignature, mode, batch_size, input_stream, input_channels, depth, stream, basepoint) h.diff(logsignature, iisignature_logsignature(path, depth, stream, basepoint, inverse, mode)) # Check that the 'ctx' object is properly garbage collected and we don't have a memory leak # (Easy to accidentally happen due to PyTorch bug 25340) if path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad): ctx = logsignature.grad_fn if stream: ctx = ctx.next_functions[0][0] assert type(ctx).__name__ == '_SignatureToLogsignatureFunctionBackward' ref = weakref.ref(ctx) del ctx del logsignature gc.collect() assert ref() is None else: assert logsignature.grad_fn is None
def _test_repeat_and_memory_leaks(class_, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial, scalar_term): cpu_path = h.get_path(batch_size, input_stream, input_channels, device='cpu', path_grad=False) cpu_basepoint = h.get_basepoint(batch_size, input_channels, device='cpu', basepoint=basepoint) cpu_initial = h.get_initial(batch_size, input_channels, device='cpu', depth=depth, initial=initial, scalar_term=scalar_term) with warnings.catch_warnings(): warnings.filterwarnings('ignore', message="Argument 'initial' has been set but argument 'basepoint' has " "not.", category=UserWarning) if class_: signature_instance = signatory.Signature(depth, stream=stream, inverse=inverse, scalar_term=scalar_term) cpu_signature = signature_instance(cpu_path, basepoint=cpu_basepoint, initial=cpu_initial) else: cpu_signature = signatory.signature(cpu_path, depth, stream=stream, basepoint=cpu_basepoint, inverse=inverse, initial=cpu_initial, scalar_term=scalar_term) cpu_grad = torch.rand_like(cpu_signature) def one_iteration(): gc.collect() torch.cuda.synchronize() torch.cuda.reset_max_memory_allocated() cuda_path = cpu_path.to('cuda') if path_grad: cuda_path.requires_grad_() if isinstance(cpu_basepoint, torch.Tensor): cuda_basepoint = cpu_basepoint.cuda() if basepoint is h.with_grad: cuda_basepoint.requires_grad_() else: cuda_basepoint = basepoint if isinstance(cpu_initial, torch.Tensor): cuda_initial = cpu_initial.cuda() if initial is h.with_grad: cuda_initial.requires_grad_() else: cuda_initial = initial with warnings.catch_warnings(): warnings.filterwarnings('ignore', message="Argument 'initial' has been set but argument 'basepoint' has " "not.", category=UserWarning) if class_: cuda_signature = signature_instance(cuda_path, basepoint=cuda_basepoint, initial=cuda_initial) else: cuda_signature = signatory.signature(cuda_path, depth, stream=stream, basepoint=cuda_basepoint, inverse=inverse, initial=cuda_initial, scalar_term=scalar_term) h.diff(cuda_signature.cpu(), cpu_signature) if path_grad: cuda_grad = cpu_grad.cuda() cuda_signature.backward(cuda_grad) torch.cuda.synchronize() return torch.cuda.max_memory_allocated() memory_used = one_iteration() for repeat in range(10): assert one_iteration() <= memory_used
def _test_backward(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad=True) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) initial = h.get_initial(batch_size, input_channels, device, depth, initial) # This is the test we'd like to run. Unfortunately it takes forever, so we do something else instead. # # if class_: # def check_fn(path, basepoint, initial): # return signatory.Signature(depth, stream=stream, inverse=inverse)(path, basepoint=basepoint, # initial=initial) # else: # def check_fn(path, basepoint, initial): # return signatory.signature(path, depth, stream=stream, basepoint=basepoint, inverse=inverse, # initial=initial) # try: # autograd.gradcheck(check_fn, (path, basepoint, initial), atol=2e-05, rtol=0.002) # except RuntimeError: # pytest.fail() signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial) grad = torch.rand_like(signature) signature.backward(grad) path_grad = path.grad.clone() path.grad.zero_() if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: basepoint_grad = basepoint.grad.clone() basepoint.grad.zero_() if isinstance(initial, torch.Tensor) and initial.requires_grad: initial_grad = initial.grad.clone() initial.grad.zero_() iisignature_signature_result = iisignature_signature( path, depth, stream, basepoint, inverse, initial) iisignature_signature_result.backward(grad) # iisignature uses float32 for this calculation so we need a lower tolerance h.diff(path.grad, path_grad, atol=1e-4) if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: h.diff(basepoint.grad, basepoint_grad, atol=1e-4) if isinstance(initial, torch.Tensor) and initial.requires_grad: h.diff(initial.grad, initial_grad, atol=1e-4)
def _test_path(device, path_grad, batch_size, input_stream, input_channels, depth, basepoint, update_lengths, update_grads, extrarandom): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) path_obj = signatory.Path(path, depth, basepoint=basepoint) if isinstance(basepoint, torch.Tensor): full_path = torch.cat([basepoint.unsqueeze(1), path], dim=1) elif basepoint is True: full_path = torch.cat([ torch.zeros(batch_size, 1, input_channels, device=device, dtype=torch.double), path ], dim=1) else: full_path = path # First of all test a Path with no updates _test_signature(path_obj, full_path, depth, extrarandom) _test_logsignature(path_obj, full_path, depth, extrarandom) assert path_obj.depth == depth if len(update_lengths) > 1: # Then test Path with variable amounts of updates for length, grad in zip(update_lengths, update_grads): new_path = torch.rand(batch_size, length, input_channels, dtype=torch.double, device=device, requires_grad=grad) path_obj.update(new_path) full_path = torch.cat([full_path, new_path], dim=1) _test_signature(path_obj, full_path, depth, extrarandom) _test_logsignature(path_obj, full_path, depth, extrarandom) assert path_obj.depth == depth
def _test_backward(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, mode): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad=True) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) # This is the test we'd like to run, but it takes forever # # if class_: # def check_fn(path, basepoint): # return signatory.LogSignature(depth, stream=stream, inverse=inverse, mode=mode)(path, # basepoint=basepoint) # else: # def check_fn(path, basepoint): # return signatory.logsignature(path, depth, stream=stream, basepoint=basepoint, inverse=inverse, # mode=mode) # try: # autograd.gradcheck(check_fn, (path, basepoint), atol=2e-05, rtol=0.002) # except RuntimeError: # pytest.fail() logsignature = signatory_logsignature(class_, path, depth, stream, basepoint, inverse, mode) grad = torch.rand_like(logsignature) logsignature.backward(grad) path_grad = path.grad.clone() path.grad.zero_() if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: basepoint_grad = basepoint.grad.clone() basepoint.grad.zero_() iisignature_logsignature_result = iisignature_logsignature(path, depth, stream, basepoint, inverse, mode) iisignature_logsignature_result.backward(grad) # iisignature uses float32 for this calculation so we need a lower tolerance h.diff(path.grad, path_grad, atol=1e-6) if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: h.diff(basepoint.grad, basepoint_grad, atol=1e-6)
def _test_batch_trick(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial): if device == 'cuda': threshold = 512 else: from signatory import impl threshold = impl.hardware_concurrency() if threshold < 2: return # can't test the batch trick in this case if round(float(threshold) / batch_size) < 2: batch_size = int(threshold / 2) path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) initial = h.get_initial(batch_size, input_channels, device, depth, initial) current_parallelism = signatory.max_parallelism() try: signatory.max_parallelism(1) # disable batch trick signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial) finally: signatory.max_parallelism(current_parallelism) # enable batch trick batch_trick_signature = signatory.signature.__globals__[ '_signature_batch_trick'](path, depth, stream=stream, basepoint=basepoint, inverse=inverse, initial=initial) assert batch_trick_signature is not None # that the batch trick is viable in this case h.diff(signature, batch_trick_signature) can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \ (isinstance(initial, torch.Tensor) and initial.requires_grad) try: grad = torch.rand_like(signature) signature.backward(grad) except RuntimeError: assert not can_backward return else: assert can_backward if path_grad: path_grad_ = path.grad.clone() path.grad.zero_() if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: basepoint_grad = basepoint.grad.clone() basepoint.grad.zero_() if isinstance(initial, torch.Tensor) and initial.requires_grad: initial_grad = initial.grad.clone() initial.grad.zero_() batch_trick_signature.backward(grad) if path_grad: h.diff(path_grad_, path.grad) path.grad.zero_() if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: h.diff(basepoint_grad, basepoint.grad) basepoint.grad.zero_() if isinstance(initial, torch.Tensor) and initial.requires_grad: h.diff(initial_grad, initial.grad) initial.grad.zero_()
def _test_path(device, path_grad, batch_size, input_stream, input_channels, depth, basepoint, update_lengths, update_grads, scalar_term, extrarandom, which): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) path_obj = signatory.Path(path, depth, basepoint=basepoint, scalar_term=scalar_term) if isinstance(basepoint, torch.Tensor): full_path = torch.cat([basepoint.unsqueeze(1), path], dim=1) elif basepoint is True: full_path = torch.cat([ torch.zeros(batch_size, 1, input_channels, device=device, dtype=torch.double), path ], dim=1) else: full_path = path if not path_grad and not (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad): backup_path_obj = copy.deepcopy(path_obj) # derived objects to test copy_path_obj = copy.copy(path_obj) shuffle_path_obj1, perm1 = path_obj.shuffle() shuffle_path_obj2, perm2 = copy.deepcopy(path_obj).shuffle_() getitem1 = _randint(batch_size) getitem_path_obj1 = path_obj[getitem1] # integer all_derived = [(copy_path_obj, slice(None)), (shuffle_path_obj1, perm1), (shuffle_path_obj2, perm2), (getitem_path_obj1, getitem1)] start = _randint(batch_size) end = _randint(batch_size) getitem2 = slice(start, end) getitem3 = torch.randint(low=0, high=batch_size, size=(_randint(int(1.5 * batch_size)), )) getitem4 = torch.randint(low=0, high=batch_size, size=(_randint(int(1.5 * batch_size)), )).numpy() getitem5 = torch.randint(low=0, high=batch_size, size=(_randint(int(1.5 * batch_size)), )).tolist() try: getitem_path_obj2 = path_obj[ getitem2] # slice, perhaps a 'null' slice except IndexError as e: if start >= end: pass else: pytest.fail(str(e)) else: all_derived.append((getitem_path_obj2, getitem2)) try: getitem_path_obj3 = path_obj[getitem3] # 1D tensor except IndexError as e: if len(getitem3) == 0: pass else: pytest.fail(str(e)) else: all_derived.append((getitem_path_obj3, getitem3)) try: getitem_path_obj4 = path_obj[getitem4] # array except IndexError as e: if len(getitem4) == 0: pass else: pytest.fail(str(e)) else: all_derived.append((getitem_path_obj4, getitem4)) try: getitem_path_obj5 = path_obj[getitem5] # list except IndexError as e: if len(getitem5) == 0: pass else: pytest.fail(str(e)) else: all_derived.append((getitem_path_obj5, getitem5)) if which == 'random': all_derived = [random.choice(all_derived)] elif which == 'none': all_derived = [] for derived_path_obj, derived_index in all_derived: # tests that the derived objects do what they claim to do _test_derived(path_obj, derived_path_obj, derived_index, extrarandom) # tests that the derived objects are consistent wrt themselves full_path_ = full_path[derived_index] if isinstance(derived_index, int): full_path_ = full_path_.unsqueeze(0) _test_path_obj(full_path_.size(0), input_channels, device, derived_path_obj, full_path_, depth, update_lengths, update_grads, scalar_term, extrarandom) # tests that the changes to the derived objects have not affected the original assert path_obj == backup_path_obj # finally test the original object _test_path_obj(batch_size, input_channels, device, path_obj, full_path, depth, update_lengths, update_grads, scalar_term, extrarandom)
def _test_repeat_and_memory_leaks(class_, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, mode): cpu_path = h.get_path(batch_size, input_stream, input_channels, device='cpu', path_grad=False) cpu_basepoint = h.get_basepoint(batch_size, input_channels, device='cpu', basepoint=basepoint) with warnings.catch_warnings(): warnings.filterwarnings('ignore', message="The logsignature with mode='brackets' has been requested on the " "GPU.", category=UserWarning) if class_: logsignature_instance = signatory.LogSignature(depth, stream=stream, inverse=inverse, mode=mode) cpu_logsignature = logsignature_instance(cpu_path, basepoint=cpu_basepoint) else: cpu_logsignature = signatory.logsignature(cpu_path, depth, stream=stream, basepoint=cpu_basepoint, inverse=inverse, mode=mode) cpu_grad = torch.rand_like(cpu_logsignature) def one_iteration(): gc.collect() torch.cuda.synchronize() torch.cuda.reset_max_memory_allocated() # device='cuda' because it's just easier to keep track of GPU memory cuda_path = cpu_path.to('cuda') if path_grad: cuda_path.requires_grad_() if isinstance(cpu_basepoint, torch.Tensor): cuda_basepoint = cpu_basepoint.cuda() if basepoint is h.with_grad: cuda_basepoint.requires_grad_() else: cuda_basepoint = basepoint with warnings.catch_warnings(): warnings.filterwarnings('ignore', message="The logsignature with mode='brackets' has been requested on the " "GPU.", category=UserWarning) if class_: cuda_logsignature = logsignature_instance(cuda_path, basepoint=cuda_basepoint) else: cuda_logsignature = signatory.logsignature(cuda_path, depth, stream=stream, basepoint=cuda_basepoint, inverse=inverse, mode=mode) h.diff(cuda_logsignature.cpu(), cpu_logsignature) if path_grad: cuda_grad = cpu_grad.cuda() cuda_logsignature.backward(cuda_grad) torch.cuda.synchronize() return torch.cuda.max_memory_allocated() memory_used = one_iteration() # The calculations are essentially parallel and therefore not quite deterministic in the stream==True case. This # means that they sometimes use a bit of extra peak memory. if stream: memory_used *= 2 for repeat in range(10): try: assert one_iteration() <= memory_used except AssertionError: print(repeat) raise