def test_custom_random_greedy(): eq, shapes = oe.helpers.rand_equation(10, 4, seed=42) views = list(map(np.ones, shapes)) with pytest.raises(ValueError): oe.RandomGreedy(minimize='something') optimizer = oe.RandomGreedy(max_repeats=10, minimize='flops') path, path_info = oe.contract_path(eq, *views, optimize=optimizer) assert len(optimizer.costs) == 10 assert len(optimizer.sizes) == 10 assert path == optimizer.path assert optimizer.best['flops'] == min(optimizer.costs) assert path_info.largest_intermediate == optimizer.best['size'] assert path_info.opt_cost == optimizer.best['flops'] # check can change settings and run again optimizer.temperature = 0.0 optimizer.max_repeats = 6 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) assert len(optimizer.costs) == 16 assert len(optimizer.sizes) == 16 assert path == optimizer.path assert optimizer.best['size'] == min(optimizer.sizes) assert path_info.largest_intermediate == optimizer.best['size'] assert path_info.opt_cost == optimizer.best['flops']
def test_large_path(num_symbols): symbols = ''.join(oe.get_symbol(i) for i in range(num_symbols)) dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4]))) expression = ','.join(symbols[t:t + 2] for t in range(num_symbols - 1)) tensors = oe.helpers.build_views(expression, dimension_dict=dimension_dict) # Check that path construction does not crash oe.contract_path(expression, *tensors, optimize='greedy')
def test_optimal_edge_cases(): # Edge test5 expression = 'a,ac,ab,ad,cd,bd,bc->' edge_test4 = oe.helpers.build_views(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20}) path, path_str = oe.contract_path(expression, *edge_test4, optimize='greedy', memory_limit='max_input') assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)]) path, path_str = oe.contract_path(expression, *edge_test4, optimize='optimal', memory_limit='max_input') assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
def test_greedy_edge_cases(): expression = "abc,cfd,dbe,efa" dim_dict = {k: 20 for k in expression.replace(",", "")} tensors = oe.helpers.build_views(expression, dimension_dict=dim_dict) path, path_str = oe.contract_path(expression, *tensors, optimize='greedy', memory_limit='max_input') assert check_path(path, [(0, 1, 2, 3)]) path, path_str = oe.contract_path(expression, *tensors, optimize='greedy', memory_limit=-1) assert check_path(path, [(0, 1), (0, 2), (0, 1)])
def test_chain_sharing(size, backend): xs = [np.random.rand(2, 2) for _ in range(size)] alphabet = ''.join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i:i+2] for i in range(size)] inputs = ','.join(names) num_exprs_nosharing = 0 for i in range(size + 1): with shared_intermediates() as cache: target = alphabet[i] eq = '{}->{}'.format(inputs, target) expr = contract_expression(eq, *(x.shape for x in xs)) expr(*xs, backend=backend) num_exprs_nosharing += _compute_cost(cache) with shared_intermediates() as cache: print(inputs) for i in range(size + 1): target = alphabet[i] eq = '{}->{}'.format(inputs, target) path_info = contract_path(eq, *xs) print(path_info[1]) expr = contract_expression(eq, *(x.shape for x in xs)) expr(*xs, backend=backend) num_exprs_sharing = _compute_cost(cache) print('-' * 40) print('Without sharing: {} expressions'.format(num_exprs_nosharing)) print('With sharing: {} expressions'.format(num_exprs_sharing)) assert num_exprs_nosharing > num_exprs_sharing
def test_memory_paths(): expression = "abc,bdef,fghj,cem,mhk,ljk->adgl" views = oe.helpers.build_views(expression) # Test tiny memory limit path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5) assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)]) path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5) assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)]) # Check the possibilities, greedy is capped path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1) assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)]) path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1) assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
def test_custom_branchbound(): eq, shapes = oe.helpers.rand_equation(8, 4, seed=42) views = list(map(np.ones, shapes)) optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize='size') path, path_info = oe.contract_path(eq, *views, optimize=optimizer) assert path == optimizer.path assert path_info.largest_intermediate == optimizer.best['size'] assert path_info.opt_cost == optimizer.best['flops'] # tweak settings and run again optimizer.nbranch = 3 optimizer.cutoff_flops_factor = 4 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) assert path == optimizer.path assert path_info.largest_intermediate == optimizer.best['size'] assert path_info.opt_cost == optimizer.best['flops']
def test_parallel_random_greedy(): from concurrent.futures import ProcessPoolExecutor pool = ProcessPoolExecutor(2) eq, shapes = oe.helpers.rand_equation(10, 4, seed=42) views = list(map(np.ones, shapes)) optimizer = oe.RandomGreedy(max_repeats=10, parallel=pool) path, path_info = oe.contract_path(eq, *views, optimize=optimizer) assert len(optimizer.costs) == 10 assert len(optimizer.sizes) == 10 assert path == optimizer.path assert optimizer.parallel is pool assert optimizer._executor is pool assert optimizer.best['flops'] == min(optimizer.costs) assert path_info.largest_intermediate == optimizer.best['size'] assert path_info.opt_cost == optimizer.best['flops'] # now switch to max time algorithm optimizer.max_repeats = int(1e6) optimizer.max_time = 0.2 optimizer.parallel = 2 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) assert len(optimizer.costs) > 10 assert len(optimizer.sizes) > 10 assert path == optimizer.path assert optimizer.best['flops'] == min(optimizer.costs) assert path_info.largest_intermediate == optimizer.best['size'] assert path_info.opt_cost == optimizer.best['flops'] optimizer.parallel = True assert optimizer._executor is not None assert optimizer._executor is not pool are_done = [f.running() or f.done() for f in optimizer._futures] assert all(are_done)
def find_order(self, tn, **kwargs): tn = tn._expand_and_delta() while True: lhs, rhs, shapes = tn.subscripts() path, _ = opt_einsum.contract_path(','.join(lhs) + '->' + rhs, *shapes, shapes=True, optimize=self.optimize) if len(tn.nodes_by_name) > 1: order = defaultOrderResolver.path_to_paired_order( [list(tn.nodes_by_name), '#'], path) else: order = [] yield defaultOrderResolver.order_to_contraction_scheme(tn, order)
def test_reconfigure(forested, parallel, requires): if requires: pytest.importorskip(requires) eq, shapes = oe.helpers.rand_equation(30, reg=5, seed=42, d_max=3) info_gr = oe.contract_path(eq, *shapes, shapes=True, optimize='greedy')[1] tree_gr = ContractionTree.from_info(info_gr) assert tree_gr.total_flops() == info_gr.opt_cost if forested: tree_gr.subtree_reconfigure_forest_( num_trees=2, subtree_size=6, progbar=True, parallel=parallel) else: tree_gr.subtree_reconfigure_(progbar=True) assert tree_gr.total_flops() < info_gr.opt_cost info_tsr = oe.contract_path( eq, *shapes, shapes=True, optimize=tree_gr.path())[1] assert tree_gr.total_flops() == info_tsr.opt_cost
def _einsum_contract_path(*operands, **kwargs): """Like opt_einsum.contract_path, with support for DimPolynomial shapes. We use opt_einsum.contract_path to compute the schedule, using a fixed constant for all dimension variables. This is safe because we throw an error if there are more than 1 contractions. Essentially, we just use opt_einsum.contract_path to parse the specification. """ # Replace the polymorphic shapes with some concrete shapes for calling # into opt_einsum.contract_path, because the latter wants to compute the # sizes of operands and intermediate results. fake_ops = [] for operand in operands: # We replace only array operands if not hasattr(operand, "dtype"): fake_ops.append(operand) else: shape = np.shape(operand) def fake_dim(d): if core.is_constant_dim(d): return d else: if not isinstance(d, _DimPolynomial): raise TypeError( f"Encountered unexpected shape dimension {d}") # It is Ok to replace all polynomials with the same value. We may miss # here some errors due to non-equal dimensions, but we catch them # later. return 8 fake_ops.append( jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)), operand.dtype)) contract_fake_ops, contractions = opt_einsum.contract_path( *fake_ops, **kwargs) if len(contractions) > 1: msg = ( "Shape polymorphism is not yet supported for einsum with more than " f"one contraction {contractions}") raise ValueError(msg) contract_operands = [] for operand in contract_fake_ops: idx = tuple(i for i, fake_op in enumerate(fake_ops) if operand is fake_op) assert len(idx) == 1 contract_operands.append(operands[idx[0]]) return contract_operands, contractions
def test_hyper(contraction_20_5, optlib, requires, parallel): pytest.importorskip('kahypar') pytest.importorskip(requires) if parallel: pytest.importorskip('distributed') eq, _, _, arrays = contraction_20_5 optimizer = ctg.HyperOptimizer( max_repeats=16, parallel=parallel, optlib=optlib, ) _, path_info = oe.contract_path(eq, *arrays, optimize=optimizer) assert path_info.speedup > 1 assert {x[0] for x in optimizer.get_trials()} == {'greedy', 'kahypar'} optimizer.print_trials()
def test_contraction_tree_equivalency(): import opt_einsum as oe from cotengra.core import ContractionTree eq = "a,ab,bc,c->" shapes = [(4, ), (4, 2), (2, 5), (5, )] # optimal contraction is like: # o # / \ # o o # / \ / \ _, info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=[(0, 1), (0, 1), (0, 1)]) _, info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=[(2, 3), (0, 1), (0, 1)]) assert info1.contraction_list != info2.contraction_list ct1 = ContractionTree.from_pathinfo(info1) ct2 = ContractionTree.from_pathinfo(info2) assert ct1.total_flops() == ct2.total_flops() == 40 assert ct1.children == ct2.children
def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize): """Returns the (memoized) result of opt_einsum.contract_path.""" # Note: We use einsum_call=True, which is an internal api for opt_einsum, # to get the contraction path without having opt_einsum perform the actual # contractions. _, contractions = opt_einsum.contract_path(equation, *shaped_inputs_tuple, optimize=optimize, einsum_call=True, use_blas=True) # Return a tuple so that the cached value is not mutable. indices_and_equations = tuple([(expr[0], expr[2]) for expr in contractions]) return indices_and_equations
def test_contraction_tree_equivalency(): eq = "a,ab,bc,c->" shapes = [(4, ), (4, 2), (2, 5), (5, )] # optimal contraction is like: # o # / \ # o o # / \ / \ _, info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=[(0, 1), (0, 1), (0, 1)]) _, info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=[(2, 3), (0, 1), (0, 1)]) assert info1.contraction_list != info2.contraction_list ct1 = ContractionTree.from_info(info1, check=True) ct2 = ContractionTree.from_info(info2, check=True) assert ct1.total_flops() == ct2.total_flops() == 40 assert ct1.children == ct2.children assert ct1.is_complete() assert ct2.is_complete()
def test_insane_nested(parallel_backend): if parallel_backend == 'dask': pytest.importorskip('distributed') else: pytest.importorskip(parallel_backend) eq, shapes = oe.helpers.rand_equation(30, reg=5, seed=42, d_max=3) optimizer = ctg.HyperOptimizer(max_repeats=16, parallel=parallel_backend, optlib='random', progbar=True, slicing_reconf_opts={ 'target_size': 2**20, 'forested': True, 'max_repeats': 4, 'num_trees': 2, 'reconf_opts': { 'forested': True, 'num_trees': 2, 'subtree_size': 6, } }) oe.contract_path(eq, *shapes, shapes=True, optimize=optimizer) assert optimizer.get_tree().max_size() <= 2**20
def test_custom_random_greedy(): eq, shapes = oe.helpers.rand_equation(10, 4, seed=42) views = list(map(np.ones, shapes)) with pytest.raises(ValueError): oe.RandomGreedy(minimize='something') optimizer = oe.RandomGreedy(max_repeats=10, minimize='flops') path, path_info = oe.contract_path(eq, *views, optimize=optimizer) assert len(optimizer.costs) == 10 assert len(optimizer.sizes) == 10 assert path == optimizer.path assert optimizer.best['flops'] == min(optimizer.costs) assert path_info.largest_intermediate == optimizer.best['size'] assert path_info.opt_cost == optimizer.best['flops'] # check can change settings and run again optimizer.temperature = 0.0 optimizer.max_repeats = 6 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) assert len(optimizer.costs) == 16 assert len(optimizer.sizes) == 16 assert path == optimizer.path assert optimizer.best['size'] == min(optimizer.sizes) assert path_info.largest_intermediate == optimizer.best['size'] assert path_info.opt_cost == optimizer.best['flops'] # check error if we try and reuse the optimizer on a different expression eq, shapes = oe.helpers.rand_equation(10, 4, seed=41) views = list(map(np.ones, shapes)) with pytest.raises(ValueError): path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
def test_optimizer_registration(): def custom_optimizer(inputs, output, size_dict, memory_limit): return [(0, 1)] * (len(inputs) - 1) with pytest.raises(KeyError): oe.paths.register_path_fn('optimal', custom_optimizer) oe.paths.register_path_fn('custom', custom_optimizer) assert 'custom' in oe.paths._PATH_OPTIONS eq = 'ab,bc,cd' shapes = [(2, 3), (3, 4), (4, 5)] path, path_info = oe.contract_path(eq, *shapes, shapes=True, optimize='custom') assert path == [(0, 1), (0, 1)] del oe.paths._PATH_OPTIONS['custom']
def subscript_to_path(self, lhs, rhs, shapes): optimize = None try: path, path_info = opt_einsum.contract_path(','.join(lhs) + '->' + rhs, *shapes, shapes=True, optimize='auto') except ValueError as e: print(','.join(lhs) + '->' + rhs, shapes) raise e if len(lhs) > 2 and path_info.opt_cost > self.thres: if len(lhs) < self.optimal: optimize = 'optimal' elif len(lhs) < self.dp: optimize = 'dp' if optimize is not None: path, path_info = opt_einsum.contract_path(','.join(lhs) + '->' + rhs, *shapes, shapes=True, optimize=optimize) return path, ContractionCost(path_info.opt_cost, path_info.largest_intermediate)
def test_chain_2(size, backend): xs = [np.random.rand(2, 2) for _ in range(size)] shapes = [x.shape for x in xs] alphabet = ''.join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i:i+2] for i in range(size)] inputs = ','.join(names) with shared_intermediates(): print(inputs) for i in range(size): target = alphabet[i:i+2] eq = '{}->{}'.format(inputs, target) path_info = contract_path(eq, *xs) print(path_info[1]) expr = contract_expression(eq, *shapes) expr(*xs, backend=backend) print('-' * 40)
def test_chain_2(size, backend): xs = [np.random.rand(2, 2) for _ in range(size)] shapes = [x.shape for x in xs] alphabet = ''.join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i:i + 2] for i in range(size)] inputs = ','.join(names) with shared_intermediates(): print(inputs) for i in range(size): target = alphabet[i:i + 2] eq = '{}->{}'.format(inputs, target) path_info = contract_path(eq, *xs) print(path_info[1]) expr = contract_expression(eq, *shapes) expr(*xs, backend=backend) print('-' * 40)
def test_slice_and_reconfigure(forested, parallel, requires): if requires: pytest.importorskip(requires) eq, shapes = oe.helpers.rand_equation(30, reg=5, seed=42, d_max=2) info_gr = oe.contract_path(eq, *shapes, shapes=True, optimize='greedy')[1] tree_gr = ContractionTree.from_info(info_gr) target_size = tree_gr.max_size() // 32 if forested: tree_gr.slice_and_reconfigure_forest_( target_size, num_trees=2, progbar=True, parallel=parallel) else: tree_gr.slice_and_reconfigure_(target_size, progbar=True) assert tree_gr.max_size() <= target_size
def test_slicer(): eq, shapes = oe.helpers.rand_equation(30, reg=5, seed=42, d_max=3) arrays = [np.random.uniform(size=s) for s in shapes] path, info = oe.contract_path(eq, *shapes, shapes=True) expected = oe.contract(eq, *arrays, optimize=path) sf = ctg.SliceFinder(info, target_size=1_000_000, target_overhead=None) inds, ccost = sf.search() assert info.largest_intermediate > 1_000_000 assert ccost.size <= 1_000_000 assert ccost.total_flops > info.opt_cost assert len(inds) > 1 sc = sf.SlicedContractor(arrays) assert sc.total_flops == ccost.total_flops assert sc.contract_all() == pytest.approx(expected)
def logeinsumexp(formula, *tensors): """Computing `einsum` in logarithmic space. Args: formula: a formula compatible with `einsum`. (But note that ellipses are not supported.) *tensors: tensors to which contractions will be applied. Returns: `logeinsumexp(formula, *tensors)` is equivalent to `tf.math.log(einsum(formula, *map(tf.math.exp, tensors)))` except that it is more numerically stable. Notes: This implementation of `logeinsumexp` is primarily intended for internal use in the `marginalize` library and assumes the formula is well-formed. """ _, path = oe.contract_path(formula, *tensors) return _execute_contract_path(path.contraction_list, list(tensors))
def __init__( self, eq, arrays, sliced, optimize='auto', size_dict=None, ): # basic info lhs, self.output = eq.split('->') self.inputs = lhs.split(',') self.arrays = tuple(arrays) self.sliced = tuple(sorted(sliced, key=eq.index)) if size_dict is None: size_dict = create_size_dict(self.inputs, self.arrays) self.size_dict = size_dict # find which arrays are going to be sliced or not self.constant, self.changing = [], [] for i, term in enumerate(self.inputs): if any(ix in self.sliced for ix in term): self.changing.append(i) else: self.constant.append(i) # information about the contraction of a single slice self.eq_sliced = "".join(c for c in eq if c not in sliced) self.sliced_sizes = tuple(self.size_dict[i] for i in self.sliced) self.nslices = compute_size_by_dict(self.sliced, self.size_dict) self.shapes_sliced = tuple( tuple(self.size_dict[i] for i in term) for term in self.eq_sliced.split('->')[0].split(',') ) self.path, self.info_sliced = contract_path( self.eq_sliced, *self.shapes_sliced, shapes=True, optimize=optimize ) # generate the contraction expression self._expr = contract_expression( self.eq_sliced, *self.shapes_sliced, optimize=self.path )
def test_einsum(): I = np.random.rand(1, 10, 10, 10, 10) C = np.random.rand(1, 10, 10) operands, contraction_list = contract_path('tea,tfb,tabcd,tgc,thd->tefgh', C, C, I, C, C, einsum_call=True) ans = log_einsum_np('tea,tfb,tabcd,tgc,thd->tefgh', C, C, I, C, C, _test=True) print((np.sin(ans)**2).sum() - (np.sin(np.einsum('tea,tfb,tabcd,tgc,thd->tefgh', C, C, I, C, C))**2 ).sum())
def test_path_edge_cases(alg, expression, order): views = oe.helpers.build_views(expression) # Test tiny memory limit path_ret = oe.contract_path(expression, *views, optimize=alg) assert check_path(path_ret[0], order)
for x in range(200): sum_string, views, index_size = random_contraction() try: ein = np.einsum(sum_string, *views) except Exception as error: out.append(['Einsum failed', sum_string, index_size, 0, 0]) continue try: opt = oe.contract(sum_string, *views, path=opt_path) except Exception as error: out.append(['Opt_einsum failed', sum_string, index_size, 0, 0]) continue current_opt_path = oe.contract_path(sum_string, *views, optimize=opt_path)[0] if not np.allclose(ein, opt): out.append(['Comparison failed', sum_string, index_size, 0, 0]) continue setup = "import numpy as np; import opt_einsum as oe; \ from __main__ import sum_string, views, current_opt_path" einsum_string = "np.einsum(sum_string, *views)" contract_string = "oe.contract(sum_string, *views, path=current_opt_path)" e_n = 1 o_n = 1 einsum_time = timeit.timeit(einsum_string, setup=setup, number=e_n) / e_n contract_time = timeit.timeit(contract_string, setup=setup, number=o_n) / o_n out.append([True, sum_string, current_opt_path, einsum_time, contract_time])
for x in range(200): sum_string, views, index_size = random_contraction() try: ein = np.einsum(sum_string, *views) except Exception: out.append(["Einsum failed", sum_string, index_size, 0, 0]) continue try: opt = oe.contract(sum_string, *views, path=opt_path) except Exception: out.append(["Opt_einsum failed", sum_string, index_size, 0, 0]) continue current_opt_path = oe.contract_path(sum_string, *views, optimize=opt_path)[0] if not np.allclose(ein, opt): out.append(["Comparison failed", sum_string, index_size, 0, 0]) continue setup = "import numpy as np; import opt_einsum as oe; \ from __main__ import sum_string, views, current_opt_path" einsum_string = "np.einsum(sum_string, *views)" contract_string = "oe.contract(sum_string, *views, path=current_opt_path)" e_n = 1 o_n = 1 einsum_time = timeit.timeit(einsum_string, setup=setup, number=e_n) / e_n contract_time = timeit.timeit(contract_string, setup=setup, number=o_n) / o_n
def log_einsum_np(contract, *args, contraction_list=None, _test=False): """ Taken from here https://github.com/dgasmith/opt_einsum/blob/master/opt_einsum/contract.py but instead of ( multiply, sum ), will do ( sum, logsumexp ). So this is just einsum in log space. Args: contract - The contraction to perform args - The tensors to contract contraction_list - Pre-computed the contraction order _test - If we want to test this function Returns: The result of the contraction """ # Make it easy to test against a correct einsum implementation if (_test == True): def product(x, y): return real_product(x, y) def integrate(x, axis): return real_integrate(x, axis) else: def product(x, y): return log_product(x, y) def integrate(x, axis): return log_integrate(x, axis) # If we haven't passed in the contraction list, find it if (contraction_list is None): _, contraction_list = contract_path(contract, *args, einsum_call=True, optimize='auto') operands = list(args) # Find the unique letters in the contract and allocate a list for the final transpose unique_letters = ''.join(sorted(set(contract))).replace(',', '').replace( '-', '').replace('>', '') n_unique_letters = len(unique_letters) transpose_back = [0 for _ in unique_letters] # Start contraction loop for num, (inds, idx_rm, einsum_str, remaining, _) in enumerate(contraction_list): # Retrieve the current operands and get split the contract tmp_operands = [operands.pop(x) for x in inds] input_str, results_index = einsum_str.split('->') # Check if we should multiply and then contract if (len(inds) > 1): left_operand, right_operand = tmp_operands input_left, input_right = input_str.split(',') # Want to transpose the operands to be in alphabetical order so that multiplying them is easy not_in_left = ''.join([ letter for letter in unique_letters if letter not in input_left ]) not_in_right = ''.join([ letter for letter in unique_letters if letter not in input_right ]) left_shape = input_left + not_in_left right_shape = input_right + not_in_right # Align operands on the correct axes in order to do the sum transpose_left = tuple( [left_shape.index(letter) for letter in unique_letters]) transpose_right = tuple( [right_shape.index(letter) for letter in unique_letters]) # Extend the axes of the operands and transpose them shape_left = list(left_operand.shape) + [ 1 for _ in range(len(left_operand.shape), n_unique_letters) ] shape_right = list(right_operand.shape) + [ 1 for _ in range(len(right_operand.shape), n_unique_letters) ] reshaped_left = left_operand.reshape( tuple(shape_left)).transpose(transpose_left) reshaped_right = right_operand.reshape( tuple(shape_right)).transpose(transpose_right) # Sum up the terms summed = product(reshaped_left, reshaped_right) # Transpose the output back and put the removal indices on the last axes not_in_result = ''.join([ letter for letter in unique_letters if letter not in results_index ]) full_results_index = results_index + not_in_result for i, letter in enumerate(full_results_index): transpose_back[i] = unique_letters.index(letter) swapped_summed = summed.transpose(tuple(transpose_back)) # Integrate out terms if needed if (len(idx_rm) > 0): remove_idx = tuple( list(range(len(results_index), n_unique_letters))) new_view = integrate(swapped_summed, axis=remove_idx) else: # Don't squeeze the first dim! This messes things up if we have a batch size of 1! trailing_ones = tuple([ i for i, s in enumerate(swapped_summed.shape) if s == 1 and i > 0 ]) if (len(trailing_ones) == 0): new_view = swapped_summed else: new_view = swapped_summed.squeeze(axis=trailing_ones) else: # Then we just need to do an integration step remove_idx = tuple([input_str.index(letter) for letter in idx_rm]) new_view = integrate(tmp_operands[0], axis=remove_idx) # Append new items and dereference what we can operands.append(new_view) del tmp_operands, new_view if (_test == True): check = np.einsum(contract, *args) assert np.allclose(check, operands[0]) return operands[0]
def test_reconfigure_with_n_smaller_than_subtree_size(): eq, shapes = oe.helpers.rand_equation(10, 3) path, info = oe.contract_path(eq, *shapes, shapes=True) tree = ContractionTree.from_info(info) tree.subtree_reconfigure(12)
def test_can_optimize_outer_products(optimize): a, b, c = [np.random.randn(10, 10) for _ in range(3)] d = np.random.randn(10, 2) assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize)[0] == [(2, 3), (0, 2), (0, 1)]
def test_compressed_rank(optimize): eq, shapes = oe.helpers.rand_equation(30, reg=5, seed=42, d_max=2) info = oe.contract_path(eq, *shapes, shapes=True, optimize=optimize)[1] tree = ContractionTree.from_info(info) assert tree.compressed_rank(1) < math.log2(tree.max_size())
for i in (j - 1, j + 2, j, j - 2, j + 1)) einsum_str += "{}{}{},{}{}{},".format(m, ul, ur, m, ll, lr) # finish with last site # --O # | # --O i = n - 1 j = 3 * i ul, m, ll, = (oe.get_symbol(i) for i in (j - 1, j, j - 2)) einsum_str += "{}{},{}{}".format(m, ul, m, ll) def gen_shapes(): yield (phys_dim, bond_dim) yield (phys_dim, bond_dim) for i in range(1, n - 1): yield (phys_dim, bond_dim, bond_dim) yield (phys_dim, bond_dim, bond_dim) yield (phys_dim, bond_dim) yield (phys_dim, bond_dim) shapes = tuple(gen_shapes()) print(shapes) arrays = [np.random.randn(*shp) / 4 for shp in shapes] print(oe.contract_path(einsum_str, *arrays, memory_limit=-1)[0]) print(oe.contract_path(einsum_str, *arrays, memory_limit=-1)[1])
def test_linear_vs_ssa(equation): views = helpers.build_views(equation) linear_path, _ = contract_path(equation, *views) ssa_path = linear_to_ssa(linear_path) linear_path2 = ssa_to_linear(ssa_path) assert linear_path2 == linear_path
def test_contract_path_supply_shapes(): eq = 'ab,bc,cd' shps = [(2, 3), (3, 4), (4, 5)] contract_path(eq, *shps, shapes=True)
def test_dp_edge_cases_dimension_1(): eq = 'nlp,nlq,pl->n' shapes = [(1, 1, 1), (1, 1, 1), (1, 1)] info = oe.contract_path(eq, *shapes, shapes=True, optimize='dp')[1] assert max(info.scale_list) == 3
def test_dp_edge_cases_all_singlet_indices(): eq = 'a,bcd,efg->' shapes = [(2, ), (2, 2, 2), (2, 2, 2)] info = oe.contract_path(eq, *shapes, shapes=True, optimize='dp')[1] assert max(info.scale_list) == 3
def test_printing(): string = "bbd,bda,fc,db->acf" views = helpers.build_views(string) ein = contract_path(string, *views) assert len(str(ein[1])) == 726
def test_printing(): string = "bbd,bda,fc,db->acf" views = helpers.build_views(string) ein = contract_path(string, *views) assert len(str(ein[1])) == 728
def test_rand_equation( n, reg, n_out, n_hyper_in, n_hyper_out, d_min, d_max, seed, indices_sort, ): inputs, output, shapes, size_dict = ctg.utils.rand_equation( n=n, reg=reg, n_out=n_out, n_hyper_in=n_hyper_in, n_hyper_out=n_hyper_out, d_min=d_min, d_max=d_max, seed=seed, ) arrays = [np.random.normal(size=s) for s in shapes] eq = ",".join(map("".join, inputs)) + "->" + "".join(output) path, info = oe.contract_path(eq, *arrays, optimize='greedy') if info.largest_intermediate > 2**20: raise RuntimeError("Contraction too big.") x = oe.contract(eq, *arrays, optimize=path) tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path) if indices_sort: tree.sort_contraction_indices(indices_sort) # base contract y1 = tree.contract(arrays, check=True) assert_allclose(x, y1) # contract after modifying tree tree.subtree_reconfigure_() y2 = tree.contract(arrays, check=True) assert_allclose(x, y2) size = tree.max_size() if size < 600: return # contract after slicing and modifying tree.slice_and_reconfigure_(target_size=size // 6) y3 = tree.contract(arrays, check=True) assert_allclose(x, y3) # contract after slicing some output indices remaining_out = list(tree.output_legs) nsout = np.random.randint(low=0, high=len(remaining_out) + 1) so_ix = np.random.choice(remaining_out, replace=False, size=nsout) for ind in so_ix: tree.remove_ind_(ind) if indices_sort: tree.sort_contraction_indices(indices_sort) y4 = tree.contract(arrays, check=True) assert_allclose(x, y4)