Exemplo n.º 1
0
def test_input_weights_python():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    # - shared_weights = False -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=False)
    bdim = random.randint(1, 3)
    x1 = irreps_in1.randn(bdim, -1)
    x2 = irreps_in2.randn(bdim, -1)
    w = [
        torch.randn((bdim, ) + ins.path_shape) for ins in m.instructions
        if ins.has_weight
    ]
    m(x1, x2, w)
    # - shared_weights = True -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=True)
    bdim = random.randint(1, 3)
    x1 = irreps_in1.randn(bdim, -1)
    x2 = irreps_in2.randn(bdim, -1)
    w = [
        torch.randn(ins.path_shape) for ins in m.instructions if ins.has_weight
    ]
    m(x1, x2, w)
Exemplo n.º 2
0
def test_input_weights_jit():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    # - shared_weights = False -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=False)
    traced = assert_auto_jitable(m)
    x1 = irreps_in1.randn(2, -1)
    x2 = irreps_in2.randn(2, -1)
    w = torch.randn(2, m.weight_numel)
    with pytest.raises((RuntimeError, torch.jit.Error)):
        m(x1, x2)  # it should require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2)  # it should also require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2, w[0])  # it should reject insufficient weights
    # Does the trace give right results?
    assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))

    # Confirm that weird batch dimensions give the same results
    for f in (m, traced):
        x1 = irreps_in1.randn(2, 1, 4, -1)
        x2 = irreps_in2.randn(2, 3, 1, -1)
        w = torch.randn(3, 4, f.weight_numel)
        assert torch.allclose(
            f(x1, x2, w).reshape(24, -1),
            f(
                x1.expand(2, 3, 4, -1).reshape(24, -1),
                x2.expand(2, 3, 4, -1).reshape(24, -1),
                w[None].expand(2, 3, 4, -1).reshape(24, -1)))
        assert torch.allclose(
            f.right(x2, w).reshape(24, -1),
            f.right(
                x2.expand(2, 3, 4, -1).reshape(24, -1),
                w[None].expand(2, 3, 4, -1).reshape(24, -1)).reshape(24, -1))

    # - shared_weights = True -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=True)
    traced = assert_auto_jitable(m)
    w = torch.randn(m.weight_numel)
    with pytest.raises((RuntimeError, torch.jit.Error)):
        m(x1, x2)  # it should require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2)  # it should also require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2, torch.randn(
            2, m.weight_numel))  # it should reject too many weights
    # Does the trace give right results?
    assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))
Exemplo n.º 3
0
def test_specialized_code(normalization, mode, weighted, float_tolerance):
    irreps_in1 = Irreps('4x0e + 4x1e + 4x2e')
    irreps_in2 = Irreps('5x0e + 5x1e + 5x2e')
    irreps_out = Irreps('6x0e + 6x1e + 6x2e')

    if mode == 'uvu':
        irreps_out = irreps_in1
    elif mode == 'uvv':
        irreps_out = irreps_in2
    elif mode == 'uuu':
        irreps_in2 = irreps_in1
        irreps_out = irreps_in1
    elif mode == 'uuw':
        irreps_in2 = irreps_in1
        # When unweighted, uuw is a plain sum over u and requires an output mul of 1
        if not weighted:
            irreps_out = Irreps([(1, ir) for _, ir in irreps_out])

    ins = [
        (0, 0, 0, mode, weighted, 1.0),
        (0, 1, 1, mode, weighted, 1.0),
        (1, 0, 1, mode, weighted, 1.0),
        (1, 1, 0, mode, weighted, 1.0),
        (1, 1, 1, mode, weighted, 1.0),
        (0, 2, 2, mode, weighted, 1.0),
        (2, 0, 2, mode, weighted, 1.0),
        (2, 2, 0, mode, weighted, 1.0),
        (2, 1, 1, mode, weighted, 1.0),
    ]
    tp1 = TensorProduct(irreps_in1,
                        irreps_in2,
                        irreps_out,
                        ins,
                        normalization=normalization,
                        _specialized_code=False)
    tp2 = TensorProduct(irreps_in1,
                        irreps_in2,
                        irreps_out,
                        ins,
                        normalization=normalization,
                        _specialized_code=True)
    with torch.no_grad():
        tp2.weight[:] = tp1.weight

    x = irreps_in1.randn(3, -1)
    y = irreps_in2.randn(3, -1)
    assert (tp1(x, y) - tp2(x, y)).abs().max() < float_tolerance
    assert (tp1.right(y) - tp2.right(y)).abs().max() < float_tolerance
Exemplo n.º 4
0
def test_weight_view_for_instruction():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    x1 = irreps_in1.randn(2, -1)
    x2 = irreps_in2.randn(2, -1)
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)

    # Find all paths to the first output
    ins_idexes = [i for i, ins in enumerate(m.instructions) if ins.i_out == 0]
    with torch.no_grad():
        for i in ins_idexes:
            m.weight_view_for_instruction(i).zero_()

    out = m(x1, x2)
    assert torch.all(out[:, :1] == 0.0)
    assert torch.any(out[:, 1:] > 0.0)
