def test_input_weights_python(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") # - shared_weights = False - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=False) bdim = random.randint(1, 3) x1 = irreps_in1.randn(bdim, -1) x2 = irreps_in2.randn(bdim, -1) w = [ torch.randn((bdim, ) + ins.path_shape) for ins in m.instructions if ins.has_weight ] m(x1, x2, w) # - shared_weights = True - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=True) bdim = random.randint(1, 3) x1 = irreps_in1.randn(bdim, -1) x2 = irreps_in2.randn(bdim, -1) w = [ torch.randn(ins.path_shape) for ins in m.instructions if ins.has_weight ] m(x1, x2, w)
def test_input_weights_jit(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") # - shared_weights = False - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=False) traced = assert_auto_jitable(m) x1 = irreps_in1.randn(2, -1) x2 = irreps_in2.randn(2, -1) w = torch.randn(2, m.weight_numel) with pytest.raises((RuntimeError, torch.jit.Error)): m(x1, x2) # it should require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2) # it should also require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2, w[0]) # it should reject insufficient weights # Does the trace give right results? assert torch.allclose(m(x1, x2, w), traced(x1, x2, w)) # Confirm that weird batch dimensions give the same results for f in (m, traced): x1 = irreps_in1.randn(2, 1, 4, -1) x2 = irreps_in2.randn(2, 3, 1, -1) w = torch.randn(3, 4, f.weight_numel) assert torch.allclose( f(x1, x2, w).reshape(24, -1), f( x1.expand(2, 3, 4, -1).reshape(24, -1), x2.expand(2, 3, 4, -1).reshape(24, -1), w[None].expand(2, 3, 4, -1).reshape(24, -1))) assert torch.allclose( f.right(x2, w).reshape(24, -1), f.right( x2.expand(2, 3, 4, -1).reshape(24, -1), w[None].expand(2, 3, 4, -1).reshape(24, -1)).reshape(24, -1)) # - shared_weights = True - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=True) traced = assert_auto_jitable(m) w = torch.randn(m.weight_numel) with pytest.raises((RuntimeError, torch.jit.Error)): m(x1, x2) # it should require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2) # it should also require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2, torch.randn( 2, m.weight_numel)) # it should reject too many weights # Does the trace give right results? assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))
def test_specialized_code(normalization, mode, weighted, float_tolerance): irreps_in1 = Irreps('4x0e + 4x1e + 4x2e') irreps_in2 = Irreps('5x0e + 5x1e + 5x2e') irreps_out = Irreps('6x0e + 6x1e + 6x2e') if mode == 'uvu': irreps_out = irreps_in1 elif mode == 'uvv': irreps_out = irreps_in2 elif mode == 'uuu': irreps_in2 = irreps_in1 irreps_out = irreps_in1 elif mode == 'uuw': irreps_in2 = irreps_in1 # When unweighted, uuw is a plain sum over u and requires an output mul of 1 if not weighted: irreps_out = Irreps([(1, ir) for _, ir in irreps_out]) ins = [ (0, 0, 0, mode, weighted, 1.0), (0, 1, 1, mode, weighted, 1.0), (1, 0, 1, mode, weighted, 1.0), (1, 1, 0, mode, weighted, 1.0), (1, 1, 1, mode, weighted, 1.0), (0, 2, 2, mode, weighted, 1.0), (2, 0, 2, mode, weighted, 1.0), (2, 2, 0, mode, weighted, 1.0), (2, 1, 1, mode, weighted, 1.0), ] tp1 = TensorProduct(irreps_in1, irreps_in2, irreps_out, ins, normalization=normalization, _specialized_code=False) tp2 = TensorProduct(irreps_in1, irreps_in2, irreps_out, ins, normalization=normalization, _specialized_code=True) with torch.no_grad(): tp2.weight[:] = tp1.weight x = irreps_in1.randn(3, -1) y = irreps_in2.randn(3, -1) assert (tp1(x, y) - tp2(x, y)).abs().max() < float_tolerance assert (tp1.right(y) - tp2.right(y)).abs().max() < float_tolerance
def test_weight_view_for_instruction(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") x1 = irreps_in1.randn(2, -1) x2 = irreps_in2.randn(2, -1) m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) # Find all paths to the first output ins_idexes = [i for i, ins in enumerate(m.instructions) if ins.i_out == 0] with torch.no_grad(): for i in ins_idexes: m.weight_view_for_instruction(i).zero_() out = m(x1, x2) assert torch.all(out[:, :1] == 0.0) assert torch.any(out[:, 1:] > 0.0)
def test_weight_views(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") batchdim = 3 x1 = irreps_in1.randn(batchdim, -1) x2 = irreps_in2.randn(batchdim, -1) # shared weights m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) with torch.no_grad(): for w in m.weight_views(): w.zero_() assert torch.all(m(x1, x2) == 0.0) # unshared weights m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, shared_weights=False) weights = torch.randn(batchdim, m.weight_numel) with torch.no_grad(): for w in m.weight_views(weights): w.zero_() assert torch.all(m(x1, x2, weights) == 0.0)
def main(): parser = argparse.ArgumentParser(prog="tensor_product_benchmark") parser.add_argument("--jit", type=t_or_f, default=True) parser.add_argument("--irreps-in1", type=str, default="8x0e + 8x1e + 8x2e + 8x3e") parser.add_argument("--irreps-in2", type=str, default="8x0e + 8x1e + 8x2e + 8x3e") parser.add_argument("--irreps-out", type=str, default="8x0e + 8x1e + 8x2e + 8x3e") parser.add_argument("--cuda", type=t_or_f, default=True) parser.add_argument("--backward", type=t_or_f, default=True) parser.add_argument("--opt-ein", type=t_or_f, default=True) parser.add_argument("--specialized-code", type=t_or_f, default=True) parser.add_argument("-w", type=int, default=10) parser.add_argument("-n", type=int, default=3) parser.add_argument("--batch", type=int, default=10) args = parser.parse_args() device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu' args.cuda = device == 'cuda' if args.cuda: # Workaround for CUDA driver issues # See https://github.com/pytorch/pytorch/issues/60158#issuecomment-866294291 with torch.profiler.profile() as _: pass print("======= Benchmark with settings: ======") for key, val in vars(args).items(): print(f"{key:>18} : {val}") print("=" * 40) irreps_in1 = Irreps(args.irreps_in1) irreps_in2 = Irreps(args.irreps_in2) irreps_out = Irreps(args.irreps_out) tp = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein) tp = tp.to(device=device) inputs = [(irreps_in1.randn(args.batch, -1).to(device=device), irreps_in2.randn(args.batch, -1).to(device=device)) for _ in range(1 + args.w + args.n)] if args.backward: for tmp in inputs: for t in tmp: t.requires_grad_(True) inputs = iter(inputs) # compile if args.jit: print("JITing...") tp = compile(tp) print("starting...") called_num = [0] def trace_handler(p): print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json") called_num[0] += 1 with torch.profiler.profile(activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA ], schedule=torch.profiler.schedule( wait=1, warmup=args.w, active=args.n), on_trace_ready=trace_handler) as p: for _ in range(1 + args.w + args.n): out = tp(*next(inputs)) if args.backward: # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones out.tanh().sum().backward() p.step()
def main(): parser = argparse.ArgumentParser(prog="tensor_product_benchmark") parser.add_argument("--jit", type=t_or_f, default=True) parser.add_argument("--irreps", type=str, default="8x0e + 8x1e + 8x2e + 8x3o") parser.add_argument("--irreps-in1", type=str, default=None) parser.add_argument("--irreps-in2", type=str, default=None) parser.add_argument("--irreps-out", type=str, default=None) parser.add_argument("--cuda", type=t_or_f, default=True) parser.add_argument("--backward", type=t_or_f, default=True) parser.add_argument("--opt-ein", type=t_or_f, default=True) parser.add_argument("--specialized-code", type=t_or_f, default=True) parser.add_argument("--elementwise", action='store_true') parser.add_argument("-n", type=int, default=1000) parser.add_argument("--batch", type=int, default=10) args = parser.parse_args() device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu' args.cuda = device == 'cuda' print("======= Benchmark with settings: ======") for key, val in vars(args).items(): print(f"{key:>18} : {val}") print("=" * 40) irreps_in1 = Irreps(args.irreps_in1 if args.irreps_in1 else args.irreps) irreps_in2 = Irreps(args.irreps_in2 if args.irreps_in2 else args.irreps) irreps_out = Irreps(args.irreps_out if args.irreps_out else args.irreps) if args.elementwise: tp = ElementwiseTensorProduct(irreps_in1, irreps_in2, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein) if args.backward: print( "Elementwise TP has no weights, cannot backward. Setting --backward False." ) args.backward = False else: tp = FullyConnectedTensorProduct( irreps_in1, irreps_in2, irreps_out, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein) tp = tp.to(device=device) assert len(tp.instructions) > 0, "Bad irreps, no instructions" print(f"Tensor product: {tp}") print("Instructions:") for ins in tp.instructions: print(f" {ins}") # from https://pytorch.org/docs/master/_modules/torch/utils/benchmark/utils/timer.html#Timer.timeit warmup = max(int(args.n // 100), 1) inputs = iter([(irreps_in1.randn(args.batch, -1).to(device=device), irreps_in2.randn(args.batch, -1).to(device=device)) for _ in range(args.n + warmup)]) # compile if args.jit: tp = compile(tp) print("starting...") # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones t = Timer( stmt=("tp.zero_grad()\n" "out = tp(*next(inputs))\n" + ("out.tanh().sum().backward()\n" if args.backward else '')), globals={ 'tp': tp, 'inputs': inputs }) perloop = t.timeit(args.n) print() print(perloop)