Exemple #1
0
 def test_abc_b_a(self):
     A, B, C = 3, 5, 7
     EQUATION_STR = 'abc,b->a'
     SIZES = [(A, B, C), (B, )]
     OUTPUT_SIZE = (A, )
     args = [
         torch.nn.Parameter(torch.empty(size, device=self.device))
         for size in SIZES
     ]
     for arg in args:
         arg.data.uniform_(-10.0, 10.0, generator=self.generator)
     grad = torch.empty(OUTPUT_SIZE, device=self.device)
     grad.uniform_(-5.0, 5.0, generator=self.generator)
     exp_args = [torch.exp(arg) for arg in args]
     exp_result = torch.einsum(EQUATION_STR, *exp_args)
     expected_output = torch.log(exp_result)
     expected_output.backward(grad)
     expected_grads = [arg.grad.clone() for arg in args]
     arg_grads = log_einsum_backward(compile_equation(EQUATION_STR),
                                     [arg.detach() for arg in args],
                                     [True for arg in args],
                                     grad,
                                     block_size=3)
     for arg_grad, arg_size in zip(arg_grads, SIZES):
         self.assertEqual(arg_grad.size(), arg_size)
     for arg_grad, expected_grad in zip(arg_grads, expected_grads):
         numpy.testing.assert_allclose(arg_grad, expected_grad, rtol=1e-5)
Exemple #2
0
 def test_real_einsum_forward(self):
     args = [
         torch.rand(size, device=self.device, generator=self.generator)
         for size in SIZES
     ]
     expected_result = torch.einsum(EQUATION_STR, *args)
     self.assertEqual(expected_result.size(), OUTPUT_SIZE)
     result = real_einsum_forward(compile_equation(EQUATION_STR),
                                  *args,
                                  block_size=3)
     self.assertEqual(result.size(), OUTPUT_SIZE)
     numpy.testing.assert_allclose(result, expected_result, rtol=1e-6)
Exemple #3
0
 def test_log_einsum_forward(self):
     args = [torch.empty(size, device=self.device) for size in SIZES]
     for arg in args:
         arg.uniform_(-10.0, 10.0, generator=self.generator)
     exp_args = [torch.exp(arg) for arg in args]
     exp_result = torch.einsum(EQUATION_STR, *exp_args)
     expected_result = torch.log(exp_result)
     self.assertEqual(expected_result.size(), OUTPUT_SIZE)
     result = log_einsum_forward(compile_equation(EQUATION_STR),
                                 *args,
                                 block_size=3)
     self.assertEqual(result.size(), OUTPUT_SIZE)
     numpy.testing.assert_allclose(result, expected_result)
Exemple #4
0
 def test_log_viterbi_einsum_forward(self):
     args = [torch.empty(size, device=self.device) for size in SIZES]
     for arg in args:
         arg.uniform_(-10.0, 10.0, generator=self.generator)
     expected_maxval, expected_argmax = reference_log_viterbi_einsum(
         *args, self.device)
     self.assertEqual(expected_maxval.size(), OUTPUT_SIZE)
     self.assertEqual(expected_argmax.size(), (*OUTPUT_SIZE, 3))
     maxval, argmax = log_viterbi_einsum_forward(
         compile_equation(EQUATION_STR), *args, block_size=3)
     self.assertEqual(expected_maxval.size(), OUTPUT_SIZE)
     self.assertEqual(expected_argmax.size(), (*OUTPUT_SIZE, 3))
     numpy.testing.assert_allclose(maxval, expected_maxval)
     self.assertTrue(torch.equal(argmax, expected_argmax))
Exemple #5
0
 def test_einsum(self):
     args = [
         torch.nn.Parameter(
             torch.rand(size, device=self.device, generator=self.generator))
         for size in SIZES
     ]
     expected_output = torch.einsum(EQUATION_STR, *args)
     expected_loss = expected_output.sum()
     expected_loss.backward()
     expected_grads = [arg.grad.clone() for arg in args]
     for arg in args:
         arg.grad.zero_()
     output = einsum(compile_equation(EQUATION_STR), *args, block_size=3)
     loss = output.sum()
     loss.backward()
     grads = [arg.grad.clone() for arg in args]
     for grad, expected_grad in zip(grads, expected_grads):
         numpy.testing.assert_allclose(grad, expected_grad, rtol=1e-6)
Exemple #6
0
 def test_real_einsum_backward(self):
     args = [
         torch.nn.Parameter(
             torch.rand(size, device=self.device, generator=self.generator))
         for size in SIZES
     ]
     grad = torch.empty(OUTPUT_SIZE, device=self.device)
     grad.uniform_(-5.0, 5.0, generator=self.generator)
     expected_output = torch.einsum(EQUATION_STR, *args)
     expected_output.backward(grad)
     expected_grads = [arg.grad.clone() for arg in args]
     arg_grads = real_einsum_backward(compile_equation(EQUATION_STR),
                                      [arg.detach() for arg in args],
                                      [True for arg in args],
                                      grad,
                                      block_size=3)
     for arg_grad, arg_size in zip(arg_grads, SIZES):
         self.assertEqual(arg_grad.size(), arg_size)
     for arg_grad, expected_grad in zip(arg_grads, expected_grads):
         numpy.testing.assert_allclose(arg_grad, expected_grad, rtol=1e-3)
