def test_prune_inv_multiple_inv(backendopt): for datatype in backendopt: A0 = ad.Variable(name="A0", shape=[2, 2]) A1 = ad.Variable(name="A1", shape=[2, 2]) A2 = ad.Variable(name="A2", shape=[2, 2]) out = ad.einsum('ab,bc,cd,de,ef,fg,gh->ah', A0, A1, A1, ad.tensorinv(ad.einsum('ab,bc->ac', A1, A1), ind=1), A2, A2, ad.tensorinv(ad.einsum('ab,bc->ac', A2, A2), ind=1)) new_out = prune_inv_node(out) for node in new_out.inputs: assert not isinstance(node, ad.EinsumNode) assert tree_eq(out, new_out, [A0, A1, A2], tol=1e-6)
def test_prune_inv_different_num_inputs_no_pruning(backendopt): A = ad.Variable(name="A", shape=[2, 2]) inv_input = ad.einsum('ab,bc->ac', A, A) output = ad.einsum('ab,bc->ac', ad.tensorinv(inv_input, ind=1), A) new_output = prune_inv_node(output) assert new_output is output
def test_prune_inv_set_not_match(backendopt): A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) inv = ad.tensorinv(ad.einsum('ab,bc->ac', A, B), ind=1) output = ad.einsum('ab,bc->ac', inv, A) new_output = prune_inv_node(output) assert new_output is output
def test_kronecker_product_non_even(backendopt): A = ad.Variable(name="A", shape=[4, 4, 2, 2]) B = ad.Variable(name="B", shape=[2, 2]) out = ad.einsum("abcd,ef->abcdef", A, B) inv = ad.tensorinv(out, ind=2) newinv = optimize_inverse(inv) assert inv is newinv
def test_tensorinv_tensor(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[3, 2, 3, 2]) inv_x = ad.tensorinv(x) executor = ad.Executor([inv_x]) x_val = T.random([3, 2, 3, 2]) inv_x_val, = executor.run(feed_dict={x: x_val}) assert T.array_equal(inv_x_val, T.tensorinv(x_val))
def test_kronecker_product_nondecomposable(backendopt): A = ad.Variable(name="A", shape=[2, 3]) B = ad.Variable(name="B", shape=[3, 2]) out = ad.einsum("ab,cd->acbd", A, B) inv = ad.tensorinv(out) newinv = optimize_inverse(inv) assert inv is newinv
def test_prune_inv_nonmatmul_no_pruning(backendopt): A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) inv_input = ad.einsum('ab,bc->ac', A, B) # inv(inv_input) * inv_input.T, cannot be pruned output = ad.einsum('ac,ab,bc->ac', ad.tensorinv(inv_input, ind=1), A, B) new_output = prune_inv_node(output) assert new_output is output
def optimize_inverse(inv_node): """ Optimize the inverse of an einsum expression. Parameters ---------- node: The inverse of a fused einsum node Returns ------- If the input node cannot be optimized, then return the input node. If it can be optimized, return the optimized node. """ assert isinstance(inv_node, ad.TensorInverseNode) # Note: currently, the optimization algorithm only works when # 1. the matrix row and column has same number of dimension, # 2. the matrix is square, # 3. each corresponding dimension in row and column has the same size. if inv_node.input_indices_length * 2 != len(inv_node.shape): logger.info(f"Dimension length doesn't agree, can't optimize inverse") return inv_node matrix_dim = int(len(inv_node.shape) / 2) assert np.prod(inv_node.shape[:matrix_dim]) == np.prod( inv_node.shape[matrix_dim:]) shape_diff_list = [ inv_node.shape[i] - inv_node.shape[i + matrix_dim] for i in range(matrix_dim) ] if any(shape_diff != 0 for shape_diff in shape_diff_list): logger.info( f"Each corresponding dimension in row and column doesn't have the same size, can't optimize inverse" ) return inv_node input_node = inv_node.inputs[0] if isinstance(input_node, ad.EinsumNode): return split_inv_einsum(inv_node) if isinstance(input_node, ad.AddNode) and (input_node.inputs[0].name == input_node.inputs[1].name): inverse_node = optimize_inverse(ad.tensorinv(input_node.inputs[0])) subscript = "".join( [chr(ord('a') + i) for i in range(len(inverse_node.shape))]) return ad.einsum(f",{subscript}->{subscript}", ad.ScalarNode(0.5), inverse_node) return inv_node
def split_inv_einsum(inv_node): """ Optimize the inverse of an einsum expression, such that inverse is operated on several smaller tensors. Parameters ---------- node: The inverse of a fused einsum node Returns ------- If the input node cannot be optimized, then return the input node. If it can be optimized, return the optimized node. """ einsum_node = inv_node.inputs[0] assert isinstance(einsum_node, ad.EinsumNode) # einsum_node is a fused einsum for node in einsum_node.inputs: assert not isinstance(node, ad.EinsumNode) in_subs, out_subs, _ = _parse_einsum_input( (einsum_node.einsum_subscripts, *einsum_node.inputs)) in_subs_list = in_subs.split(',') p_einsum_node = PseudoNode(node=einsum_node, subscript=out_subs) p_in_nodes = [] for i, node in enumerate(einsum_node.inputs): p_in_nodes.append(PseudoNode(node=node, subscript=in_subs_list[i])) dsets = inv_disjoint_sets(p_einsum_node, p_in_nodes) # If the node cannot be decomposed, just return the input node if len(dsets) == 1: return inv_node new_inputs = [] for dset in dsets: input_decomp_einsum = list( filter(lambda node: any(char in dset for char in node.subscript), p_in_nodes)) out_subs = "".join( [char for char in p_einsum_node.subscript if char in dset]) decomp_node = generate_new_einsum(input_decomp_einsum, out_subs) decomp_node.set_in_indices_length(int(len(out_subs) / 2)) input_node = PseudoNode(node=ad.tensorinv(decomp_node), subscript=out_subs) new_inputs.append(input_node) return generate_new_einsum(new_inputs, p_einsum_node.subscript)
def test_prune_inv_nodes_transpose(backendopt): for datatype in backendopt: A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) inv_input = ad.einsum('ab,bc->ca', A, B) # inv(inv_input.T) @ inv_input.T output = ad.einsum('ca,cd,de->ae', ad.tensorinv(inv_input, ind=1), A, B) new_output = prune_inv_node(output) assert isinstance(new_output, ad.IdentityNode) assert tree_eq(output, new_output, [A, B], tol=1e-6)
def test_get_common_ancestor_w_inv(backendopt): A = ad.Variable(name="A", shape=[3, 3]) X = ad.Variable(name="X", shape=[3, 3, 3]) inv = ad.tensorinv(ad.einsum("ab,ac->bc", A, A), ind=1) einsum_node = ad.einsum('abc,ad,ce->bce', X, A, inv) opt_einsum = generate_optimal_tree(einsum_node) sub_einsum = get_common_ancestor(opt_einsum, einsum_node.inputs, A) # sub_einsum should be ad.einsum('ad,abc->dbc',A,X), and shouldn't include the inv node. assert sorted(get_all_inputs(sub_einsum), key=lambda node: node.name) == sorted( [A, X], key=lambda node: node.name)
def test_tensorinv_odd_dim(backendopt): for datatype in backendopt: T.set_backend(datatype) x = ad.Variable(name="x", shape=[24, 8, 3]) inv_x = ad.tensorinv(x, ind=1) assert inv_x.shape == [8, 3, 24] assert inv_x.input_indices_length == 2 executor = ad.Executor([inv_x]) x_val = T.random([24, 8, 3]) inv_x_val, = executor.run(feed_dict={x: x_val}) assert T.array_equal(inv_x_val, T.tensorinv(x_val, ind=1))
def cpd_als_shared_exec(dim, size, rank, num_iter, input_val=[]): A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) full_hessian = ad.hessian(loss, A_list) hessians = [full_hessian[i][i] for i in range(len(full_hessian))] grads = ad.gradients(loss, A_list) updates = [ ad.tensordot(ad.tensorinv(hes), grad, [[2, 3], [0, 1]]) for (hes, grad) in zip(hessians, grads) ] new_A_list = [simplify(A - update) for (A, update) in zip(A_list, updates)] new_A_list = generate_sequential_optimal_tree(new_A_list, A_list) executor = ad.Executor(new_A_list) executor_loss = ad.Executor([simplify(loss)]) if input_val == []: A_val_list, input_tensor_val = init_rand_cp(dim, size, rank) else: A_val_list, input_tensor_val = input_val for iter in range(num_iter): t0 = time.time() # als iterations for i in range(len(A_list)): feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) if i == 0: A_val_list[0], = executor.run(feed_dict=feed_dict, out_nodes=[new_A_list[0]]) else: A_val_list[i], = executor.run(feed_dict=feed_dict, reset_graph=False, evicted_inputs=[A_list[i - 1]], out_nodes=[new_A_list[i]]) feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) loss_val, = executor_loss.run(feed_dict=feed_dict) print(f'At iteration {iter} the loss is: {loss_val}') t1 = time.time() print(f"[ {iter} ] Sweep took {t1 - t0} seconds") return A_val_list
def test_simplify_inv_w_identity(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,cd->acbd", A, ad.tensorinv(ad.identity(3))) newout = simplify(out) assert isinstance(newout, ad.EinsumNode) assert isinstance(newout.inputs[1], ad.IdentityNode) assert tree_eq(out, newout, [A], tol=1e-6)
def test_s2s_tensorinv(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) B = ad.tensorinv(A) A_val = T.tensor([[1., 0.], [0., 1.]]) StS = SourceToSource() fwd_str = StS.forward([B], function_name='fwd', backend=datatype) m = import_code(fwd_str) out, = m.fwd([A_val]) assert T.array_equal(A_val, out)
def test_simplify_inv_w_redundent_einsum(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,cd->abcd", A, ad.tensorinv(ad.einsum("ab->ab", A))) newout = simplify(out) inv_node = newout.inputs[1] assert isinstance(inv_node.inputs[0], ad.VariableNode) assert tree_eq(out, newout, [A], tol=1e-6)
def test_kronecker_product_repeated_inputs(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) out = ad.einsum("ab,cd->acbd", A, A) inv = ad.tensorinv(out) newinv = optimize_inverse(inv) assert isinstance(newinv, ad.EinsumNode) for node in newinv.inputs: assert isinstance(node, ad.TensorInverseNode) assert tree_eq(inv, newinv, [A], tol=1e-5)
def test_inv_multiple_decomposation(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) C = ad.Variable(name="C", shape=[2, 2]) out = ad.einsum("ab,cd,ef->acebdf", A, B, C) inv = ad.tensorinv(out) newinv = optimize_inverse(inv) assert isinstance(newinv, ad.EinsumNode) for node in newinv.inputs: assert isinstance(node, ad.TensorInverseNode) assert len(newinv.inputs) == 3 assert tree_eq(inv, newinv, [A, B, C], tol=1e-5)
def tucker_als_graph_shared_exec(dim, size, rank): """ Build the graph used for Tucker ALS with shared execution. Parameters ---------- dim: dimensionality of the input tensor size: the size of input tensor's each dim rank: the rank of the decomposition Returns ------- tg: an TuckerGraph object executor: An shared executor loss: the optimized graph for tucker loss updates: an list containing updates graphs for each dimension intermediates: list of einsum nodes. Each node is the objective each Tucker ALS step optimized for """ tg = TuckerGraph(dim, size, rank) updates = [] for i in range(dim): core_A = tg.intermediates[i] hes = ad.hessian(tg.losses[i], [core_A]) hes = hes[0][0] grad, = ad.gradients(tg.losses[i], [core_A]) new_core_A = core_A - ad.tensordot( ad.tensorinv(hes), grad, [[i + dim for i in range(dim)], [i for i in range(dim)]]) updates.append(simplify(new_core_A)) loss = simplify(tg.losses[0]) for i in range(1, len(tg.losses)): assert loss.name == simplify(tg.losses[i]).name updates = generate_sequential_optimal_tree(updates, tg.A_list) executor_updates = ad.Executor(updates) executor_loss = ad.Executor([loss]) return tg, executor_updates, executor_loss, loss, updates, tg.intermediates
def test_prune_inv_nodes_cpd(backendopt): for datatype in backendopt: A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) C = ad.Variable(name="C", shape=[2, 2]) inv_input = ad.einsum('ab,dc,ac,db->bc', B, C, B, C) output = ad.einsum('ed,ea,cd,ba,ca,gd->bg', C, C, B, A, B, ad.tensorinv(inv_input, ind=1)) new_output = prune_inv_node(output) # T.einsum('ba,ag->bg',A,T.identity(2)) assert len(new_output.inputs) == 2 for node in new_output.inputs: if isinstance(node, ad.VariableNode): assert node == A assert tree_eq(output, new_output, [A, B, C], tol=1e-6)
def test_high_dim_inv(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2, 2, 2]) B = ad.Variable(name="B", shape=[2, 2, 2, 2]) out = ad.einsum("aceg,dbhf->abcdefgh", A, B) inv = ad.tensorinv(out) # T.einsum('aceg,bdfh->abcdefgh', # T.tensorinv(T.einsum('aceg->aceg',A), ind=2), # T.tensorinv(T.einsum('dbhf->bdfh',B), ind=2)) newinv = optimize_inverse(inv) assert isinstance(newinv, ad.EinsumNode) for node in newinv.inputs: assert isinstance(node, ad.TensorInverseNode) assert tree_eq(inv, newinv, [A, B], tol=1e-6)
def tucker_als_graph(dim, size, rank): """ Build the graph used for Tucker ALS. Parameters ---------- dim: dimensionality of the input tensor size: the size of input tensor's each dim rank: the rank of the decomposition Returns ------- tg: an TuckerGraph object executors: list of executors. Each executor is used for one step of Tucker ALS intermediates: list of einsum nodes. Each node is the objective each Tucker ALS step optimized for """ tg = TuckerGraph(dim, size, rank) executors_update = [] for i in range(dim): core_A = tg.intermediates[i] hes = ad.hessian(tg.losses[i], [core_A]) hes = hes[0][0] grad, = ad.gradients(tg.losses[i], [core_A]) new_core_A = core_A - ad.tensordot( ad.tensorinv(hes), grad, [[i + dim for i in range(dim)], [i for i in range(dim)]]) executor = ad.Executor([simplify(new_core_A)]) executors_update.append(executor) executor_loss = ad.Executor([simplify(tg.losses[0])]) return tg, executors_update, executor_loss, tg.intermediates
def test_complex_product_inv(backendopt): for datatype in backendopt: T.set_backend(datatype) A = ad.Variable(name="A", shape=[2, 2]) B = ad.Variable(name="B", shape=[2, 2]) C = ad.Variable(name="C", shape=[2, 2]) D = ad.Variable(name="D", shape=[2, 2]) out = ad.einsum("ab,bc,de,ef->adcf", A, B, C, D) inv = ad.tensorinv(out) # T.einsum('ac,df->adcf', # T.tensorinv(T.einsum('ab,bc->ac',A,B), ind=1), # T.tensorinv(T.einsum('de,ef->df',C,D), ind=1)) newinv = optimize_inverse(inv) assert isinstance(newinv, ad.EinsumNode) for node in newinv.inputs: assert isinstance(node, ad.TensorInverseNode) assert tree_eq(inv, newinv, [A, B, C, D], tol=1e-5)
def cpd_als(dim, size, rank, num_iter, input_val=[]): A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank) full_hessian = ad.hessian(loss, A_list) hessians = [full_hessian[i][i] for i in range(len(full_hessian))] grads = ad.gradients(loss, A_list) updates = [ ad.tensordot(ad.tensorinv(hes), grad, [[2, 3], [0, 1]]) for (hes, grad) in zip(hessians, grads) ] new_A_list = [simplify(A - update) for (A, update) in zip(A_list, updates)] executor = ad.Executor(new_A_list) executor_loss = ad.Executor([simplify(loss)]) if input_val == []: A_val_list, input_tensor_val = init_rand_cp(dim, size, rank) else: A_val_list, input_tensor_val = input_val for iter in range(num_iter): # als iterations for i in range(len(A_list)): feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) A_val_list[i], = executor.run(feed_dict=feed_dict, out_nodes=[new_A_list[i]]) feed_dict = dict(zip(A_list, A_val_list)) feed_dict.update({input_tensor: input_tensor_val}) loss_val, = executor_loss.run(feed_dict=feed_dict) print(f'At iteration {iter} the loss is: {loss_val}') return A_val_list
def p_expression_tensorinv(t): 'expression : EINSUM_INVERSE LPAREN expression COMMA INV_INDEX NUMBER RPAREN' t[0] = ad.tensorinv(t[3], ind=int(t[6]))
def test_tensorinv(): A = ad.Variable(name="A", shape=[3, 3]) y = ad.tensorinv(A) assert AutodiffParser.parse(y.name, [A]).name == y.name