def test_abc_b_a(self): A, B, C = 3, 5, 7 EQUATION_STR = 'abc,b->a' SIZES = [(A, B, C), (B, )] OUTPUT_SIZE = (A, ) args = [ torch.nn.Parameter(torch.empty(size, device=self.device)) for size in SIZES ] for arg in args: arg.data.uniform_(-10.0, 10.0, generator=self.generator) grad = torch.empty(OUTPUT_SIZE, device=self.device) grad.uniform_(-5.0, 5.0, generator=self.generator) exp_args = [torch.exp(arg) for arg in args] exp_result = torch.einsum(EQUATION_STR, *exp_args) expected_output = torch.log(exp_result) expected_output.backward(grad) expected_grads = [arg.grad.clone() for arg in args] arg_grads = log_einsum_backward(compile_equation(EQUATION_STR), [arg.detach() for arg in args], [True for arg in args], grad, block_size=3) for arg_grad, arg_size in zip(arg_grads, SIZES): self.assertEqual(arg_grad.size(), arg_size) for arg_grad, expected_grad in zip(arg_grads, expected_grads): numpy.testing.assert_allclose(arg_grad, expected_grad, rtol=1e-5)
def test_real_einsum_forward(self): args = [ torch.rand(size, device=self.device, generator=self.generator) for size in SIZES ] expected_result = torch.einsum(EQUATION_STR, *args) self.assertEqual(expected_result.size(), OUTPUT_SIZE) result = real_einsum_forward(compile_equation(EQUATION_STR), *args, block_size=3) self.assertEqual(result.size(), OUTPUT_SIZE) numpy.testing.assert_allclose(result, expected_result, rtol=1e-6)
def test_log_einsum_forward(self): args = [torch.empty(size, device=self.device) for size in SIZES] for arg in args: arg.uniform_(-10.0, 10.0, generator=self.generator) exp_args = [torch.exp(arg) for arg in args] exp_result = torch.einsum(EQUATION_STR, *exp_args) expected_result = torch.log(exp_result) self.assertEqual(expected_result.size(), OUTPUT_SIZE) result = log_einsum_forward(compile_equation(EQUATION_STR), *args, block_size=3) self.assertEqual(result.size(), OUTPUT_SIZE) numpy.testing.assert_allclose(result, expected_result)
def test_log_viterbi_einsum_forward(self): args = [torch.empty(size, device=self.device) for size in SIZES] for arg in args: arg.uniform_(-10.0, 10.0, generator=self.generator) expected_maxval, expected_argmax = reference_log_viterbi_einsum( *args, self.device) self.assertEqual(expected_maxval.size(), OUTPUT_SIZE) self.assertEqual(expected_argmax.size(), (*OUTPUT_SIZE, 3)) maxval, argmax = log_viterbi_einsum_forward( compile_equation(EQUATION_STR), *args, block_size=3) self.assertEqual(expected_maxval.size(), OUTPUT_SIZE) self.assertEqual(expected_argmax.size(), (*OUTPUT_SIZE, 3)) numpy.testing.assert_allclose(maxval, expected_maxval) self.assertTrue(torch.equal(argmax, expected_argmax))
def test_einsum(self): args = [ torch.nn.Parameter( torch.rand(size, device=self.device, generator=self.generator)) for size in SIZES ] expected_output = torch.einsum(EQUATION_STR, *args) expected_loss = expected_output.sum() expected_loss.backward() expected_grads = [arg.grad.clone() for arg in args] for arg in args: arg.grad.zero_() output = einsum(compile_equation(EQUATION_STR), *args, block_size=3) loss = output.sum() loss.backward() grads = [arg.grad.clone() for arg in args] for grad, expected_grad in zip(grads, expected_grads): numpy.testing.assert_allclose(grad, expected_grad, rtol=1e-6)
def test_real_einsum_backward(self): args = [ torch.nn.Parameter( torch.rand(size, device=self.device, generator=self.generator)) for size in SIZES ] grad = torch.empty(OUTPUT_SIZE, device=self.device) grad.uniform_(-5.0, 5.0, generator=self.generator) expected_output = torch.einsum(EQUATION_STR, *args) expected_output.backward(grad) expected_grads = [arg.grad.clone() for arg in args] arg_grads = real_einsum_backward(compile_equation(EQUATION_STR), [arg.detach() for arg in args], [True for arg in args], grad, block_size=3) for arg_grad, arg_size in zip(arg_grads, SIZES): self.assertEqual(arg_grad.size(), arg_size) for arg_grad, expected_grad in zip(arg_grads, expected_grads): numpy.testing.assert_allclose(arg_grad, expected_grad, rtol=1e-3)
def test_log_einsum_overflow(self): # Test that log einsum does not overflow when dealing with large # values. args = [ torch.nn.Parameter(torch.empty(size, device=self.device)) for size in SIZES ] for arg in args: arg.data.uniform_(0.0, 100.0, generator=self.generator) # Make sure the arguments would cause exp() to overflow. for arg in args: self.assertTrue(torch.isinf(torch.exp(arg)).sum().ne(0).item()) output = log_einsum(compile_equation(EQUATION_STR), *args, block_size=3) # The output should not have inf or nan. self.assertTrue(torch.isfinite(output).prod().eq(1).item()) loss = output.sum() loss.backward() for arg in args: # The gradients should not have inf or nan. self.assertTrue(torch.isfinite(arg.grad).prod().eq(1).item())
def sum_block(a, dims): if not dims: return a return a.amin(dim=dims) def multiply_in_place(a, b): a[:, :, :] = (a >= b).float() return compute_sum(add_in_place, sum_block, multiply_in_place) return torch_semiring_einsum.semiring_einsum_forward( equation, args, block_size, func) equation = 'bij,bik->bjk' equation = torch_semiring_einsum.compile_equation(equation) mats = numpy.random.uniform(size=(256, 16, 256)) mats_torch = torch.from_numpy(mats).cuda() t0 = time.time() output = dominate_semiring(equation, mats_torch, mats_torch, block_size=10) x1 = output.cpu().numpy() print(time.time() - t0) # method 3: broadcast import numpy import time import torch mats = numpy.random.uniform(size=(256, 16, 512)) mats_torch = torch.from_numpy(mats).cuda() t0 = time.time()
def main(): parser = argparse.ArgumentParser( description= 'Generate data for the time and space complexity plots included in ' 'the documentation.') parser.add_argument('output', type=pathlib.Path) parser.add_argument('-A', type=int, default=10) parser.add_argument('--steps', type=int, default=10) parser.add_argument('--step-size', type=int, default=10000) parser.add_argument('--block-sizes', type=int, nargs='+', default=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 50]) args = parser.parse_args() device = torch.device('cuda') EQUATION_STR = 'ak,ak,ak->a' COMPILED_EQUATION = compile_equation(EQUATION_STR) block_sizes = args.block_sizes + ['unbounded'] pytorch_results = [] blocked_results = [(block_size, []) for block_size in block_sizes] for i in range(args.steps + 1): # Run the first iteration twice to warm things up ignore = (i == 0) if ignore: i = 1 K = i * args.step_size if ignore: print('warming up') else: print(f'K = {K}') x1, x2, x3 = [torch.rand((args.A, K), device=device) for _ in range(3)] torch.cuda.synchronize(device) base_memory = torch.cuda.memory_allocated(device) def measure_cost(einsum, equation): torch.cuda.synchronize(device) torch.cuda.reset_max_memory_allocated(device) start_time = datetime.datetime.now() y = einsum(equation, x1, x2, x3) torch.cuda.synchronize(device) duration = (datetime.datetime.now() - start_time).total_seconds() memory = torch.cuda.max_memory_allocated(device) - base_memory return {'K': K, 'duration': duration, 'memory': memory} result = measure_cost(torch.einsum, EQUATION_STR) if not ignore: pytorch_results.append(result) for block_size, results in blocked_results: print(f' block size = {block_size}') if block_size == 'unbounded': # Just use a big number. block_size = 9999999999999999 result = measure_cost( lambda *args: einsum(*args, block_size=block_size), COMPILED_EQUATION) if not ignore: results.append(result) with args.output.open('w') as fout: json.dump({ 'pytorch': pytorch_results, 'blocked': blocked_results }, fout)
def test_zero_dim(self): eq = compile_equation('->') ans = einsum(eq, torch.tensor(1.), block_size=1) self.assertAlmostEqual(ans.item(), 1.) ans = log_einsum(eq, torch.tensor(2.), block_size=1) self.assertAlmostEqual(ans.item(), 2.)
def test_compile_equation(self): equation = compile_equation(EQUATION_STR) equation.prepare_for_forward() equation.prepare_for_backward()