Exemple #7
0
 def test_log_einsum_overflow(self):
     # Test that log einsum does not overflow when dealing with large
     # values.
     args = [
         torch.nn.Parameter(torch.empty(size, device=self.device))
         for size in SIZES
     ]
     for arg in args:
         arg.data.uniform_(0.0, 100.0, generator=self.generator)
     # Make sure the arguments would cause exp() to overflow.
     for arg in args:
         self.assertTrue(torch.isinf(torch.exp(arg)).sum().ne(0).item())
     output = log_einsum(compile_equation(EQUATION_STR),
                         *args,
                         block_size=3)
     # The output should not have inf or nan.
     self.assertTrue(torch.isfinite(output).prod().eq(1).item())
     loss = output.sum()
     loss.backward()
     for arg in args:
         # The gradients should not have inf or nan.
         self.assertTrue(torch.isfinite(arg.grad).prod().eq(1).item())
Exemple #8
0
        def sum_block(a, dims):
            if not dims:
                return a
            return a.amin(dim=dims)

        def multiply_in_place(a, b):
            a[:, :, :] = (a >= b).float()

        return compute_sum(add_in_place, sum_block, multiply_in_place)

    return torch_semiring_einsum.semiring_einsum_forward(
        equation, args, block_size, func)


equation = 'bij,bik->bjk'
equation = torch_semiring_einsum.compile_equation(equation)
mats = numpy.random.uniform(size=(256, 16, 256))
mats_torch = torch.from_numpy(mats).cuda()
t0 = time.time()
output = dominate_semiring(equation, mats_torch, mats_torch, block_size=10)
x1 = output.cpu().numpy()
print(time.time() - t0)

# method 3: broadcast
import numpy
import time
import torch

mats = numpy.random.uniform(size=(256, 16, 512))
mats_torch = torch.from_numpy(mats).cuda()
t0 = time.time()
def main():

    parser = argparse.ArgumentParser(
        description=
        'Generate data for the time and space complexity plots included in '
        'the documentation.')
    parser.add_argument('output', type=pathlib.Path)
    parser.add_argument('-A', type=int, default=10)
    parser.add_argument('--steps', type=int, default=10)
    parser.add_argument('--step-size', type=int, default=10000)
    parser.add_argument('--block-sizes',
                        type=int,
                        nargs='+',
                        default=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 50])
    args = parser.parse_args()

    device = torch.device('cuda')
    EQUATION_STR = 'ak,ak,ak->a'
    COMPILED_EQUATION = compile_equation(EQUATION_STR)

    block_sizes = args.block_sizes + ['unbounded']

    pytorch_results = []
    blocked_results = [(block_size, []) for block_size in block_sizes]
    for i in range(args.steps + 1):
        # Run the first iteration twice to warm things up
        ignore = (i == 0)
        if ignore:
            i = 1
        K = i * args.step_size
        if ignore:
            print('warming up')
        else:
            print(f'K = {K}')
        x1, x2, x3 = [torch.rand((args.A, K), device=device) for _ in range(3)]
        torch.cuda.synchronize(device)
        base_memory = torch.cuda.memory_allocated(device)

        def measure_cost(einsum, equation):
            torch.cuda.synchronize(device)
            torch.cuda.reset_max_memory_allocated(device)
            start_time = datetime.datetime.now()
            y = einsum(equation, x1, x2, x3)
            torch.cuda.synchronize(device)
            duration = (datetime.datetime.now() - start_time).total_seconds()
            memory = torch.cuda.max_memory_allocated(device) - base_memory
            return {'K': K, 'duration': duration, 'memory': memory}

        result = measure_cost(torch.einsum, EQUATION_STR)
        if not ignore:
            pytorch_results.append(result)
        for block_size, results in blocked_results:
            print(f'  block size = {block_size}')
            if block_size == 'unbounded':
                # Just use a big number.
                block_size = 9999999999999999
            result = measure_cost(
                lambda *args: einsum(*args, block_size=block_size),
                COMPILED_EQUATION)
            if not ignore:
                results.append(result)
    with args.output.open('w') as fout:
        json.dump({
            'pytorch': pytorch_results,
            'blocked': blocked_results
        }, fout)
Exemple #10
0
 def test_zero_dim(self):
     eq = compile_equation('->')
     ans = einsum(eq, torch.tensor(1.), block_size=1)
     self.assertAlmostEqual(ans.item(), 1.)
     ans = log_einsum(eq, torch.tensor(2.), block_size=1)
     self.assertAlmostEqual(ans.item(), 2.)
Exemple #11
0
 def test_compile_equation(self):
     equation = compile_equation(EQUATION_STR)
     equation.prepare_for_forward()
     equation.prepare_for_backward()