def __init__(self, ls): super().__init__(self, fx.Graph()) # == generate code == graph = self.graph z = fx.Proxy(graph.placeholder('z', torch.Tensor)) y = fx.Proxy(graph.placeholder('y', torch.Tensor)) out = z.new_zeros(z.shape + (sum(2 * l + 1 for l in ls), )) i = 0 for l in ls: leg = [] for m in range(l + 1): p = _poly_legendre(l, m) p = list(p.items()) (zn, yn), c = p[0] x = float(c) * z**zn * y**yn for (zn, yn), c in p[1:]: x += float(c) * z**zn * y**yn leg.append(x.unsqueeze(-1)) for m in range(-l, l + 1): out.narrow(-1, i, 1).copy_(leg[abs(m)]) i += 1 graph.output(out.node, torch.Tensor) self.recompile()
def decompose(model: torch.nn.Module, example_inputs) -> torch.nn.Module: """ decompose(model, example_inputs) takes in a model, decomposes any of the functions in `decomposition_rules` to its constituent operations, and returns a `nn.Module` without any of the operations with decomposition rules. """ # Run it multiple times so we converge to a fixed point. for _ in range(5): model = fx.symbolic_trace(model) ShapeProp(model).propagate(*example_inputs) new_graph = fx.Graph() env = {} for node in model.graph.nodes: if node.op == 'call_function' and node.target in decomposition_rules: # If the current function is in `decomposition_rules`, we use # `Proxy` objects to decompose the operations using the # decomposition rule. See # https://pytorch.org/docs/master/fx.html#proxy-retracing for # more details. proxy_args = map_arg(node.args, lambda n: fx.Proxy(env[n.name])) proxy_kwargs = map_arg(node.kwargs, lambda n: fx.Proxy(env[n.name])) new_node = decomposition_rules[node.target]( *proxy_args, **proxy_kwargs).node env[node.name] = new_node else: new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node model = fx.GraphModule(model, new_graph) return model
def remove_suffix(cur_graph, cur_inps): print("Strategy: Remove suffix") assert graph_fails(cur_graph, cur_inps) gap = 2**math.floor(math.log2(len(cur_graph.nodes))) tested = set() while gap >= 1: new_graph = fx.Graph() env = {} for idx, node in enumerate(cur_graph.nodes): new_node = new_graph.node_copy(node, lambda x: env[x]) if node.op not in ['placeholder', 'output']: if idx % gap == 0 and idx not in tested: output_node = new_graph.output((new_node, )) if graph_fails(new_graph, cur_inps) and len( new_graph.nodes) < len(cur_graph.nodes): print() print( f"SUCCESS: Removed [{idx}:{len(cur_graph.nodes)})" ) return (new_graph, cur_inps), True else: tested.add(idx) new_graph.erase_node(output_node) env[node] = new_node gap //= 2 print("FAIL: Could not remove suffix") return (cur_graph, cur_inps), False
def vmap(model: torch.nn.Module, in_axes: Tuple[Optional[int], ...], example_args: Tuple[Any, ...]) -> torch.nn.Module: """vmap Given a model with inputs, vmap will return a function that works on batched versions of those inputs. Which inputs will be batched is determined by in_axes. In addition, as vmap requires shape (actually rank) information, we will pass in example_args (example inputs for the original module). """ in_axes = iter(in_axes) fx_model = fx.symbolic_trace(model) # Here we run a shape propagation pass in order to annotate the graph with shape information. ShapeProp(fx_model).propagate(*example_args) # As vmap rewrites the whole graph, it's easiest to create an entirely new # graph and append to that. new_graph: fx.Graph = fx.Graph() # We will create an environment to map the new nodes created to the # corresponding old nodes. def lookup_env(l): return fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) env = {} for node in fx_model.graph.nodes: if node.op == 'placeholder': # If the node is an input placeholder, we simply copy it over and # annotate it with the batch dimension from `in_axes`. new_node = new_graph.placeholder(node.name) new_node.bdim = next(in_axes) new_node.meta = node.meta env[node.name] = new_node elif node.op == 'output': new_graph.output(env[node.args[0].name]) elif node.op == 'call_function': new_args = lookup_env(node.args) # If any of the inputs to the function has a new batch dimension, # we will need to use our batching rules. Otherwise, we will simply # copy the node over. if any([ x.bdim is not None for x in new_args if isinstance(x, fx.Node) ]): new_node = gen_batching_rule_function(node.target, *new_args) else: new_node = new_graph.node_copy(node, lambda x: env[x.name]) new_node.bdim = None new_node.meta = node.meta env[node.name] = new_node else: raise RuntimeError("Not yet implemented") res = fx.GraphModule(fx_model, new_graph) print(res.code) res.graph.lint() return res
def grad(model: torch.nn.Module, example_inps: Tuple[Any, ...], get_value=True) -> torch.nn.Module: fx_model = fx.symbolic_trace(model) ShapeProp(fx_model).propagate(*example_inps) # graph and append to that. val_map = {} new_graph: fx.Graph = fx.Graph() orig_output = new_graph.graph_copy(fx_model.graph, val_map) def shape_proxy(node): proxy = fx.Proxy(val_map[node]) proxy.shape = node.meta['shape'] proxy.dim = lambda: len(proxy.shape) return proxy inputs = [] ones = new_graph.create_node('call_function', torch.ones, ([], )) for node in reversed(fx_model.graph.nodes): if node.op == 'output': assert (len(node.args) == 1) val_map[node.args[0]].grad = [fx.Proxy(ones)] elif node.op == 'placeholder': inputs.append(sum(val_map[node].grad).node) elif node.op == 'call_function': g = sum(val_map[node].grad) new_args = [ shape_proxy(i) if isinstance(i, fx.Node) else i for i in node.args ] if node.target not in vjp_map: raise RuntimeError("vjp not yet implemented") new_grads = vjp_map[node.target](g, *new_args) if not isinstance(new_grads, tuple): new_grads = (new_grads, ) for new_g, arg in zip(new_grads, new_args): if isinstance(arg, fx.Proxy): if not hasattr(arg.node, 'grad'): arg.node.grad = [] arg.node.grad.append(new_g) elif node.op == 'call_method': raise RuntimeError("doesn't support methods since i'm lazy") if len(inputs) == 1: inputs = inputs[0] else: inputs = inputs[::-1] if get_value: new_graph.output((orig_output, inputs)) else: new_graph.output(inputs) res = fx.GraphModule(fx_model, new_graph) res.graph.lint() return res
def fx_graph_cse(fx_g: torch.fx.graph.Graph): new_graph = fx.Graph() env = {} # map from node in the old graph to node in the new graph hash_env = {} # map from hash to a node in the new graph token_map = {} # map from hash to token for n in fx_g.nodes: # The placeholder, output, and get_attr nodes are copied to the new grpah without change # do not CSE away random operations if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target( n) in rand_ops: new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' # substitute args and kwargs memebrs to their mapping in env if exists # specs can be used to reconstruct nested list/dictionaries def substitute(arg_list): arg_list, spec = tree_flatten(arg_list) for i in range(len(arg_list)): v = arg_list[i] if isinstance(v, torch.fx.node.Node) and v in env: arg_list[i] = env[v] return tuple(arg_list), spec args, args_spec = substitute(n.args) kwargs, kwargs_spec = substitute(n.kwargs) # each token corresponds to a unique node # nodes with the same token can be substituted token = { "target": n.target, "args": args, "args_spec": args_spec, "kwargs": kwargs, "kwargs_spec": kwargs_spec } # hash substituted args to a number, do not hash specs because specs are not hashable hash_arg = hash((args, kwargs)) hash_val = (n.target, hash_arg) # check if a node has a substitute and can be eliminated hash_val_in_hash_env = hash_val in hash_env if hash_val_in_hash_env and token_map[hash_val] == token: env[n] = hash_env[hash_val] continue new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node if not hash_val_in_hash_env: hash_env[hash_val] = new_node token_map[hash_val] = token return new_graph
def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): """ Given a graph, extracts out a subgraph that takes the specified nodes as inputs and returns the specified outputs. This includes specifying non-placeholder nodes as inputs. The general strategy is to initialize all inputs with proxies as we encounter them, and trace through the graph, only keeping values which take in valid proxies. Then, all dead code is eliminated. """ new_graph = fx.Graph() env = {} # Add new placeholder nodes in the order specified by the inputs for node in inputs: new_node = new_graph.placeholder(node.name) # Can't use node_copy here as we may be turning previous call_function into placeholders new_node.meta = node.meta env[node] = new_node for node in joint_graph.nodes: if node in inputs: continue elif node.op == 'placeholder': env[node] = InvalidNode elif node.op == 'call_function': all_args = pytree.tree_flatten((node.args, node.kwargs))[0] all_args = [ isinstance(env[x], InvalidNodeBase) for x in all_args if isinstance(x, fx.Node) ] if any(all_args): env[node] = InvalidNode continue env[node] = new_graph.node_copy(node, lambda x: env[x]) elif node.op == 'get_attr': env[node] = new_graph.node_copy(node, lambda x: env[x]) elif node.op == 'output': pass output_values = [] for x in outputs: if isinstance(x, fx.Node): if x not in env: raise RuntimeError(f"Node {x} couldn't be found in env") output_values.append(env[x]) else: output_values.append(x) new_graph.output(output_values) new_graph.eliminate_dead_code() new_graph.lint() return new_graph
def _consolidate_placeholders(cur_graph): new_graph = fx.Graph() env = {} for node in cur_graph.nodes: if node.op == 'placeholder': new_node = new_graph.node_copy(node, lambda x: env[x]) env[node] = new_node for node in cur_graph.nodes: if node.op != 'placeholder': new_node = new_graph.node_copy(node, lambda x: env[x]) env[node] = new_node return new_graph
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): """ Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. """ new_graph = fx.Graph() env: Dict[fx.Node, fx.Node] = {} for input in inputs: new_node = new_graph.placeholder(input.name) env[input] = new_node for node in nodes: new_node = new_graph.node_copy(node, lambda x: env[x]) env[node] = new_node new_graph.output([env[output] for output in outputs]) new_graph.lint() return fx.GraphModule(orig_module, new_graph)
def truncate(model, k): model = fx.symbolic_trace(model) new_graph = fx.Graph() env = {} cnt = 0 for node in list(model.graph.nodes): new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node cnt += 1 if cnt == k: new_graph.output(env[node.name]) break return fx.GraphModule(model, new_graph)
def remove_suffix(cur_graph, cur_inps, granularity): tested = set() new_graph = fx.Graph() env = {} for idx, node in enumerate(cur_graph.nodes): new_node = new_graph.node_copy(node, lambda x: env[x]) if node.op not in ['placeholder', 'output']: # If idx is divisible by (granularity * 2), it would have been checked already. if idx % granularity == 0 and (idx % (granularity * 2) != 0) and idx not in tested: output_node = new_graph.output((new_node, )) if len(new_graph.nodes) < len( cur_graph.nodes) and graph_fails( new_graph, cur_inps): return ReproState(new_graph, cur_inps) else: tested.add(idx) new_graph.erase_node(output_node) env[node] = new_node return None
def invert(model: torch.nn.Module) -> torch.nn.Module: fx_model = fx.symbolic_trace(model) new_graph = fx.Graph() # As we're building up a new graph env = {} for node in reversed(fx_model.graph.nodes): if node.op == "call_function": # This creates a node in the new graph with the inverse function, # and passes `env[node.name]` (i.e. the previous output node) as # input. new_node = new_graph.call_function(invert_mapping[node.target], (env[node.name], )) env[node.args[0].name] = new_node elif node.op == "output": # We turn the output into an input placeholder new_node = new_graph.placeholder(node.name) env[node.args[0].name] = new_node elif node.op == "placeholder": # We turn the input placeholder into an output new_graph.output(env[node.name]) else: raise RuntimeError("Not implemented") new_graph.lint() return fx.GraphModule(fx_model, new_graph)
def _codegen_linear( irreps_in: o3.Irreps, irreps_out: o3.Irreps, instructions: List[Instruction], biases: List[bool], f_in: Optional[int] = None, f_out: Optional[int] = None, shared_weights: bool = False, optimize_einsums: bool = True, ) -> Tuple[fx.GraphModule, int, int]: graph_out = fx.Graph() # = Function definitions = x = fx.Proxy(graph_out.placeholder('x', torch.Tensor)) ws = fx.Proxy(graph_out.placeholder('w', torch.Tensor)) bs = fx.Proxy(graph_out.placeholder('b', torch.Tensor)) if f_in is None: size = x.shape[:-1] outsize = size + (irreps_out.dim, ) else: size = x.shape[:-2] outsize = size + ( f_out, irreps_out.dim, ) bias_numel = sum(mul_ir.dim for bias, mul_ir in zip(biases, irreps_out) if bias) if bias_numel > 0: if f_out is None: bs = bs.reshape(-1, bias_numel) else: bs = bs.reshape(-1, f_out, bias_numel) # = Short-circut for nothing to do = # We produce no code for empty instructions instructions = [ins for ins in instructions if 0 not in ins.path_shape] if len(instructions) == 0 and bias_numel == 0: out = x.new_zeros(outsize) graph_out.output(out.node, torch.Tensor) # Short circut # 0 is weight_numel return fx.GraphModule({}, graph_out, "linear_forward"), 0, 0 if f_in is None: x = x.reshape(-1, irreps_in.dim) else: x = x.reshape(-1, f_in, irreps_in.dim) batch_out = x.shape[0] out_bias_list = [] bias_index = 0 for bias, mul_ir_out in zip(biases, irreps_out): if bias: if sum(biases) == 1: b = bs else: b = bs.narrow(-1, bias_index, mul_ir_out.dim) bias_index += mul_ir_out.dim out_bias_list += [[ b.expand(batch_out, -1) if f_out is None else b.expand( batch_out, f_out, -1) ]] else: out_bias_list += [[]] weight_numel = sum(prod(ins.path_shape) for ins in instructions) if weight_numel > 0: ws = ws.reshape(-1, weight_numel) if f_in is None else ws.reshape( -1, f_in, f_out, weight_numel) # = extract individual input irreps = if len(irreps_in) == 1: x_list = [ x.reshape(batch_out, *(() if f_in is None else (f_in, )), irreps_in[0].mul, irreps_in[0].ir.dim) ] else: x_list = [ x.narrow(-1, i.start, mul_ir.dim).reshape(batch_out, *(() if f_in is None else (f_in, )), mul_ir.mul, mul_ir.ir.dim) for i, mul_ir in zip(irreps_in.slices(), irreps_in) ] z = '' if shared_weights else 'z' flat_weight_index = 0 out_list = [] for ins in instructions: mul_ir_in = irreps_in[ins.i_in] mul_ir_out = irreps_out[ins.i_out] # Short-circut for empty irreps if mul_ir_in.dim == 0 or mul_ir_out.dim == 0: continue # Extract the weight from the flattened weight tensor path_nweight = prod(ins.path_shape) if len(instructions) == 1: # Avoid unnecessary view when there is only one weight w = ws else: w = ws.narrow(-1, flat_weight_index, path_nweight) w = w.reshape((() if shared_weights else (-1, )) + (() if f_in is None else (f_in, f_out)) + ins.path_shape) flat_weight_index += path_nweight if f_in is None: ein_out = torch.einsum(f"{z}uw,zui->zwi", w, x_list[ins.i_in]) else: ein_out = torch.einsum(f"{z}xyuw,zxui->zywi", w, x_list[ins.i_in]) alpha = 1.0 / math.sqrt((f_in or 1) * mul_ir_in.mul * sum(1 if other_ins.i_out == ins.i_out else 0 for other_ins in instructions)) ein_out = alpha * ein_out out_list += [ ein_out.reshape(batch_out, *(() if f_out is None else (f_out, )), mul_ir_out.dim) ] # = Return the result = out = [ _sum_tensors([ out for ins, out in zip(instructions, out_list) if ins.i_out == i_out ] + out_bias_list[i_out], shape=(batch_out, *(() if f_out is None else (f_out, )), mul_ir_out.dim), like=x) for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0 ] if len(out) > 1: out = torch.cat(out, dim=-1) else: out = out[0] out = out.reshape(outsize) graph_out.output(out.node, torch.Tensor) # check graphs graph_out.lint() graphmod_out = fx.GraphModule({}, graph_out, "linear_forward") # TODO: when eliminate_dead_code() is in PyTorch stable, use that if optimize_einsums: # See _tensor_product/_codegen.py for notes batchdim = 4 example_inputs = ( torch.zeros((batchdim, *(() if f_in is None else (f_in, )), irreps_in.dim)), torch.zeros( 1 if shared_weights else batchdim, f_in or 1, f_out or 1, weight_numel, ), torch.zeros( 1 if shared_weights else batchdim, f_out or 1, bias_numel, ), ) graphmod_out = jitable( optimize_einsums_full(graphmod_out, example_inputs)) return graphmod_out, weight_numel, bias_numel
def __init__(self, irreps_in, irreps_outs, instructions, squeeze_out: bool = False): r"""Extract sub sets of irreps Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input irreps_outs : list of `e3nn.o3.Irreps` list of representation of the outputs instructions : list of tuple of int list of tuples, one per output continaing each ``len(irreps_outs[i])`` int squeeze_out : bool, default False if ``squeeze_out`` and only one output exists, a ``torch.Tensor`` will be returned instead of a ``Tuple[torch.Tensor]`` Examples -------- >>> c = Extract('1e + 0e + 0e', ['0e', '0e'], [(1,), (2,)]) >>> c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) (tensor([1.]), tensor([2.])) """ super().__init__(self, fx.Graph()) self.irreps_in = o3.Irreps(irreps_in) self.irreps_outs = tuple(o3.Irreps(irreps) for irreps in irreps_outs) self.instructions = instructions assert len(self.irreps_outs) == len(self.instructions) for irreps_out, ins in zip(self.irreps_outs, self.instructions): assert len(irreps_out) == len(ins) # == generate code == graph = self.graph x = fx.Proxy(graph.placeholder('x', torch.Tensor)) torch._assert(x.shape[-1] == self.irreps_in.dim, "invalid input shape") out = [] for irreps in self.irreps_outs: out.append(x.new_zeros(x.shape[:-1] + (irreps.dim, ))) for i, (irreps_out, ins) in enumerate(zip(self.irreps_outs, self.instructions)): if ins == tuple(range(len(self.irreps_in))): out[i].copy_(x) else: for s_out, i_in in zip(irreps_out.slices(), ins): i_start = self.irreps_in[:i_in].dim i_len = self.irreps_in[i_in].dim out[i].narrow(-1, s_out.start, s_out.stop - s_out.start).copy_( x.narrow(-1, i_start, i_len)) out = tuple(e.node for e in out) if squeeze_out and len(out) == 1: graph.output(out[0], torch.Tensor) else: graph.output(out, Tuple[(torch.Tensor, ) * len(self.irreps_outs)]) self.recompile()
def __init__(self, formula, filter_ir_out=None, filter_ir_mid=None, eps=1e-9, **irreps): super().__init__(self, fx.Graph()) if filter_ir_out is not None: filter_ir_out = [o3.Irrep(ir) for ir in filter_ir_out] f0, formulas = germinate_formulas(formula) irreps = {i: o3.Irreps(irs) for i, irs in irreps.items()} for i in irreps: if len(i) != 1: raise TypeError(f"got an unexpected keyword argument '{i}'") for _sign, p in formulas: f = "".join(f0[i] for i in p) for i, j in zip(f0, f): if i in irreps and j in irreps and irreps[i] != irreps[j]: raise RuntimeError(f'irreps of {i} and {j} should be the same') if i in irreps: irreps[j] = irreps[i] if j in irreps: irreps[i] = irreps[j] for i in f0: if i not in irreps: raise RuntimeError(f'index {i} has no irreps associated to it') for i in irreps: if i not in f0: raise RuntimeError(f'index {i} has an irreps but does not appear in the fomula') base_perm, _ = reduce_permutation( f0, formulas, dtype=torch.float64, **{i: irs.dim for i, irs in irreps.items()} ) Ps = collections.defaultdict(list) for ir, path, base_o3 in _wigner_nj(*[irreps[i] for i in f0], filter_ir_mid=filter_ir_mid, dtype=torch.float64): if filter_ir_out is None or ir in filter_ir_out: P = base_o3.flatten(1) @ base_perm.flatten(1).T if P.norm() > eps: # if this Irrep is present in the premutation basis we keep it Ps[ir].append((path, base_o3)) outputs = [] change_of_basis = [] irreps_out = [] for ir in Ps: mul = len(Ps[ir]) paths = [path for path, _ in Ps[ir]] base_o3 = torch.stack([R for _, R in Ps[ir]]) R = base_o3.flatten(2) # [multiplicity, ir, input basis] (u,j,omega) P = base_perm.flatten(1) # [permutation basis, input basis] (a,omega) Xs = [] for j in range(ir.dim): RR = R[:, j] @ R[:, j].T # (u,u) PP = P @ P.T # (a,a) RP = R[:, j] @ P.T # (u,a) prob = torch.cat([ torch.cat([RR, -RP], dim=1), torch.cat([-RP.T, PP], dim=1) ], dim=0) eigenvalues, eigenvectors = torch.linalg.eigh(prob) X = eigenvectors[:, eigenvalues < eps][:mul].T # [solutions, multiplicity] X = torch.linalg.qr(X, mode='r').R for i, x in enumerate(X): for j in range(i, mul): if x[j] < eps: x.neg_() if x[j] > eps: break X[X.abs() < eps] = 0 X = sorted([[x.item() for x in line] for line in X]) X = torch.tensor(X, dtype=torch.float64) Xs.append(X) for X in Xs: assert (X - Xs[0]).abs().max() < eps X = Xs[0] for x in X: C = torch.einsum("u,ui...->i...", x, base_o3) correction = (ir.dim / C.pow(2).sum())**0.5 C = correction * C outputs.append([((correction * v).item(), p) for v, p in zip(x, paths) if v.abs() > eps]) change_of_basis.append(C) irreps_out.append((1, ir)) dtype, _ = explicit_default_types(None, None) self.register_buffer('change_of_basis', torch.cat(change_of_basis).to(dtype=dtype)) tps = set() for vp_list in outputs: for v, p in vp_list: for op in _get_ops(p): tps.add(op) tps = list(tps) for i, op in enumerate(tps): tp = o3.TensorProduct(op[0], op[1], op[2], [(0, 0, 0, 'uuu', False)]) setattr(self, f'tp{i}', tp) graph = fx.Graph() inputs = [ fx.Proxy(graph.placeholder(f"x{i}", torch.Tensor)) for i in f0 ] self.irreps_in = [irreps[i] for i in f0] self.irreps_out = o3.Irreps(irreps_out).simplify() values = dict() def evaluate(path): if path in values: return values[path] if isinstance(path, _INPUT): out = inputs[path.tensor] if (path.start, path.stop) != (0, self.irreps_in[path.tensor].dim): out = out.narrow(-1, path.start, path.stop - path.start) if isinstance(path, _TP): x1 = evaluate(path.args[0]).node x2 = evaluate(path.args[1]).node out = fx.Proxy(graph.call_module(f'tp{tps.index(path.op)}', (x1, x2))) values[path] = out return out outs = [] for vp_list in outputs: v, p = vp_list[0] out = evaluate(p) if abs(v - 1.0) > eps: out = v * out for v, p in vp_list[1:]: t = evaluate(p) if abs(v - 1.0) > eps: t = v * t out = out + t outs.append(out) out = torch.cat(outs, dim=-1) graph.output(out.node) self.graph = graph self.recompile()
def codegen_tensor_product( irreps_in1: o3.Irreps, in1_var: List[float], irreps_in2: o3.Irreps, in2_var: List[float], irreps_out: o3.Irreps, out_var: List[float], instructions: List[Instruction], normalization: str = 'component', shared_weights: bool = False, specialized_code: bool = True, optimize_einsums: bool = True, ) -> Tuple[fx.GraphModule, fx.GraphModule]: graph_out = fx.Graph() graph_right = fx.Graph() # = Function definitions = x1s_out = fx.Proxy(graph_out.placeholder('x1', torch.Tensor)) x2s_out = fx.Proxy(graph_out.placeholder('x2', torch.Tensor)) ws_out = fx.Proxy(graph_out.placeholder('w', torch.Tensor)) x2s_right = fx.Proxy(graph_right.placeholder('x2', torch.Tensor)) ws_right = fx.Proxy(graph_right.placeholder('w', torch.Tensor)) empty_out = fx.Proxy( graph_out.call_function(torch.empty, ((), ), dict(device='cpu'))) empty_right = fx.Proxy( graph_right.call_function(torch.empty, ((), ), dict(device='cpu'))) if shared_weights: size_out = torch.broadcast_tensors( empty_out.expand(x1s_out.shape[:-1]), empty_out.expand(x2s_out.shape[:-1]))[0].shape size_right = x2s_right.shape[:-1] else: size_out = torch.broadcast_tensors( empty_out.expand(x1s_out.shape[:-1]), empty_out.expand(x2s_out.shape[:-1]), empty_out.expand(ws_out.shape[:-1]))[0].shape size_right = torch.broadcast_tensors( empty_right.expand(x2s_right.shape[:-1]), empty_right.expand(ws_right.shape[:-1]))[0].shape # = Short-circut for zero dimensional = # We produce no code for empty instructions instructions = [ins for ins in instructions if 0 not in ins.path_shape] if len(instructions) == 0: out_out = x1s_out.new_zeros(size_out + (irreps_out.dim, )) out_right = x2s_right.new_zeros(size_right + ( irreps_in1.dim, irreps_out.dim, )) graph_out.output(out_out.node, torch.Tensor) graph_right.output(out_right.node, torch.Tensor) # Short circut return (fx.GraphModule({}, graph_out, "tp_forward"), fx.GraphModule({}, graph_right, "tp_right")) # = Broadcast inputs = if shared_weights: x1s_out, x2s_out = x1s_out.broadcast_to( size_out + (-1, )), x2s_out.broadcast_to(size_out + (-1, )) else: x1s_out, x2s_out, ws_out = x1s_out.broadcast_to( size_out + (-1, )), x2s_out.broadcast_to( size_out + (-1, )), ws_out.broadcast_to(size_out + (-1, )) x2s_right, ws_right = x2s_right.broadcast_to( size_right + (-1, )), ws_right.broadcast_to(size_right + (-1, )) outsize_out = size_out + (irreps_out.dim, ) outsize_right = size_right + ( irreps_in1.dim, irreps_out.dim, ) x1s_out = x1s_out.reshape(-1, irreps_in1.dim) x2s_out = x2s_out.reshape(-1, irreps_in2.dim) x2s_right = x2s_right.reshape(-1, irreps_in2.dim) batch_out = x1s_out.shape[0] batch_right = x2s_right.shape[0] # = Determine number of weights and reshape weights == weight_numel = sum( prod(ins.path_shape) for ins in instructions if ins.has_weight) if weight_numel > 0: ws_out = ws_out.reshape(-1, weight_numel) ws_right = ws_right.reshape(-1, weight_numel) del weight_numel # = book-keeping for wigners = w3j = [] w3j_dict_out = dict() w3j_dict_right = dict() # = extract individual input irreps = # If only one input irrep, can avoid creating a view if len(irreps_in1) == 1: x1_list_out = [ x1s_out.reshape(batch_out, irreps_in1[0].mul, irreps_in1[0].ir.dim) ] else: x1_list_out = [ x1s_out[:, i].reshape(batch_out, mul_ir.mul, mul_ir.ir.dim) for i, mul_ir in zip(irreps_in1.slices(), irreps_in1) ] x2_list_out = [] x2_list_right = [] # If only one input irrep, can avoid creating a view if len(irreps_in2) == 1: x2_list_out.append( x2s_out.reshape(batch_out, irreps_in2[0].mul, irreps_in2[0].ir.dim)) x2_list_right.append( x2s_right.reshape(batch_right, irreps_in2[0].mul, irreps_in2[0].ir.dim)) else: for i, mul_ir in zip(irreps_in2.slices(), irreps_in2): x2_list_out.append(x2s_out[:, i].reshape(batch_out, mul_ir.mul, mul_ir.ir.dim)) x2_list_right.append(x2s_right[:, i].reshape(batch_right, mul_ir.mul, mul_ir.ir.dim)) # The einsum string index to prepend to the weights if the weights are not shared and have a batch dimension z = '' if shared_weights else 'z' # Cache of input irrep pairs whose outer products (xx) have already been computed xx_dict = dict() # Current index in the flat weight tensor flat_weight_index = 0 out_list_out = [] out_list_right = [] for ins in instructions: mul_ir_in1 = irreps_in1[ins.i_in1] mul_ir_in2 = irreps_in2[ins.i_in2] mul_ir_out = irreps_out[ins.i_out] assert mul_ir_in1.ir.p * mul_ir_in2.ir.p == mul_ir_out.ir.p assert abs(mul_ir_in1.ir.l - mul_ir_in2.ir.l ) <= mul_ir_out.ir.l <= mul_ir_in1.ir.l + mul_ir_in2.ir.l if mul_ir_in1.dim == 0 or mul_ir_in2.dim == 0 or mul_ir_out.dim == 0: continue alpha = ins.path_weight * out_var[ins.i_out] / sum( in1_var[i.i_in1] * in2_var[i.i_in2] for i in instructions if i.i_out == ins.i_out) # Open the profiler block name = f"{mul_ir_in1} x {mul_ir_in2} = {mul_ir_out} {ins.connection_mode} {ins.has_weight}" handle_out = graph_out.call_function( torch.ops.profiler._record_function_enter, (name, )) handle_right = graph_right.call_function( torch.ops.profiler._record_function_enter, (name, )) x1_out = x1_list_out[ins.i_in1] x2_out = x2_list_out[ins.i_in2] x2_right = x2_list_right[ins.i_in2] e1_right = fx.Proxy( graph_right.call_function( torch.eye, (mul_ir_in1.mul, ), dict(dtype=x2s_right.dtype.node, device=x2s_right.device.node))) e2_right = fx.Proxy( graph_right.call_function( torch.eye, (mul_ir_in2.mul, ), dict(dtype=x2s_right.dtype.node, device=x2s_right.device.node))) i1_right = fx.Proxy( graph_right.call_function( torch.eye, (mul_ir_in1.ir.dim, ), dict(dtype=x2s_right.dtype.node, device=x2s_right.device.node))) assert ins.connection_mode in [ 'uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv' ] alpha = sqrt( alpha / { 'uvw': (mul_ir_in1.mul * mul_ir_in2.mul), 'uvu': mul_ir_in2.mul, 'uvv': mul_ir_in1.mul, 'uuw': mul_ir_in1.mul, 'uuu': 1, 'uvuv': 1, }[ins.connection_mode]) if ins.has_weight: # Extract the weight from the flattened weight tensor w_out = ws_out[:, flat_weight_index:flat_weight_index + prod(ins.path_shape)].reshape(( () if shared_weights else (-1, )) + tuple(ins.path_shape)) w_right = ws_right[:, flat_weight_index:flat_weight_index + prod(ins.path_shape)].reshape( (() if shared_weights else (-1, )) + tuple(ins.path_shape)) flat_weight_index += prod(ins.path_shape) # Construct the general xx in case this instruction isn't specialized # If this isn't used, the dead code will get removed key = (ins.i_in1, ins.i_in2, ins.connection_mode[:2]) if key not in xx_dict: if ins.connection_mode[:2] == 'uv': xx_dict[key] = torch.einsum('zui,zvj->zuvij', x1_out, x2_out) if ins.connection_mode[:2] == 'uu': xx_dict[key] = torch.einsum('zui,zuj->zuij', x1_out, x2_out) xx = xx_dict[key] # Create a proxy & request for the relevant wigner w3j # If not used (because of specialized code), will get removed later. key = (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) if key not in w3j: w3j_dict_out[key] = fx.Proxy( graph_out.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}")) w3j_dict_right[key] = fx.Proxy( graph_right.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}")) w3j.append(key) w3j_out = w3j_dict_out[key] w3j_right = w3j_dict_right[key] exp = {'component': 1, 'norm': -1}[normalization] if ins.connection_mode == 'uvw': assert ins.has_weight if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}uvw,zu,zv->zw", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uvw,zv->zuw", w_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}uvw,zu,zvj->zwj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum(f"{z}uvw,zvi->zuwi", w_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}uvw,zui,zv->zwi", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uvw,ij,zv->zuiwj", w_right, i1_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}uvw,zui,zvi->zw", w_out, x1_out, x2_out) / sqrt(mul_ir_in1.ir.dim)**exp ein_right = torch.einsum(f"{z}uvw,zvi->zuiw", w_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}uvw,ijk,zuvij->zwk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}uvw,ijk,zvj->zuiwk", w_right, w3j_right, x2_right) if ins.connection_mode == 'uvu': assert mul_ir_in1.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}uv,zu,zv->zu", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uv,uw,zv->zuw", w_right, e1_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}uv,zu,zvj->zuj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum(f"{z}uv,uw,zvi->zuwi", w_right, e1_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}uv,zui,zv->zui", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uv,ij,uw,zv->zuiwj", w_right, i1_right, e1_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}uv,zui,zvi->zu", w_out, x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp ein_right = torch.einsum(f"{z}uv,uw,zvi->zuiw", w_right, e1_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwk", w_right, w3j_right, e1_right, x2_right) else: # not so useful operation because v is summed ein_out = torch.einsum("ijk,zuvij->zuk", w3j_out, xx) ein_right = torch.einsum("ijk,uw,zvj->zuiwk", w3j_right, e1_right, x2_right) if ins.connection_mode == 'uvv': assert mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}uv,zu,zv->zv", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uv,vw,zv->zuw", w_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}uv,zu,zvj->zvj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum(f"{z}uv,vw,zvi->zuwi", w_right, e2_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}uv,zui,zv->zvi", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uv,ij,vw,zv->zuiwj", w_right, i1_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}uv,zui,zvi->zv", w_out, x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp ein_right = torch.einsum(f"{z}uv,vw,zvi->zuiw", w_right, e2_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zvk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}uv,ijk,zvj->zuivk", w_right, w3j_right, x2_right) else: # not so useful operation because u is summed # only specialize out for this path if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( "zu,zv->zv", x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( "zu,zvj->zvj", x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( "zui,zv->zvi", x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum("zui,zvi->zv", x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum("ijk,zuvij->zvk", w3j_out, xx) s2ones = fx.Proxy( graph_right.call_function( torch.ones, (mul_ir_in1.mul, ), dict(device=x2_right.device.node, dtype=x2_right.dtype.node))) ein_right = torch.einsum("u,ijk,zvj->zuivk", s2ones, w3j_right, x2_right) if ins.connection_mode == 'uuw': assert mul_ir_in1.mul == mul_ir_in2.mul if ins.has_weight: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}uw,zu,zu->zw", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}uw,zu,zuj->zwj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}uw,zui,zu->zwi", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}uw,zui,zui->zw", w_out, x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}uw,ijk,zuij->zwk", w_out, w3j_out, xx) # TODO: specialize right() ein_right = torch.einsum(f"{z}uw,ijk,zuj->zuiwk", w_right, w3j_right, x2_right) else: # equivalent to tp(x, y, 'uuu').sum('u') assert mul_ir_out.mul == 1 ein_out = torch.einsum("ijk,zuij->zk", w3j_out, xx) ein_right = torch.einsum("ijk,zuj->zuik", w3j_right, x2_right) if ins.connection_mode == 'uuu': assert mul_ir_in1.mul == mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}u,zu,zu->zu", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}u,uw,zu->zuw", w_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and key == ( 1, 1, 1) and normalization == "component": ein_out = torch.einsum(f"{z}u,zui->zui", w_out, torch.cross(x1_out, x2_out, dim=2)) / sqrt(2) # For cross product, use the general case right() ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk", w_right, w3j_right, e1_right, x2_right) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}u,zu,zuj->zuj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum(f"{z}u,uw,zui->zuwi", w_right, e2_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}u,zui,zu->zui", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}u,ij,uw,zu->zuiwj", w_right, i1_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}u,zui,zui->zu", w_out, x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp ein_right = torch.einsum(f"{z}u,uw,zui->zuiw", w_right, e2_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}u,ijk,zuij->zuk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk", w_right, w3j_right, e1_right, x2_right) else: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( "zu,zu->zu", x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( "uw,zu->zuw", e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and key == ( 1, 1, 1) and normalization == "component": ein_out = torch.cross(x1_out, x2_out, dim=2) * (1.0 / sqrt(2)) # For cross product, use the general case right() ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right, e1_right, x2_right) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( "zu,zuj->zuj", x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum("uw,zui->zuwi", e2_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( "zui,zu->zui", x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( "ij,uw,zu->zuiwj", i1_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum("zui,zui->zu", x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp ein_right = torch.einsum("uw,zui->zuiw", e2_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum("ijk,zuij->zuk", w3j_out, xx) ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right, e1_right, x2_right) if ins.connection_mode == 'uvuv': assert mul_ir_in1.mul * mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: # TODO implement specialized code ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuvk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwvk", w_right, w3j_right, e1_right, x2_right) else: # TODO implement specialized code ein_out = torch.einsum("ijk,zuvij->zuvk", w3j_out, xx) ein_right = torch.einsum("ijk,uw,zvj->zuiwvk", w3j_right, e1_right, x2_right) ein_out = alpha * ein_out ein_right = alpha * ein_right out_list_out += [ein_out.reshape(batch_out, mul_ir_out.dim)] out_list_right += [ ein_right.reshape(batch_right, mul_ir_in1.dim, mul_ir_out.dim) ] # Close the profiler block graph_out.call_function(torch.ops.profiler._record_function_exit, (handle_out, )) graph_right.call_function(torch.ops.profiler._record_function_exit, (handle_right, )) # Remove unused w3js: if len(w3j_out.node.users) == 0 and len(w3j_right.node.users) == 0: del w3j[-1] # The w3j nodes are reshapes, so we have to remove them from the graph # Although they are dead code, they try to reshape to dimensions that don't exist # (since the corresponding w3js are not in w3j) # so they screw up the shape propagation, even though they would be removed later as dead code by TorchScript. graph_out.erase_node(w3j_dict_out.pop(key).node) graph_right.erase_node(w3j_dict_right.pop(key).node) # = Return the result = out_out = [ _sum_tensors([ out for ins, out in zip(instructions, out_list_out) if ins.i_out == i_out ], shape=(batch_out, mul_ir_out.dim), like=x1s_out) for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0 ] if len(out_out) > 1: out_out = torch.cat(out_out, dim=1) else: # Avoid an unnecessary copy in a size one torch.cat out_out = out_out[0] out_right = [ torch.cat([ _sum_tensors([ out for ins, out in zip(instructions, out_list_right) if (ins.i_in1, ins.i_out) == (i_in1, i_out) ], shape=(batch_right, mul_ir_in1.dim, mul_ir_out.dim), like=x2s_right) for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0 ], dim=2) for i_in1, mul_ir_in1 in enumerate(irreps_in1) if mul_ir_in1.mul > 0 ] if len(out_right) > 1: out_right = torch.cat(out_right, dim=1) else: out_right = out_right[0] out_out = out_out.reshape(outsize_out) out_right = out_right.reshape(outsize_right) graph_out.output(out_out.node, torch.Tensor) graph_right.output(out_right.node, torch.Tensor) # check graphs graph_out.lint() graph_right.lint() # Make GraphModules wigner_mats = {} for l_1, l_2, l_out in w3j: wig = o3.wigner_3j(l_1, l_2, l_out) if normalization == 'component': wig *= (2 * l_out + 1)**0.5 if normalization == 'norm': wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5 wigner_mats[f"_w3j_{l_1}_{l_2}_{l_out}"] = wig # By putting the constants in a Module rather than a dict, # we force FX to copy them as buffers instead of as attributes. # # FX seems to have resolved this issue for dicts in 1.9, but we support all the way back to 1.8.0. constants_root = torch.nn.Module() for wkey, wmat in wigner_mats.items(): constants_root.register_buffer(wkey, wmat) graphmod_out = fx.GraphModule(constants_root, graph_out, class_name="tp_forward") graphmod_right = fx.GraphModule(constants_root, graph_right, class_name="tp_right") # == Optimize == # TODO: when eliminate_dead_code() is in PyTorch stable, use that if optimize_einsums: # Note that for our einsums, we can optimize _once_ for _any_ batch dimension # and still get the right path for _all_ batch dimensions. # This is because our einsums are essentially of the form: # zuvw,ijk,zuvij->zwk OR uvw,ijk,zuvij->zwk # In the first case, all but one operands have the batch dimension # => The first contraction gains the batch dimension # => All following contractions have batch dimension # => All possible contraction paths have cost that scales linearly in batch size # => The optimal path is the same for all batch sizes # For the second case, this logic follows as long as the first contraction is not between the first two operands. Since those two operands do not share any indexes, contracting them first is a rare pathological case. See # https://github.com/dgasmith/opt_einsum/issues/158 # for more details. # # TODO: consider the impact maximum intermediate result size on this logic # \- this is the `memory_limit` option in opt_einsum # TODO: allow user to choose opt_einsum parameters? # # We use float32 and zeros to save memory and time, since opt_einsum_fx looks only at traced shapes, not values or dtypes. batchdim = 4 example_inputs = ( torch.zeros((batchdim, irreps_in1.dim)), torch.zeros((batchdim, irreps_in2.dim)), torch.zeros( 1 if shared_weights else batchdim, flat_weight_index, ), ) graphmod_out = jitable( optimize_einsums_full(graphmod_out, example_inputs)) graphmod_right = jitable( optimize_einsums_full(graphmod_right, example_inputs[1:])) return graphmod_out, graphmod_right