Exemplo n.º 5
0
def test_weight_views():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    batchdim = 3
    x1 = irreps_in1.randn(batchdim, -1)
    x2 = irreps_in2.randn(batchdim, -1)
    # shared weights
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)
    with torch.no_grad():
        for w in m.weight_views():
            w.zero_()
    assert torch.all(m(x1, x2) == 0.0)

    # unshared weights
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    shared_weights=False)
    weights = torch.randn(batchdim, m.weight_numel)
    with torch.no_grad():
        for w in m.weight_views(weights):
            w.zero_()
    assert torch.all(m(x1, x2, weights) == 0.0)
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(prog="tensor_product_benchmark")
    parser.add_argument("--jit", type=t_or_f, default=True)
    parser.add_argument("--irreps-in1",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3e")
    parser.add_argument("--irreps-in2",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3e")
    parser.add_argument("--irreps-out",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3e")
    parser.add_argument("--cuda", type=t_or_f, default=True)
    parser.add_argument("--backward", type=t_or_f, default=True)
    parser.add_argument("--opt-ein", type=t_or_f, default=True)
    parser.add_argument("--specialized-code", type=t_or_f, default=True)
    parser.add_argument("-w", type=int, default=10)
    parser.add_argument("-n", type=int, default=3)
    parser.add_argument("--batch", type=int, default=10)

    args = parser.parse_args()

    device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu'
    args.cuda = device == 'cuda'

    if args.cuda:
        # Workaround for CUDA driver issues
        # See https://github.com/pytorch/pytorch/issues/60158#issuecomment-866294291
        with torch.profiler.profile() as _:
            pass

    print("======= Benchmark with settings: ======")
    for key, val in vars(args).items():
        print(f"{key:>18} : {val}")
    print("=" * 40)

    irreps_in1 = Irreps(args.irreps_in1)
    irreps_in2 = Irreps(args.irreps_in2)
    irreps_out = Irreps(args.irreps_out)
    tp = FullyConnectedTensorProduct(irreps_in1,
                                     irreps_in2,
                                     irreps_out,
                                     _specialized_code=args.specialized_code,
                                     _optimize_einsums=args.opt_ein)
    tp = tp.to(device=device)

    inputs = [(irreps_in1.randn(args.batch, -1).to(device=device),
               irreps_in2.randn(args.batch, -1).to(device=device))
              for _ in range(1 + args.w + args.n)]
    if args.backward:
        for tmp in inputs:
            for t in tmp:
                t.requires_grad_(True)
    inputs = iter(inputs)

    # compile
    if args.jit:
        print("JITing...")
        tp = compile(tp)

    print("starting...")

    called_num = [0]

    def trace_handler(p):
        print(p.key_averages().table(sort_by="self_cuda_time_total",
                                     row_limit=-1))
        p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json")
        called_num[0] += 1

    with torch.profiler.profile(activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA
    ],
                                schedule=torch.profiler.schedule(
                                    wait=1, warmup=args.w, active=args.n),
                                on_trace_ready=trace_handler) as p:
        for _ in range(1 + args.w + args.n):
            out = tp(*next(inputs))
            if args.backward:
                # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones
                out.tanh().sum().backward()
            p.step()
Exemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(prog="tensor_product_benchmark")
    parser.add_argument("--jit", type=t_or_f, default=True)
    parser.add_argument("--irreps",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3o")
    parser.add_argument("--irreps-in1", type=str, default=None)
    parser.add_argument("--irreps-in2", type=str, default=None)
    parser.add_argument("--irreps-out", type=str, default=None)
    parser.add_argument("--cuda", type=t_or_f, default=True)
    parser.add_argument("--backward", type=t_or_f, default=True)
    parser.add_argument("--opt-ein", type=t_or_f, default=True)
    parser.add_argument("--specialized-code", type=t_or_f, default=True)
    parser.add_argument("--elementwise", action='store_true')
    parser.add_argument("-n", type=int, default=1000)
    parser.add_argument("--batch", type=int, default=10)

    args = parser.parse_args()

    device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu'
    args.cuda = device == 'cuda'

    print("======= Benchmark with settings: ======")
    for key, val in vars(args).items():
        print(f"{key:>18} : {val}")
    print("=" * 40)

    irreps_in1 = Irreps(args.irreps_in1 if args.irreps_in1 else args.irreps)
    irreps_in2 = Irreps(args.irreps_in2 if args.irreps_in2 else args.irreps)
    irreps_out = Irreps(args.irreps_out if args.irreps_out else args.irreps)

    if args.elementwise:
        tp = ElementwiseTensorProduct(irreps_in1,
                                      irreps_in2,
                                      _specialized_code=args.specialized_code,
                                      _optimize_einsums=args.opt_ein)
        if args.backward:
            print(
                "Elementwise TP has no weights, cannot backward. Setting --backward False."
            )
            args.backward = False
    else:
        tp = FullyConnectedTensorProduct(
            irreps_in1,
            irreps_in2,
            irreps_out,
            _specialized_code=args.specialized_code,
            _optimize_einsums=args.opt_ein)
    tp = tp.to(device=device)
    assert len(tp.instructions) > 0, "Bad irreps, no instructions"
    print(f"Tensor product: {tp}")
    print("Instructions:")
    for ins in tp.instructions:
        print(f"  {ins}")

    # from https://pytorch.org/docs/master/_modules/torch/utils/benchmark/utils/timer.html#Timer.timeit
    warmup = max(int(args.n // 100), 1)

    inputs = iter([(irreps_in1.randn(args.batch, -1).to(device=device),
                    irreps_in2.randn(args.batch, -1).to(device=device))
                   for _ in range(args.n + warmup)])

    # compile
    if args.jit:
        tp = compile(tp)

    print("starting...")

    # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones
    t = Timer(
        stmt=("tp.zero_grad()\n"
              "out = tp(*next(inputs))\n" +
              ("out.tanh().sum().backward()\n" if args.backward else '')),
        globals={
            'tp': tp,
            'inputs': inputs
        })

    perloop = t.timeit(args.n)

    print()
    print(perloop)