def test_empty_inputs(): tp = FullyConnectedTensorProduct('0e + 1e', '0e + 1e', '0e + 1e') out = tp(torch.randn(2, 1, 0, 1, 4), torch.randn(1, 2, 0, 3, 4)) assert out.shape == (2, 2, 0, 3, 4) out = tp.right(torch.randn(1, 2, 0, 3, 4)) assert out.shape == (1, 2, 0, 3, 4, 4)
def test_input_weights_python(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") # - shared_weights = False - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=False) bdim = random.randint(1, 3) x1 = irreps_in1.randn(bdim, -1) x2 = irreps_in2.randn(bdim, -1) w = [ torch.randn((bdim, ) + ins.path_shape) for ins in m.instructions if ins.has_weight ] m(x1, x2, w) # - shared_weights = True - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=True) bdim = random.randint(1, 3) x1 = irreps_in1.randn(bdim, -1) x2 = irreps_in2.randn(bdim, -1) w = [ torch.randn(ins.path_shape) for ins in m.instructions if ins.has_weight ] m(x1, x2, w)
def __init__(self, irreps_out, num_z, lmax) -> None: super().__init__() self.num_z = num_z self.irreps_sh = o3.Irreps.spherical_harmonics(lmax) # to multiply the edge type one-hot with the spherical harmonics to get the edge attributes self.mul = TensorProduct( [(num_z**2, "0e")], self.irreps_sh, [(num_z**2, ir) for _, ir in self.irreps_sh], [ (0, l, l, "uvu", False) for l in range(lmax + 1) ] ) irreps_attr = self.mul.irreps_out irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o") irreps_out = o3.Irreps(irreps_out) self.tp1 = FullyConnectedTensorProduct( irreps_in1=self.irreps_sh, irreps_in2=irreps_attr, irreps_out=irreps_mid, ) self.tp2 = FullyConnectedTensorProduct( irreps_in1=irreps_mid, irreps_in2=irreps_attr, irreps_out=irreps_out, )
def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep( 0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, 'uvu', True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() assert irreps_mid.dim > 0, f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing in irreps_node_output={self.irreps_node_output}" instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) # inspired by https://arxiv.org/pdf/2002.10444.pdf self.alpha = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e") with torch.no_grad(): self.alpha.weight.zero_() assert self.alpha.output_mask[ 0] == 1.0, f"irreps_mid={irreps_mid} and irreps_node_attr={self.irreps_node_attr} are not able to generate scalars"
def test_input_weights_jit(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") # - shared_weights = False - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=False) traced = assert_auto_jitable(m) x1 = irreps_in1.randn(2, -1) x2 = irreps_in2.randn(2, -1) w = torch.randn(2, m.weight_numel) with pytest.raises((RuntimeError, torch.jit.Error)): m(x1, x2) # it should require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2) # it should also require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2, w[0]) # it should reject insufficient weights # Does the trace give right results? assert torch.allclose(m(x1, x2, w), traced(x1, x2, w)) # Confirm that weird batch dimensions give the same results for f in (m, traced): x1 = irreps_in1.randn(2, 1, 4, -1) x2 = irreps_in2.randn(2, 3, 1, -1) w = torch.randn(3, 4, f.weight_numel) assert torch.allclose( f(x1, x2, w).reshape(24, -1), f( x1.expand(2, 3, 4, -1).reshape(24, -1), x2.expand(2, 3, 4, -1).reshape(24, -1), w[None].expand(2, 3, 4, -1).reshape(24, -1))) assert torch.allclose( f.right(x2, w).reshape(24, -1), f.right( x2.expand(2, 3, 4, -1).reshape(24, -1), w[None].expand(2, 3, 4, -1).reshape(24, -1)).reshape(24, -1)) # - shared_weights = True - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=True) traced = assert_auto_jitable(m) w = torch.randn(m.weight_numel) with pytest.raises((RuntimeError, torch.jit.Error)): m(x1, x2) # it should require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2) # it should also require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2, torch.randn( 2, m.weight_numel)) # it should reject too many weights # Does the trace give right results? assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))
def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs): super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_sh = o3.Irreps(irreps_sh) self.num_radial_basis = num_radial_basis # self-connection self.sc = Linear(self.irreps_in, self.irreps_out) # connection with neighbors r = diameter / 2 s = math.floor(r / steps[0]) x = torch.arange(-s, s + 1.0) * steps[0] s = math.floor(r / steps[1]) y = torch.arange(-s, s + 1.0) * steps[1] s = math.floor(r / steps[2]) z = torch.arange(-s, s + 1.0) * steps[2] lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1) # [x, y, z, R^3] self.register_buffer('lattice', lattice) if 'padding' not in kwargs: kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3]) self.kwargs = kwargs emb = soft_one_hot_linspace( x=lattice.norm(dim=-1), start=0.0, end=r, number=self.num_radial_basis, basis='smooth_finite', cutoff=True, ) self.register_buffer('emb', emb) sh = o3.spherical_harmonics( l=self.irreps_sh, x=lattice, normalize=True, normalization='component' ) # [x, y, z, irreps_sh.dim] self.register_buffer('sh', sh) self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False) self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel))
def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep( 0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, 'uvu', True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) self.lin3 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e")
def __init__(self) -> None: super().__init__() self.irreps_sh = o3.Irreps.spherical_harmonics(3) irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o") irreps_out = o3.Irreps("0o + 6x0e") self.tp1 = FullyConnectedTensorProduct( irreps_in1=self.irreps_sh, irreps_in2=self.irreps_sh, irreps_out=irreps_mid, ) self.tp2 = FullyConnectedTensorProduct( irreps_in1=irreps_mid, irreps_in2=self.irreps_sh, irreps_out=irreps_out, ) self.irreps_out = self.tp2.irreps_out
def test_weight_view_for_instruction(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") x1 = irreps_in1.randn(2, -1) x2 = irreps_in2.randn(2, -1) m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) # Find all paths to the first output ins_idexes = [i for i, ins in enumerate(m.instructions) if ins.i_out == 0] with torch.no_grad(): for i in ins_idexes: m.weight_view_for_instruction(i).zero_() out = m(x1, x2) assert torch.all(out[:, :1] == 0.0) assert torch.any(out[:, 1:] > 0.0)
def __init__(self, irreps_in, irreps_out, irreps_sh, dim_key): super().__init__() self.irreps_in = irreps_in.simplify() self.irreps_out = irreps_out.simplify() self.irreps_sh = irreps_sh.simplify() # self.si = Linear(self.irreps_in, self.irreps_out, internal_weights=True, shared_weights=True) self.si = FullyConnectedTensorProduct(self.irreps_in, o3.Irreps("5x0e"), self.irreps_out) # self.lin1 = Linear(self.irreps_in, self.irreps_in, internal_weights=True, shared_weights=True) self.lin1 = FullyConnectedTensorProduct(self.irreps_in, o3.Irreps("5x0e"), self.irreps_in) instr = [] irreps = [] for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in): for i_2, (_, (l_2, p_2)) in enumerate(self.irreps_sh): for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1): p_out = p_1 * p_2 if (l_out, p_out) in [(l, p) for _, (l, p) in self.irreps_out]: r = (mul_1, l_out, p_out) if r in irreps: i_out = irreps.index(r) else: i_out = len(irreps) irreps.append(r) instr += [(i_1, i_2, i_out, 'uvu', True)] irreps = o3.Irreps(irreps) self.tp = TensorProduct(self.irreps_in, self.irreps_sh, irreps, instr, internal_weights=False, shared_weights=False) self.tp_weight = torch.nn.Parameter( torch.randn(dim_key, self.tp.weight_numel)) # self.lin2 = Linear(irreps, self.irreps_out, internal_weights=True, shared_weights=True) self.lin2 = FullyConnectedTensorProduct(irreps, o3.Irreps("5x0e"), self.irreps_out)
def __init__(self, irreps_in1, irreps_out, irreps_in2=None, tp_rescale=True) -> None: super().__init__() self.irreps_in1 = irreps_in1 self.irreps_out = irreps_out # Init irreps_in2 if irreps_in2 == None: self.irreps_in2_provided = False self.irreps_in2 = Irreps("1x0e") else: self.irreps_in2_provided = True self.irreps_in2 = irreps_in2 self.tp_rescale = tp_rescale # Build the layers self.tp = FullyConnectedTensorProduct( irreps_in1=self.irreps_in1, irreps_in2=self.irreps_in2, irreps_out=self.irreps_out, shared_weights=True, normalization='component') # For each zeroth order output irrep we need a bias # So first determine the order for each output tensor and their dims self.irreps_out_orders = [int(irrep_str[-2]) for irrep_str in str(irreps_out).split('+')] self.irreps_out_dims = [int(irrep_str.split('x')[0]) for irrep_str in str(irreps_out).split('+')] self.irreps_out_slices = irreps_out.slices() # Store tuples of slices and corresponding biases in a list self.biases = [] self.biases_slices = [] self.biases_slice_idx = [] for slice_idx in range(len(self.irreps_out_orders)): if self.irreps_out_orders[slice_idx] == 0: out_slice = irreps_out.slices()[slice_idx] out_bias = torch.nn.Parameter( torch.zeros(self.irreps_out_dims[slice_idx], dtype=self.tp.weight.dtype)) self.biases += [out_bias] self.biases_slices += [out_slice] self.biases_slice_idx += [slice_idx] self.biases = torch.nn.ParameterList(self.biases) # Initialize the correction factors self.slices_sqrt_k = {} # Initialize similar to the torch.nn.Linear self.tensor_product_init()
def test_fully_connected(): irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) print(m) m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim)) assert_equivariant(m) assert_auto_jitable(m)
def WeightBalancedIrreps(irreps_in1_scalar, irreps_in2, sh = True): """ Determines an irreps_in1 type of order irreps_in2.lmax that when used in a tensor product irreps_in1 x irreps_in2 -> irreps_in1 would have the same number of weights as for a standard linear layer, e.g. a tensor product irreps_in1_scalar x "1x0e" -> irreps_in1_scaler """ n = 1 lmax = irreps_in2.lmax irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n) weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel weight_numel_scalar = FullyConnectedTensorProduct(irreps_in1_scalar, Irreps("1x0e"), irreps_in1_scalar).weight_numel while weight_numel1 < weight_numel_scalar: # TODO: somewhat suboptimal implementation... n += 1 irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n) weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel print('Determined irrep type:', irreps_in1) return Irreps(irreps_in1)
def __init__(self, irreps_in, irreps_sh, irreps_out, num_neighbors) -> None: super().__init__() self.num_neighbors = num_neighbors tp = FullyConnectedTensorProduct( irreps_in1=irreps_in, irreps_in2=irreps_sh, irreps_out=irreps_out, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet([3, 256, tp.weight_numel], torch.relu) self.tp = tp self.irreps_out = self.tp.irreps_out
def test_weight_views(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") batchdim = 3 x1 = irreps_in1.randn(batchdim, -1) x2 = irreps_in2.randn(batchdim, -1) # shared weights m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) with torch.no_grad(): for w in m.weight_views(): w.zero_() assert torch.all(m(x1, x2) == 0.0) # unshared weights m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, shared_weights=False) weights = torch.randn(batchdim, m.weight_numel) with torch.no_grad(): for w in m.weight_views(weights): w.zero_() assert torch.all(m(x1, x2, weights) == 0.0)
class Convolution(torch.nn.Module): r"""convolution on voxels Parameters ---------- irreps_in : `e3nn.o3.Irreps` input irreps irreps_out : `e3nn.o3.Irreps` output irreps irreps_sh : `e3nn.o3.Irreps` set typically to ``o3.Irreps.spherical_harmonics(lmax)`` diameter : float diameter of the filter in physical units num_radial_basis : int number of radial basis functions steps : tuple of float size of the pixel in physical units """ def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs): super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_sh = o3.Irreps(irreps_sh) self.num_radial_basis = num_radial_basis # self-connection self.sc = Linear(self.irreps_in, self.irreps_out) # connection with neighbors r = diameter / 2 s = math.floor(r / steps[0]) x = torch.arange(-s, s + 1.0) * steps[0] s = math.floor(r / steps[1]) y = torch.arange(-s, s + 1.0) * steps[1] s = math.floor(r / steps[2]) z = torch.arange(-s, s + 1.0) * steps[2] lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1) # [x, y, z, R^3] self.register_buffer('lattice', lattice) if 'padding' not in kwargs: kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3]) self.kwargs = kwargs emb = soft_one_hot_linspace( x=lattice.norm(dim=-1), start=0.0, end=r, number=self.num_radial_basis, basis='smooth_finite', cutoff=True, ) self.register_buffer('emb', emb) sh = o3.spherical_harmonics( l=self.irreps_sh, x=lattice, normalize=True, normalization='component' ) # [x, y, z, irreps_sh.dim] self.register_buffer('sh', sh) self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False) self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel)) def kernel(self): weight = self.emb @ self.weight weight = weight / (self.sh.shape[0] * self.sh.shape[1] * self.sh.shape[2]) kernel = self.tp.right(self.sh, weight) # [x, y, z, irreps_in.dim, irreps_out.dim] kernel = torch.einsum('xyzio->oixyz', kernel) return kernel def forward(self, x): r""" Parameters ---------- x : `torch.Tensor` tensor of shape ``(batch, irreps_in.dim, x, y, z)`` Returns ------- `torch.Tensor` tensor of shape ``(batch, irreps_out.dim, x, y, z)`` """ sc = self.sc(x.transpose(1, 4)).transpose(1, 4) return sc + torch.nn.functional.conv3d(x, self.kernel(), **self.kwargs)
def main(): parser = argparse.ArgumentParser(prog="tensor_product_benchmark") parser.add_argument("--jit", type=t_or_f, default=True) parser.add_argument("--irreps", type=str, default="8x0e + 8x1e + 8x2e + 8x3o") parser.add_argument("--irreps-in1", type=str, default=None) parser.add_argument("--irreps-in2", type=str, default=None) parser.add_argument("--irreps-out", type=str, default=None) parser.add_argument("--cuda", type=t_or_f, default=True) parser.add_argument("--backward", type=t_or_f, default=True) parser.add_argument("--opt-ein", type=t_or_f, default=True) parser.add_argument("--specialized-code", type=t_or_f, default=True) parser.add_argument("--elementwise", action='store_true') parser.add_argument("-n", type=int, default=1000) parser.add_argument("--batch", type=int, default=10) args = parser.parse_args() device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu' args.cuda = device == 'cuda' print("======= Benchmark with settings: ======") for key, val in vars(args).items(): print(f"{key:>18} : {val}") print("=" * 40) irreps_in1 = Irreps(args.irreps_in1 if args.irreps_in1 else args.irreps) irreps_in2 = Irreps(args.irreps_in2 if args.irreps_in2 else args.irreps) irreps_out = Irreps(args.irreps_out if args.irreps_out else args.irreps) if args.elementwise: tp = ElementwiseTensorProduct(irreps_in1, irreps_in2, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein) if args.backward: print( "Elementwise TP has no weights, cannot backward. Setting --backward False." ) args.backward = False else: tp = FullyConnectedTensorProduct( irreps_in1, irreps_in2, irreps_out, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein) tp = tp.to(device=device) assert len(tp.instructions) > 0, "Bad irreps, no instructions" print(f"Tensor product: {tp}") print("Instructions:") for ins in tp.instructions: print(f" {ins}") # from https://pytorch.org/docs/master/_modules/torch/utils/benchmark/utils/timer.html#Timer.timeit warmup = max(int(args.n // 100), 1) inputs = iter([(irreps_in1.randn(args.batch, -1).to(device=device), irreps_in2.randn(args.batch, -1).to(device=device)) for _ in range(args.n + warmup)]) # compile if args.jit: tp = compile(tp) print("starting...") # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones t = Timer( stmt=("tp.zero_grad()\n" "out = tp(*next(inputs))\n" + ("out.tanh().sum().backward()\n" if args.backward else '')), globals={ 'tp': tp, 'inputs': inputs }) perloop = t.timeit(args.n) print() print(perloop)
class O3TensorProduct(torch.nn.Module): def __init__(self, irreps_in1, irreps_out, irreps_in2=None, tp_rescale=True) -> None: super().__init__() self.irreps_in1 = irreps_in1 self.irreps_out = irreps_out # Init irreps_in2 if irreps_in2 == None: self.irreps_in2_provided = False self.irreps_in2 = Irreps("1x0e") else: self.irreps_in2_provided = True self.irreps_in2 = irreps_in2 self.tp_rescale = tp_rescale # Build the layers self.tp = FullyConnectedTensorProduct( irreps_in1=self.irreps_in1, irreps_in2=self.irreps_in2, irreps_out=self.irreps_out, shared_weights=True, normalization='component') # For each zeroth order output irrep we need a bias # So first determine the order for each output tensor and their dims self.irreps_out_orders = [int(irrep_str[-2]) for irrep_str in str(irreps_out).split('+')] self.irreps_out_dims = [int(irrep_str.split('x')[0]) for irrep_str in str(irreps_out).split('+')] self.irreps_out_slices = irreps_out.slices() # Store tuples of slices and corresponding biases in a list self.biases = [] self.biases_slices = [] self.biases_slice_idx = [] for slice_idx in range(len(self.irreps_out_orders)): if self.irreps_out_orders[slice_idx] == 0: out_slice = irreps_out.slices()[slice_idx] out_bias = torch.nn.Parameter( torch.zeros(self.irreps_out_dims[slice_idx], dtype=self.tp.weight.dtype)) self.biases += [out_bias] self.biases_slices += [out_slice] self.biases_slice_idx += [slice_idx] self.biases = torch.nn.ParameterList(self.biases) # Initialize the correction factors self.slices_sqrt_k = {} # Initialize similar to the torch.nn.Linear self.tensor_product_init() def tensor_product_init(self) -> None: with torch.no_grad(): # Determine fan_in for each slice, it could be that each output slice is updated via several instructions slices_fan_in = {} # fan_in per slice for weight, instr in zip(self.tp.weight_views(), self.tp.instructions): slice_idx = instr[2] mul_1, mul_2, mul_out = weight.shape fan_in = mul_1 * mul_2 slices_fan_in[slice_idx] = (slices_fan_in[slice_idx] + fan_in if slice_idx in slices_fan_in.keys() else fan_in) # Do the initialization of the weights in each instruction for weight, instr in zip(self.tp.weight_views(), self.tp.instructions): # The tensor product in e3nn already normalizes proportional to 1 / sqrt(fan_in), and the weights are by # default initialized with unif(-1,1). However, we want to be consistent with torch.nn.Linear and # initialize the weights with unif(-sqrt(k),sqrt(k)), with k = 1 / fan_in if self.tp_rescale: sqrt_k = 1 / slices_fan_in[slice_idx] ** 0.5 else: sqrt_k = 1. weight.data.uniform_(-sqrt_k, sqrt_k) self.slices_sqrt_k[slice_idx] = (self.irreps_out_slices[slice_idx], sqrt_k) # Initialize the biases for (out_slice_idx, out_slice, out_bias) in zip(self.biases_slice_idx, self.biases_slices, self.biases): sqrt_k = 1 / slices_fan_in[out_slice_idx] ** 0.5 out_bias.uniform_(-sqrt_k, sqrt_k) def forward_tp_rescale_bias(self, data_in1, data_in2=None) -> torch.Tensor: if data_in2 == None: data_in2 = torch.ones_like(data_in1[:, 0:1]) data_out = self.tp(data_in1, data_in2) # Apply corrections if self.tp_rescale: for (slice, slice_sqrt_k) in self.slices_sqrt_k.values(): data_out[:,slice] /= slice_sqrt_k # Add the biases for (_, slice, bias) in zip(self.biases_slice_idx, self.biases_slices, self.biases): data_out[:,slice] += bias # Return result return data_out def forward(self, data_in1, data_in2=None) -> torch.Tensor: # Apply the tensor product, the rescaling and the bias data_out = self.forward_tp_rescale_bias(data_in1, data_in2) return data_out
def main(): parser = argparse.ArgumentParser(prog="tensor_product_benchmark") parser.add_argument("--jit", type=t_or_f, default=True) parser.add_argument("--irreps-in1", type=str, default="8x0e + 8x1e + 8x2e + 8x3e") parser.add_argument("--irreps-in2", type=str, default="8x0e + 8x1e + 8x2e + 8x3e") parser.add_argument("--irreps-out", type=str, default="8x0e + 8x1e + 8x2e + 8x3e") parser.add_argument("--cuda", type=t_or_f, default=True) parser.add_argument("--backward", type=t_or_f, default=True) parser.add_argument("--opt-ein", type=t_or_f, default=True) parser.add_argument("--specialized-code", type=t_or_f, default=True) parser.add_argument("-w", type=int, default=10) parser.add_argument("-n", type=int, default=3) parser.add_argument("--batch", type=int, default=10) args = parser.parse_args() device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu' args.cuda = device == 'cuda' if args.cuda: # Workaround for CUDA driver issues # See https://github.com/pytorch/pytorch/issues/60158#issuecomment-866294291 with torch.profiler.profile() as _: pass print("======= Benchmark with settings: ======") for key, val in vars(args).items(): print(f"{key:>18} : {val}") print("=" * 40) irreps_in1 = Irreps(args.irreps_in1) irreps_in2 = Irreps(args.irreps_in2) irreps_out = Irreps(args.irreps_out) tp = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, _specialized_code=args.specialized_code, _optimize_einsums=args.opt_ein) tp = tp.to(device=device) inputs = [(irreps_in1.randn(args.batch, -1).to(device=device), irreps_in2.randn(args.batch, -1).to(device=device)) for _ in range(1 + args.w + args.n)] if args.backward: for tmp in inputs: for t in tmp: t.requires_grad_(True) inputs = iter(inputs) # compile if args.jit: print("JITing...") tp = compile(tp) print("starting...") called_num = [0] def trace_handler(p): print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json") called_num[0] += 1 with torch.profiler.profile(activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA ], schedule=torch.profiler.schedule( wait=1, warmup=args.w, active=args.n), on_trace_ready=trace_handler) as p: for _ in range(1 + args.w + args.n): out = tp(*next(inputs)) if args.backward: # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones out.tanh().sum().backward() p.step()
def __init__(self, muls=(256, 16, 0), lmax=1, num_layers=3, cutoff=10.0, rad_gaussians=50, rad_hs=(128, 128), num_neighbors=20, num_atoms=20, mean=None, std=None, scale=None, atomref=None): super().__init__() self.cutoff = cutoff self.mean = mean self.std = std self.scale = scale self.num_neighbors = num_neighbors self.num_atoms = num_atoms self.rad_gaussians = rad_gaussians self.cutoff = cutoff self.radial = FullyConnectedNet((rad_gaussians, ) + rad_hs, swish, variance_in=1 / rad_gaussians, out_act=True) self.irreps_sh = o3.Irreps.spherical_harmonics( lmax) # spherical harmonics representation # self.irreps_edge = o3.Irreps([(25, l, (-1)**l) for l in range(lmax + 1)]) self.irreps_edge = self.irreps_sh # self.mul = TensorProduct( # [(25, "0e", 1.0)], # [(1, ir, 1.0) for _, ir in self.irreps_sh], # [(25, ir, 1.0) for _, ir in self.irreps_sh], # [ # (0, l, l, "uvu", False, 1.0) # for l in range(lmax + 1) # ] # ) irreps = o3.Irreps([(muls[0], (0, 1)), (muls[1], (1, -1)), (muls[2], (2, 1))]) self.mul_node = FullyConnectedTensorProduct([(5, "0e")], self.irreps_sh, irreps) modules = [] for _ in range(num_layers): act = make_gated_block(irreps, muls, self.irreps_sh) conv = Conv(irreps, act.irreps_in, self.irreps_edge, rad_hs[-1]) irreps = act.irreps_out.simplify() modules += [torch.nn.ModuleList([conv, act])] self.layers = torch.nn.ModuleList(modules) self.irreps_out = o3.Irreps("0e + 0o") self.layers.append( Conv(irreps, self.irreps_out, self.irreps_edge, rad_hs[-1])) self.register_buffer('atomref', atomref)
def test_empty_irreps(): tp = FullyConnectedTensorProduct('0e + 1e', Irreps([]), '0e + 1e') out = tp(torch.randn(1, 2, 4), torch.randn(2, 1, 0)) assert out.shape == (2, 2, 4)