def _test_no_adjustments(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial, path_grad, scalar_term): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term) path_clone = path.clone() if isinstance(basepoint, torch.Tensor): basepoint_clone = basepoint.clone() if isinstance(initial, torch.Tensor): initial_clone = initial.clone() signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term) can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \ (isinstance(initial, torch.Tensor) and initial.requires_grad) if can_backward: grad = torch.rand_like(signature) signature_clone = signature.clone() grad_clone = grad.clone() signature.backward(grad) else: assert signature.grad_fn is None h.diff(path, path_clone) if isinstance(basepoint, torch.Tensor): h.diff(basepoint, basepoint_clone) if isinstance(initial, torch.Tensor): h.diff(initial, initial_clone) if can_backward: h.diff(signature, signature_clone) h.diff(grad, grad_clone)
def _test_forward(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial, scalar_term): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) expected_exception = (batch_size < 1) or (input_channels < 1) or (basepoint is False and input_stream < 2) or \ (input_stream < 1) try: initial = h.get_initial(batch_size, input_channels, device, depth, initial, scalar_term) signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial, scalar_term) except ValueError: if expected_exception: return else: raise else: assert not expected_exception _test_shape(signature, stream, basepoint, batch_size, input_stream, input_channels, depth, scalar_term) h.diff(signature, iisignature_signature(path, depth, stream, basepoint, inverse, initial, scalar_term)) if path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \ (isinstance(initial, torch.Tensor) and initial.requires_grad): ctx = signature.grad_fn if stream: ctx = ctx.next_functions[0][0] assert type(ctx).__name__ in ('_SignatureFunctionBackward', '_SignatureCombineFunctionBackward') ref = weakref.ref(ctx) del ctx del signature gc.collect() assert ref() is None else: assert signature.grad_fn is None
def _test_repeat_and_memory_leaks(class_, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial, scalar_term): cpu_path = h.get_path(batch_size, input_stream, input_channels, device='cpu', path_grad=False) cpu_basepoint = h.get_basepoint(batch_size, input_channels, device='cpu', basepoint=basepoint) cpu_initial = h.get_initial(batch_size, input_channels, device='cpu', depth=depth, initial=initial, scalar_term=scalar_term) with warnings.catch_warnings(): warnings.filterwarnings('ignore', message="Argument 'initial' has been set but argument 'basepoint' has " "not.", category=UserWarning) if class_: signature_instance = signatory.Signature(depth, stream=stream, inverse=inverse, scalar_term=scalar_term) cpu_signature = signature_instance(cpu_path, basepoint=cpu_basepoint, initial=cpu_initial) else: cpu_signature = signatory.signature(cpu_path, depth, stream=stream, basepoint=cpu_basepoint, inverse=inverse, initial=cpu_initial, scalar_term=scalar_term) cpu_grad = torch.rand_like(cpu_signature) def one_iteration(): gc.collect() torch.cuda.synchronize() torch.cuda.reset_max_memory_allocated() cuda_path = cpu_path.to('cuda') if path_grad: cuda_path.requires_grad_() if isinstance(cpu_basepoint, torch.Tensor): cuda_basepoint = cpu_basepoint.cuda() if basepoint is h.with_grad: cuda_basepoint.requires_grad_() else: cuda_basepoint = basepoint if isinstance(cpu_initial, torch.Tensor): cuda_initial = cpu_initial.cuda() if initial is h.with_grad: cuda_initial.requires_grad_() else: cuda_initial = initial with warnings.catch_warnings(): warnings.filterwarnings('ignore', message="Argument 'initial' has been set but argument 'basepoint' has " "not.", category=UserWarning) if class_: cuda_signature = signature_instance(cuda_path, basepoint=cuda_basepoint, initial=cuda_initial) else: cuda_signature = signatory.signature(cuda_path, depth, stream=stream, basepoint=cuda_basepoint, inverse=inverse, initial=cuda_initial, scalar_term=scalar_term) h.diff(cuda_signature.cpu(), cpu_signature) if path_grad: cuda_grad = cpu_grad.cuda() cuda_signature.backward(cuda_grad) torch.cuda.synchronize() return torch.cuda.max_memory_allocated() memory_used = one_iteration() for repeat in range(10): assert one_iteration() <= memory_used
def _test_backward(class_, device, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial): path = h.get_path(batch_size, input_stream, input_channels, device, path_grad=True) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) initial = h.get_initial(batch_size, input_channels, device, depth, initial) # This is the test we'd like to run. Unfortunately it takes forever, so we do something else instead. # # if class_: # def check_fn(path, basepoint, initial): # return signatory.Signature(depth, stream=stream, inverse=inverse)(path, basepoint=basepoint, # initial=initial) # else: # def check_fn(path, basepoint, initial): # return signatory.signature(path, depth, stream=stream, basepoint=basepoint, inverse=inverse, # initial=initial) # try: # autograd.gradcheck(check_fn, (path, basepoint, initial), atol=2e-05, rtol=0.002) # except RuntimeError: # pytest.fail() signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial) grad = torch.rand_like(signature) signature.backward(grad) path_grad = path.grad.clone() path.grad.zero_() if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: basepoint_grad = basepoint.grad.clone() basepoint.grad.zero_() if isinstance(initial, torch.Tensor) and initial.requires_grad: initial_grad = initial.grad.clone() initial.grad.zero_() iisignature_signature_result = iisignature_signature( path, depth, stream, basepoint, inverse, initial) iisignature_signature_result.backward(grad) # iisignature uses float32 for this calculation so we need a lower tolerance h.diff(path.grad, path_grad, atol=1e-4) if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: h.diff(basepoint.grad, basepoint_grad, atol=1e-4) if isinstance(initial, torch.Tensor) and initial.requires_grad: h.diff(initial.grad, initial_grad, atol=1e-4)
def _test_batch_trick(class_, device, path_grad, batch_size, input_stream, input_channels, depth, stream, basepoint, inverse, initial): if device == 'cuda': threshold = 512 else: from signatory import impl threshold = impl.hardware_concurrency() if threshold < 2: return # can't test the batch trick in this case if round(float(threshold) / batch_size) < 2: batch_size = int(threshold / 2) path = h.get_path(batch_size, input_stream, input_channels, device, path_grad) basepoint = h.get_basepoint(batch_size, input_channels, device, basepoint) initial = h.get_initial(batch_size, input_channels, device, depth, initial) current_parallelism = signatory.max_parallelism() try: signatory.max_parallelism(1) # disable batch trick signature = signatory_signature(class_, path, depth, stream, basepoint, inverse, initial) finally: signatory.max_parallelism(current_parallelism) # enable batch trick batch_trick_signature = signatory.signature.__globals__[ '_signature_batch_trick'](path, depth, stream=stream, basepoint=basepoint, inverse=inverse, initial=initial) assert batch_trick_signature is not None # that the batch trick is viable in this case h.diff(signature, batch_trick_signature) can_backward = path_grad or (isinstance(basepoint, torch.Tensor) and basepoint.requires_grad) or \ (isinstance(initial, torch.Tensor) and initial.requires_grad) try: grad = torch.rand_like(signature) signature.backward(grad) except RuntimeError: assert not can_backward return else: assert can_backward if path_grad: path_grad_ = path.grad.clone() path.grad.zero_() if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: basepoint_grad = basepoint.grad.clone() basepoint.grad.zero_() if isinstance(initial, torch.Tensor) and initial.requires_grad: initial_grad = initial.grad.clone() initial.grad.zero_() batch_trick_signature.backward(grad) if path_grad: h.diff(path_grad_, path.grad) path.grad.zero_() if isinstance(basepoint, torch.Tensor) and basepoint.requires_grad: h.diff(basepoint_grad, basepoint.grad) basepoint.grad.zero_() if isinstance(initial, torch.Tensor) and initial.requires_grad: h.diff(initial_grad, initial.grad) initial.grad.zero_()