def test_chain_sharing(size, backend): xs = [np.random.rand(2, 2) for _ in range(size)] alphabet = ''.join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i:i + 2] for i in range(size)] inputs = ','.join(names) num_exprs_nosharing = 0 for i in range(size + 1): with shared_intermediates() as cache: target = alphabet[i] eq = '{}->{}'.format(inputs, target) expr = contract_expression(eq, *(x.shape for x in xs)) expr(*xs, backend=backend) num_exprs_nosharing += _compute_cost(cache) with shared_intermediates() as cache: print(inputs) for i in range(size + 1): target = alphabet[i] eq = '{}->{}'.format(inputs, target) path_info = contract_path(eq, *xs) print(path_info[1]) expr = contract_expression(eq, *(x.shape for x in xs)) expr(*xs, backend=backend) num_exprs_sharing = _compute_cost(cache) print('-' * 40) print('Without sharing: {} expressions'.format(num_exprs_nosharing)) print('With sharing: {} expressions'.format(num_exprs_sharing)) assert num_exprs_nosharing > num_exprs_sharing
def test_chain_sharing(size, backend): xs = [np.random.rand(2, 2) for _ in range(size)] alphabet = ''.join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i:i+2] for i in range(size)] inputs = ','.join(names) num_exprs_nosharing = 0 for i in range(size + 1): with shared_intermediates() as cache: target = alphabet[i] eq = '{}->{}'.format(inputs, target) expr = contract_expression(eq, *(x.shape for x in xs)) expr(*xs, backend=backend) num_exprs_nosharing += _compute_cost(cache) with shared_intermediates() as cache: print(inputs) for i in range(size + 1): target = alphabet[i] eq = '{}->{}'.format(inputs, target) path_info = contract_path(eq, *xs) print(path_info[1]) expr = contract_expression(eq, *(x.shape for x in xs)) expr(*xs, backend=backend) num_exprs_sharing = _compute_cost(cache) print('-' * 40) print('Without sharing: {} expressions'.format(num_exprs_nosharing)) print('With sharing: {} expressions'.format(num_exprs_sharing)) assert num_exprs_nosharing > num_exprs_sharing
def test_sharing_modulo_commutativity(eq, backend): ops = helpers.build_views(eq) ops = [to_backend[backend](x) for x in ops] inputs, output, _ = parse_einsum_input([eq] + ops) inputs = inputs.split(',') print('-' * 40) print('Without sharing:') with shared_intermediates() as cache: _einsum(eq, *ops, backend=backend) expected = count_cached_ops(cache) print('-' * 40) print('With sharing:') with shared_intermediates() as cache: for permuted in itertools.permutations(zip(inputs, ops)): permuted_inputs = [p[0] for p in permuted] permuted_ops = [p[1] for p in permuted] permuted_eq = '{}->{}'.format(','.join(permuted_inputs), output) _einsum(permuted_eq, *permuted_ops, backend=backend) actual = count_cached_ops(cache) print('-' * 40) print('Without sharing: {} expressions'.format(expected)) print('With sharing: {} expressions'.format(actual)) assert actual == expected
def test_partial_sharing(backend): eq = 'ab,bc,de->' x, y, z1 = helpers.build_views(eq) z2 = 2.0 * z1 - 1.0 expr = contract_expression(eq, x.shape, y.shape, z1.shape) print('-' * 40) print('Without sharing:') num_exprs_nosharing = Counter() with shared_intermediates() as cache: expr(x, y, z1, backend=backend) num_exprs_nosharing.update(count_cached_ops(cache)) with shared_intermediates() as cache: expr(x, y, z2, backend=backend) num_exprs_nosharing.update(count_cached_ops(cache)) print('-' * 40) print('With sharing:') with shared_intermediates() as cache: expr(x, y, z1, backend=backend) expr(x, y, z2, backend=backend) num_exprs_sharing = count_cached_ops(cache) print('-' * 40) print('Without sharing: {} expressions'.format(num_exprs_nosharing)) print('With sharing: {} expressions'.format(num_exprs_sharing)) assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum']
def test_no_sharing_separate_cache(backend): eq = 'ab,bc,cd->' views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print('-' * 40) print('Without sharing:') with shared_intermediates() as cache: expr(*views, backend=backend) expected = count_cached_ops(cache) expected.update(count_cached_ops(cache)) # we expect double print('-' * 40) print('With sharing:') with shared_intermediates() as cache1: expr(*views, backend=backend) actual = count_cached_ops(cache1) with shared_intermediates() as cache2: expr(*views, backend=backend) actual.update(count_cached_ops(cache2)) print('-' * 40) print('Without sharing: {} expressions'.format(expected)) print('With sharing: {} expressions'.format(actual)) assert actual == expected
def test_no_sharing_separate_cache(backend): eq = 'ab,bc,cd->' views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print('-' * 40) print('Without sharing:') with shared_intermediates() as cache: expr(*views, backend=backend) expected = count_cached_ops(cache) expected.update(count_cached_ops(cache)) # we expect double print('-' * 40) print('With sharing:') with shared_intermediates() as cache1: expr(*views, backend=backend) actual = count_cached_ops(cache1) with shared_intermediates() as cache2: expr(*views, backend=backend) actual.update(count_cached_ops(cache2)) print('-' * 40) print('Without sharing: {} expressions'.format(expected)) print('With sharing: {} expressions'.format(actual)) assert actual == expected
def test_sharing_modulo_commutativity(eq, backend): ops = helpers.build_views(eq) ops = [to_backend[backend](x) for x in ops] inputs, output, _ = parse_einsum_input([eq] + ops) inputs = inputs.split(',') print('-' * 40) print('Without sharing:') with shared_intermediates() as cache: _einsum(eq, *ops, backend=backend) expected = count_cached_ops(cache) print('-' * 40) print('With sharing:') with shared_intermediates() as cache: for permuted in itertools.permutations(zip(inputs, ops)): permuted_inputs = [p[0] for p in permuted] permuted_ops = [p[1] for p in permuted] permuted_eq = '{}->{}'.format(','.join(permuted_inputs), output) _einsum(permuted_eq, *permuted_ops, backend=backend) actual = count_cached_ops(cache) print('-' * 40) print('Without sharing: {} expressions'.format(expected)) print('With sharing: {} expressions'.format(actual)) assert actual == expected
def test_partial_sharing(backend): eq = 'ab,bc,de->' x, y, z1 = helpers.build_views(eq) z2 = 2.0 * z1 - 1.0 expr = contract_expression(eq, x.shape, y.shape, z1.shape) print('-' * 40) print('Without sharing:') num_exprs_nosharing = Counter() with shared_intermediates() as cache: expr(x, y, z1, backend=backend) num_exprs_nosharing.update(count_cached_ops(cache)) with shared_intermediates() as cache: expr(x, y, z2, backend=backend) num_exprs_nosharing.update(count_cached_ops(cache)) print('-' * 40) print('With sharing:') with shared_intermediates() as cache: expr(x, y, z1, backend=backend) expr(x, y, z2, backend=backend) num_exprs_sharing = count_cached_ops(cache) print('-' * 40) print('Without sharing: {} expressions'.format(num_exprs_nosharing)) print('With sharing: {} expressions'.format(num_exprs_sharing)) assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum']
def _compute_dice_elbo(model_trace, guide_trace): # Accumulate marginal model costs. marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors( model_trace, guide_trace) if log_factors: # Note that while most applications of tensor message passing use the # contract_to_tensor() interface and can be easily refactored to use ubersum(), # the application here relies on contract_tensor_tree() to extract the dependency # structure of different log_prob terms, which is used by Dice to eliminate # zero-expectation terms. One possible refactoring would be to replace # contract_to_tensor() with a RaggedTensor -> Tensor contraction operation, but # replace contract_tensor_tree() with a RaggedTensor -> RaggedTensor contraction # that preserves some dependency structure. with shared_intermediates() as cache: log_factors = contract_tensor_tree(log_factors, sum_dims, cache=cache) for t, log_factors_t in log_factors.items(): marginal_costs_t = marginal_costs.setdefault(t, []) for term in log_factors_t: term = packed.scale_and_mask(term, scale=scale) marginal_costs_t.append(term) costs = marginal_costs # Accumulate negative guide costs. for name, site in guide_trace.nodes.items(): if site["type"] == "sample": cost = packed.neg(site["packed"]["log_prob"]) costs.setdefault(ordering[name], []).append(cost) return Dice(guide_trace, ordering).compute_expectation(costs)
def test_sharing_with_constants(backend): inputs = 'ij,jk,kl' outputs = 'ijkl' equations = ['{}->{}'.format(inputs, output) for output in outputs] shapes = (2, 3), (3, 4), (4, 5) constants = {0, 2} ops = [ np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes) ] var = np.random.rand(*shapes[1]) expected = [ contract_expression(eq, *shapes)(ops[0], var, ops[2]) for eq in equations ] with shared_intermediates(): actual = [ contract_expression(eq, *ops, constants=constants)(var) for eq in equations ] for dim, expected_dim, actual_dim in zip(outputs, expected, actual): assert np.allclose(expected_dim, actual_dim), 'error at {}'.format(dim)
def _compute_marginals(model_trace, guide_trace): args = _compute_model_factors(model_trace, guide_trace) marginal_costs, log_factors, ordering, sum_dims, scale = args marginal_dists = OrderedDict() with shared_intermediates() as cache: for name, site in model_trace.nodes.items(): if (site["type"] != "sample" or name in guide_trace.nodes or site["infer"].get("_enumerate_dim") is None): continue enum_dim = site["infer"]["_enumerate_dim"] enum_symbol = site["infer"]["_enumerate_symbol"] ordinal = _find_ordinal(model_trace, site) logits = contract_to_tensor(log_factors, sum_dims, target_ordinal=ordinal, target_dims={enum_symbol}, cache=cache) logits = packed.unpack(logits, model_trace.symbol_to_dim) logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1) while logits.shape[0] == 1: logits = logits.squeeze(0) marginal_dists[name] = _make_dist(site["fn"], logits) return marginal_dists
def fn(): X, Y, Z = helpers.build_views('ab,bc,cd') with shared_intermediates(): contract('ab,bc,cd->a', X, Y, Z) contract('ab,bc,cd->b', X, Y, Z) return len(get_sharing_cache())
def fn(): X, Y, Z = helpers.build_views('ab,bc,cd') with shared_intermediates(): contract('ab,bc,cd->a', X, Y, Z) contract('ab,bc,cd->b', X, Y, Z) return len(get_sharing_cache())
def _logistic_regression_inner(genotype_pdf: pd.DataFrame, log_reg_state: LogRegState, C: NDArray[(Any, Any), Float], Y: NDArray[(Any, Any), Float], Y_mask: NDArray[(Any, Any), bool], Q: Optional[NDArray[(Any, Any), Float]], correction: str, pvalue_threshold: float, phenotype_names: pd.Series) -> pd.DataFrame: ''' Tests a block of genotypes for association with binary traits. We first residualize the genotypes based on the null model fit, then perform a fast score test to check for possible significance. We use semantic indices for the einsum expressions: s, i: sample (or individual) g: genotype p: phenotype c, d: covariate ''' X = np.column_stack(genotype_pdf[_VALUES_COLUMN_NAME].array) # For approximate Firth correction, we perform a linear residualization if correction == correction_approx_firth: X = gwas_fx._residualize_in_place(X, Q) with oe.shared_intermediates(): X_res = _logistic_residualize(X, C, Y_mask, log_reg_state.gamma, log_reg_state.inv_CtGammaC) num = gwas_fx._einsum('sgp,sp->pg', X_res, log_reg_state.Y_res)**2 denom = gwas_fx._einsum('sgp,sgp,sp->pg', X_res, X_res, log_reg_state.gamma) chisq = np.ravel(num / denom) p_values = stats.chi2.sf(chisq, 1) del genotype_pdf[_VALUES_COLUMN_NAME] out_df = pd.concat([genotype_pdf] * log_reg_state.Y_res.shape[1]) out_df['chisq'] = list(np.ravel(chisq)) out_df['pvalue'] = list(np.ravel(p_values)) out_df['phenotype'] = phenotype_names.repeat(genotype_pdf.shape[0]).tolist() if correction != correction_none: out_df['correctionSucceeded'] = None correction_indices = list(np.where(out_df['pvalue'] < pvalue_threshold)[0]) if correction == correction_approx_firth: out_df['effect'] = np.nan out_df['stderror'] = np.nan for correction_idx in correction_indices: snp_idx = correction_idx % X.shape[1] pheno_idx = int(correction_idx / X.shape[1]) approx_firth_snp_fit = af.correct_approx_firth( X[:, snp_idx], Y[:, pheno_idx], log_reg_state.firth_offset[:, pheno_idx], Y_mask[:, pheno_idx]) if approx_firth_snp_fit is None: out_df.correctionSucceeded.iloc[correction_idx] = False else: out_df.correctionSucceeded.iloc[correction_idx] = True out_df.effect.iloc[correction_idx] = approx_firth_snp_fit.effect out_df.stderror.iloc[correction_idx] = approx_firth_snp_fit.stderror out_df.chisq.iloc[correction_idx] = approx_firth_snp_fit.chisq out_df.pvalue.iloc[correction_idx] = approx_firth_snp_fit.pvalue return out_df
def compute_expectation(self, costs): """ Returns a differentiable expected cost, summing over costs at given ordinals. :param dict costs: A dict mapping ordinals to lists of cost tensors :returns: a scalar expected cost :rtype: torch.Tensor or float """ # Share computation across all cost terms. with shared_intermediates() as cache: ring = MarginalRing(cache=cache) expected_cost = 0. for ordinal, cost_terms in costs.items(): log_factors = self._get_log_factors(ordinal) scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor))) log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)] # Collect log_prob terms to query for marginal probability. queries = {frozenset(cost._pyro_dims): None for cost in cost_terms} for log_factor in log_factors: key = frozenset(log_factor._pyro_dims) if queries.get(key, False) is None: queries[key] = log_factor # Ensure a query exists for each cost term. for cost in cost_terms: key = frozenset(cost._pyro_dims) if queries[key] is None: query = cost.new_zeros(cost.shape) query._pyro_dims = cost._pyro_dims log_factors.append(query) queries[key] = query # Perform sum-product contraction. Note that plates never need to be # product-contracted due to our plate-based dependency ordering. sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal for query in queries.values(): require_backward(query) root = ring.sumproduct(log_factors, sum_dims) root._pyro_backward() probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()} # Aggregate prob * cost terms. for cost in cost_terms: key = frozenset(cost._pyro_dims) prob = probs[key] prob._pyro_dims = queries[key]._pyro_dims mask = prob > 0 if torch._C._get_tracing_state() or not mask.all(): mask._pyro_dims = prob._pyro_dims cost, prob, mask = packed.broadcast_all(cost, prob, mask) prob = prob[mask] cost = cost[mask] else: cost, prob = packed.broadcast_all(cost, prob) expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim()) LAST_CACHE_SIZE[0] = count_cached_ops(cache) return expected_cost
def test_sharing_value(eq, backend): views = helpers.build_views(eq) shapes = [v.shape for v in views] expr = contract_expression(eq, *shapes) expected = expr(*views, backend=backend) with shared_intermediates(): actual = expr(*views, backend=backend) assert (actual == expected).all()
def test_sharing_value(eq, backend): views = helpers.build_views(eq) shapes = [v.shape for v in views] expr = contract_expression(eq, *shapes) expected = expr(*views, backend=backend) with shared_intermediates(): actual = expr(*views, backend=backend) assert (actual == expected).all()
def log_prob(self, model_trace): """ Returns the log pdf of `model_trace` by appropriately handling enumerated log prob factors. :return: log pdf of the trace. """ if not self.has_enumerable_sites: return model_trace.log_prob_sum() log_probs = self._get_log_factors(model_trace) with shared_intermediates() as cache: return contract_to_tensor(log_probs, self._enum_dims, cache=cache)
def log_prob(self, model_trace): """ Returns the log pdf of `model_trace` by appropriately handling enumerated log prob factors. :return: log pdf of the trace. """ with shared_intermediates(): if not self.has_enumerable_sites: return model_trace.log_prob_sum() self._compute_log_prob_terms(model_trace) return self._aggregate_log_probs(ordinal=frozenset()).sum()
def test_complete_sharing(backend): eq = 'ab,bc,cd->' views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print('-' * 40) print('Without sharing:') with shared_intermediates() as cache: expr(*views, backend=backend) expected = count_cached_ops(cache) print('-' * 40) print('With sharing:') with shared_intermediates() as cache: expr(*views, backend=backend) expr(*views, backend=backend) actual = count_cached_ops(cache) print('-' * 40) print('Without sharing: {} expressions'.format(expected)) print('With sharing: {} expressions'.format(actual)) assert actual == expected
def method2(views): with shared_intermediates(): y = contract_expression(eqs[2], *shapes)(*views, backend=backend) z = contract_expression(eqs[3], *shapes)(*views, backend=backend) refs['y'] = y refs['z'] = z result = contract_expression('c,d->', y.shape, z.shape)(y, z, backend=backend) result = result + method1(views) # nest method1 in method2 del y, z assert 'y' in refs assert 'z' in refs assert 'y' not in refs assert 'z' not in refs
def method1(views): with shared_intermediates(): w = contract_expression(eqs[0], *shapes)(*views, backend=backend) x = contract_expression(eqs[2], *shapes)(*views, backend=backend) result = contract_expression('a,b->', w.shape, x.shape)(w, x, backend=backend) refs['w'] = w refs['x'] = x del w, x assert 'w' in refs assert 'x' in refs assert 'w' not in refs, 'cache leakage' assert 'x' not in refs, 'cache leakage' return result
def test_complete_sharing(backend): eq = 'ab,bc,cd->' views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print('-' * 40) print('Without sharing:') with shared_intermediates() as cache: expr(*views, backend=backend) expected = count_cached_ops(cache) print('-' * 40) print('With sharing:') with shared_intermediates() as cache: expr(*views, backend=backend) expr(*views, backend=backend) actual = count_cached_ops(cache) print('-' * 40) print('Without sharing: {} expressions'.format(expected)) print('With sharing: {} expressions'.format(actual)) assert actual == expected
def method2(views): with shared_intermediates(): y = contract_expression(eqs[2], *shapes)(*views, backend=backend) z = contract_expression(eqs[3], *shapes)(*views, backend=backend) refs["y"] = y refs["z"] = z result = contract_expression("c,d->", y.shape, z.shape)(y, z, backend=backend) result = result + method1(views) # nest method1 in method2 del y, z assert "y" in refs assert "z" in refs assert "y" not in refs assert "z" not in refs
def test_sharing_reused_cache(backend): eq = "ab,bc,cd->" views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print("-" * 40) print("Without sharing:") with shared_intermediates() as cache: expr(*views, backend=backend) expected = count_cached_ops(cache) print("-" * 40) print("With sharing:") with shared_intermediates() as cache: expr(*views, backend=backend) with shared_intermediates(cache): expr(*views, backend=backend) actual = count_cached_ops(cache) print("-" * 40) print("Without sharing: {} expressions".format(expected)) print("With sharing: {} expressions".format(actual)) assert actual == expected
def method1(views): with shared_intermediates(): w = contract_expression(eqs[0], *shapes)(*views, backend=backend) x = contract_expression(eqs[2], *shapes)(*views, backend=backend) result = contract_expression('a,b->', w.shape, x.shape)(w, x, backend=backend) refs['w'] = w refs['x'] = x del w, x assert 'w' in refs assert 'x' in refs assert 'w' not in refs, 'cache leakage' assert 'x' not in refs, 'cache leakage' return result
def method2(views): with shared_intermediates(): y = contract_expression(eqs[2], *shapes)(*views, backend=backend) z = contract_expression(eqs[3], *shapes)(*views, backend=backend) refs['y'] = y refs['z'] = z result = contract_expression('c,d->', y.shape, z.shape)(y, z, backend=backend) result = result + method1(views) # nest method1 in method2 del y, z assert 'y' in refs assert 'z' in refs assert 'y' not in refs assert 'z' not in refs
def method1(views): with shared_intermediates(): w = contract_expression(eqs[0], *shapes)(*views, backend=backend) x = contract_expression(eqs[2], *shapes)(*views, backend=backend) result = contract_expression("a,b->", w.shape, x.shape)(w, x, backend=backend) refs["w"] = w refs["x"] = x del w, x assert "w" in refs assert "x" in refs assert "w" not in refs, "cache leakage" assert "x" not in refs, "cache leakage" return result
def log_prob(self, model_trace): """ almost identical to that of TraceEinsumEvaluator but uses log_prob instead of log_prob_sum """ if not self.has_enumerable_sites: log_prob = 0 for name in model_trace.stochastic_nodes: dist = model_trace.nodes[name]['fn'] value = model_trace.nodes[name]['value'] site_log_prob = dist.log_prob(value) log_prob = log_prob + site_log_prob return log_prob log_probs = self._get_log_factors(model_trace) with shared_intermediates() as cache: return contract_to_tensor(log_probs, self._enum_dims, cache=cache)
def test_chain_2(size, backend): xs = [np.random.rand(2, 2) for _ in range(size)] shapes = [x.shape for x in xs] alphabet = ''.join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i:i + 2] for i in range(size)] inputs = ','.join(names) with shared_intermediates(): print(inputs) for i in range(size): target = alphabet[i:i + 2] eq = '{}->{}'.format(inputs, target) path_info = contract_path(eq, *xs) print(path_info[1]) expr = contract_expression(eq, *shapes) expr(*xs, backend=backend) print('-' * 40)
def test_chain_2(size, backend): xs = [np.random.rand(2, 2) for _ in range(size)] shapes = [x.shape for x in xs] alphabet = ''.join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i:i+2] for i in range(size)] inputs = ','.join(names) with shared_intermediates(): print(inputs) for i in range(size): target = alphabet[i:i+2] eq = '{}->{}'.format(inputs, target) path_info = contract_path(eq, *xs) print(path_info[1]) expr = contract_expression(eq, *shapes) expr(*xs, backend=backend) print('-' * 40)
def _pyro_sample(self, msg): enum_msg = self.enum_trace.nodes.get(msg["name"]) if enum_msg is None: return enum_symbol = enum_msg["infer"].get("_enumerate_symbol") if enum_symbol is None: return enum_dim = enum_msg["infer"]["_enumerate_dim"] with shared_intermediates(self.cache): ordinal = _find_ordinal(self.enum_trace, msg) logits = contract_to_tensor(self.log_factors, self.sum_dims, target_ordinal=ordinal, target_dims={enum_symbol}, cache=self.cache) logits = packed.unpack(logits, self.enum_trace.symbol_to_dim) logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1) while logits.shape[0] == 1: logits = logits.squeeze(0) msg["fn"] = _make_dist(msg["fn"], logits)
def test_sharing_with_constants(backend): inputs = 'ij,jk,kl' outputs = 'ijkl' equations = ['{}->{}'.format(inputs, output) for output in outputs] shapes = (2, 3), (3, 4), (4, 5) constants = {0, 2} ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] var = np.random.rand(*shapes[1]) expected = [contract_expression(eq, *shapes)(ops[0], var, ops[2]) for eq in equations] with shared_intermediates(): actual = [contract_expression(eq, *ops, constants=constants)(var) for eq in equations] for dim, expected_dim, actual_dim in zip(outputs, expected, actual): assert np.allclose(expected_dim, actual_dim), 'error at {}'.format(dim)
def test_chain_2_growth(backend): sizes = list(range(1, 21)) costs = [] for size in sizes: xs = [np.random.rand(2, 2) for _ in range(size)] alphabet = ''.join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i:i + 2] for i in range(size)] inputs = ','.join(names) with shared_intermediates() as cache: for i in range(size): target = alphabet[i:i + 2] eq = '{}->{}'.format(inputs, target) expr = contract_expression(eq, *(x.shape for x in xs)) expr(*xs, backend=backend) costs.append(_compute_cost(cache)) print('sizes = {}'.format(repr(sizes))) print('costs = {}'.format(repr(costs))) for size, cost in zip(sizes, costs): print('{}\t{}'.format(size, cost))
def test_chain_2_growth(backend): sizes = list(range(1, 21)) costs = [] for size in sizes: xs = [np.random.rand(2, 2) for _ in range(size)] alphabet = ''.join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i:i+2] for i in range(size)] inputs = ','.join(names) with shared_intermediates() as cache: for i in range(size): target = alphabet[i:i+2] eq = '{}->{}'.format(inputs, target) expr = contract_expression(eq, *(x.shape for x in xs)) expr(*xs, backend=backend) costs.append(_compute_cost(cache)) print('sizes = {}'.format(repr(sizes))) print('costs = {}'.format(repr(costs))) for size, cost in zip(sizes, costs): print('{}\t{}'.format(size, cost))
def ubersum(equation, *operands, **kwargs): """ Generalized batched sum-product algorithm via tensor message passing. This generalizes :func:`~pyro.ops.einsum.contract` in two ways: 1. Multiple outputs are allowed, and intermediate results can be shared. 2. Inputs and outputs can be batched along symbols given in ``batch_dims``; reductions along ``batch_dims`` are product reductions. The best way to understand this function is to try the examples below, which show how :func:`ubersum` calls can be implemented as multiple calls to :func:`~pyro.ops.einsum.contract` (which is generally more expensive). To illustrate multiple outputs, note that the following are equivalent:: z1, z2, z3 = ubersum('ab,bc->a,b,c', x, y) # multiple outputs backend = 'pyro.ops.einsum.torch_log' z1 = contract('ab,bc->a', x, y, backend=backend) z2 = contract('ab,bc->b', x, y, backend=backend) z3 = contract('ab,bc->c', x, y, backend=backend) To illustrate batched inputs, note that the following are equivalent:: assert len(x) == 3 and len(y) == 3 z = ubersum('ab,ai,bi->b', w, x, y, batch_dims='i') z = contract('ab,a,a,a,b,b,b->b', w, *x, *y, backend=backend) When a sum dimension `a` always appears with a batch dimension `i`, then `a` corresponds to a distinct symbol for each slice of `a`. Thus the following are equivalent:: assert len(x) == 3 and len(y) == 3 z = ubersum('ai,ai->', x, y, batch_dims='i') z = contract('a,b,c,a,b,c->', *x, *y, backend=backend) When such a sum dimension appears in the output, it must be accompanied by all of its batch dimensions, e.g. the following are equivalent:: assert len(x) == 3 and len(y) == 3 z = ubersum('abi,abi->bi', x, y, batch_dims='i') z0 = contract('ab,ac,ad,ab,ac,ad->b', *x, *y, backend=backend) z1 = contract('ab,ac,ad,ab,ac,ad->c', *x, *y, backend=backend) z2 = contract('ab,ac,ad,ab,ac,ad->d', *x, *y, backend=backend) z = torch.stack([z0, z1, z2]) Note that each batch slice through the output is multilinear in all batch slices through all inptus, thus e.g. batch matrix multiply would be implemented *without* ``batch_dims``, so the following are all equivalent:: xy = ubersum('abc,acd->abd', x, y, batch_dims='') xy = torch.stack([xa.mm(ya) for xa, ya in zip(x, y)]) xy = torch.bmm(x, y) Among all valid equations, some computations are polynomial in the sizes of the input tensors and other computations are exponential in the sizes of the input tensors. This function raises :py:class:`NotImplementedError` whenever the computation is exponential. :param str equation: An einsum equation, optionally with multiple outputs. :param torch.Tensor operands: A collection of tensors. :param str batch_dims: An optional string of batch dims. :param dict cache: An optional :func:`~opt_einsum.shared_intermediates` cache. :param bool modulo_total: Optionally allow ubersum to arbitrarily scale each result batch, which can significantly reduce computation. This is safe to set whenever each result batch denotes a nonnormalized probability distribution whose total is not of interest. :return: a tuple of tensors of requested shape, one entry per output. :rtype: tuple :raises ValueError: if tensor sizes mismatch or an output requests a batched dim without that dim's batch dims. :raises NotImplementedError: if contraction would have cost exponential in the size of any input tensor. """ # Extract kwargs. cache = kwargs.pop('cache', None) batch_dims = kwargs.pop('batch_dims', '') backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_log') modulo_total = kwargs.pop('modulo_total', False) try: Ring = BACKEND_TO_RING[backend] except KeyError: raise NotImplementedError('\n'.join( ['Only the following pyro backends are currently implemented:'] + list(BACKEND_TO_RING))) # Parse generalized einsum equation. if '.' in equation: raise NotImplementedError( 'ubsersum does not yet support ellipsis notation') inputs, outputs = equation.split('->') inputs = inputs.split(',') outputs = outputs.split(',') assert len(inputs) == len(operands) assert all(isinstance(x, torch.Tensor) for x in operands) if not modulo_total and any(outputs): raise NotImplementedError( 'Try setting modulo_total=True and ensuring that your use case ' 'allows an arbitrary scale factor on each result batch.') if len(operands) != len(set(operands)): operands = [x[...] for x in operands] # ensure tensors are unique # Check sizes. with ignore_jit_warnings(): dim_to_size = {} for dims, term in zip(inputs, operands): for dim, size in zip(dims, map(int, term.shape)): old = dim_to_size.setdefault(dim, size) if old != size: raise ValueError( u"Dimension size mismatch at dim '{}': {} vs {}". format(dim, size, old)) # Construct a tensor tree shared by all outputs. tensor_tree = OrderedDict() batch_dims = frozenset(batch_dims) for dims, term in zip(inputs, operands): assert len(dims) == term.dim() term._pyro_dims = dims ordinal = batch_dims.intersection(dims) tensor_tree.setdefault(ordinal, []).append(term) # Compute outputs, sharing intermediate computations. results = [] with shared_intermediates(cache) as cache: ring = Ring(cache, dim_to_size=dim_to_size) for output in outputs: sum_dims = set(output).union(*inputs) - set(batch_dims) term = contract_to_tensor( tensor_tree, sum_dims, target_ordinal=batch_dims.intersection(output), target_dims=sum_dims.intersection(output), ring=ring) if term._pyro_dims != output: term = term.permute(*map(term._pyro_dims.index, output)) term._pyro_dims = output results.append(term) return tuple(results)
def _sample_posterior_from_trace(model, enum_trace, temperature, *args, **kwargs): plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. sum_dims = set() queries = [] dim_to_size = {} cost_terms = OrderedDict() enum_terms = OrderedDict() for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized and f.size > 1) # For sites that depend on an enumerated variable, we need to apply # the mask but not the scale when sampling. if "masked_log_prob" not in node["packed"]: node["packed"]["masked_log_prob"] = packed.scale_and_mask( node["packed"]["unscaled_log_prob"], mask=node["packed"]["mask"]) log_prob = node["packed"]["masked_log_prob"] sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal) if sum_dims.isdisjoint(log_prob._pyro_dims): continue dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape)) if node["infer"].get("_enumerate_dim") is None: cost_terms.setdefault(ordinal, []).append(log_prob) else: enum_terms.setdefault(ordinal, []).append(log_prob) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # We take special care to match the term ordering in # pyro.infer.traceenum_elbo._compute_model_factors() to allow # contract_tensor_tree() to use shared_intermediates() inside # TraceEnumSample_ELBO. The special ordering is: first all cost terms in # order of model_trace, then all enum_terms in order of model trace. log_probs = cost_terms for ordinal, terms in enum_terms.items(): log_probs.setdefault(ordinal, []).extend(terms) # Run forward-backward algorithm, collecting the ordinal of each connected component. cache = getattr(enum_trace, "_sharing_cache", {}) ring = _make_ring(temperature, cache, dim_to_size) with shared_intermediates(cache): log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["masked_log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not (f.vectorized and f.size > 1) or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)