Ejemplo n.º 1
0
def test_chain_sharing(size, backend):
    xs = [np.random.rand(2, 2) for _ in range(size)]
    alphabet = ''.join(get_symbol(i) for i in range(size + 1))
    names = [alphabet[i:i + 2] for i in range(size)]
    inputs = ','.join(names)

    num_exprs_nosharing = 0
    for i in range(size + 1):
        with shared_intermediates() as cache:
            target = alphabet[i]
            eq = '{}->{}'.format(inputs, target)
            expr = contract_expression(eq, *(x.shape for x in xs))
            expr(*xs, backend=backend)
            num_exprs_nosharing += _compute_cost(cache)

    with shared_intermediates() as cache:
        print(inputs)
        for i in range(size + 1):
            target = alphabet[i]
            eq = '{}->{}'.format(inputs, target)
            path_info = contract_path(eq, *xs)
            print(path_info[1])
            expr = contract_expression(eq, *(x.shape for x in xs))
            expr(*xs, backend=backend)
        num_exprs_sharing = _compute_cost(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(num_exprs_nosharing))
    print('With sharing: {} expressions'.format(num_exprs_sharing))
    assert num_exprs_nosharing > num_exprs_sharing
Ejemplo n.º 2
0
def test_chain_sharing(size, backend):
    xs = [np.random.rand(2, 2) for _ in range(size)]
    alphabet = ''.join(get_symbol(i) for i in range(size + 1))
    names = [alphabet[i:i+2] for i in range(size)]
    inputs = ','.join(names)

    num_exprs_nosharing = 0
    for i in range(size + 1):
        with shared_intermediates() as cache:
            target = alphabet[i]
            eq = '{}->{}'.format(inputs, target)
            expr = contract_expression(eq, *(x.shape for x in xs))
            expr(*xs, backend=backend)
            num_exprs_nosharing += _compute_cost(cache)

    with shared_intermediates() as cache:
        print(inputs)
        for i in range(size + 1):
            target = alphabet[i]
            eq = '{}->{}'.format(inputs, target)
            path_info = contract_path(eq, *xs)
            print(path_info[1])
            expr = contract_expression(eq, *(x.shape for x in xs))
            expr(*xs, backend=backend)
        num_exprs_sharing = _compute_cost(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(num_exprs_nosharing))
    print('With sharing: {} expressions'.format(num_exprs_sharing))
    assert num_exprs_nosharing > num_exprs_sharing
Ejemplo n.º 3
0
def test_sharing_modulo_commutativity(eq, backend):
    ops = helpers.build_views(eq)
    ops = [to_backend[backend](x) for x in ops]
    inputs, output, _ = parse_einsum_input([eq] + ops)
    inputs = inputs.split(',')

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        _einsum(eq, *ops, backend=backend)
        expected = count_cached_ops(cache)

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        for permuted in itertools.permutations(zip(inputs, ops)):
            permuted_inputs = [p[0] for p in permuted]
            permuted_ops = [p[1] for p in permuted]
            permuted_eq = '{}->{}'.format(','.join(permuted_inputs), output)
            _einsum(permuted_eq, *permuted_ops, backend=backend)
        actual = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 4
0
def test_partial_sharing(backend):
    eq = 'ab,bc,de->'
    x, y, z1 = helpers.build_views(eq)
    z2 = 2.0 * z1 - 1.0
    expr = contract_expression(eq, x.shape, y.shape, z1.shape)

    print('-' * 40)
    print('Without sharing:')
    num_exprs_nosharing = Counter()
    with shared_intermediates() as cache:
        expr(x, y, z1, backend=backend)
        num_exprs_nosharing.update(count_cached_ops(cache))
    with shared_intermediates() as cache:
        expr(x, y, z2, backend=backend)
        num_exprs_nosharing.update(count_cached_ops(cache))

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        expr(x, y, z1, backend=backend)
        expr(x, y, z2, backend=backend)
        num_exprs_sharing = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(num_exprs_nosharing))
    print('With sharing: {} expressions'.format(num_exprs_sharing))
    assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum']
Ejemplo n.º 5
0
def test_no_sharing_separate_cache(backend):
    eq = 'ab,bc,cd->'
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)
        expected.update(count_cached_ops(cache))  # we expect double

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache1:
        expr(*views, backend=backend)
        actual = count_cached_ops(cache1)
    with shared_intermediates() as cache2:
        expr(*views, backend=backend)
        actual.update(count_cached_ops(cache2))

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 6
0
def test_no_sharing_separate_cache(backend):
    eq = 'ab,bc,cd->'
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)
        expected.update(count_cached_ops(cache))  # we expect double

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache1:
        expr(*views, backend=backend)
        actual = count_cached_ops(cache1)
    with shared_intermediates() as cache2:
        expr(*views, backend=backend)
        actual.update(count_cached_ops(cache2))

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 7
0
def test_sharing_modulo_commutativity(eq, backend):
    ops = helpers.build_views(eq)
    ops = [to_backend[backend](x) for x in ops]
    inputs, output, _ = parse_einsum_input([eq] + ops)
    inputs = inputs.split(',')

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        _einsum(eq, *ops, backend=backend)
        expected = count_cached_ops(cache)

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        for permuted in itertools.permutations(zip(inputs, ops)):
            permuted_inputs = [p[0] for p in permuted]
            permuted_ops = [p[1] for p in permuted]
            permuted_eq = '{}->{}'.format(','.join(permuted_inputs), output)
            _einsum(permuted_eq, *permuted_ops, backend=backend)
        actual = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 8
0
def test_partial_sharing(backend):
    eq = 'ab,bc,de->'
    x, y, z1 = helpers.build_views(eq)
    z2 = 2.0 * z1 - 1.0
    expr = contract_expression(eq, x.shape, y.shape, z1.shape)

    print('-' * 40)
    print('Without sharing:')
    num_exprs_nosharing = Counter()
    with shared_intermediates() as cache:
        expr(x, y, z1, backend=backend)
        num_exprs_nosharing.update(count_cached_ops(cache))
    with shared_intermediates() as cache:
        expr(x, y, z2, backend=backend)
        num_exprs_nosharing.update(count_cached_ops(cache))

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        expr(x, y, z1, backend=backend)
        expr(x, y, z2, backend=backend)
        num_exprs_sharing = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(num_exprs_nosharing))
    print('With sharing: {} expressions'.format(num_exprs_sharing))
    assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum']
Ejemplo n.º 9
0
def _compute_dice_elbo(model_trace, guide_trace):
    # Accumulate marginal model costs.
    marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors(
        model_trace, guide_trace)
    if log_factors:
        # Note that while most applications of tensor message passing use the
        # contract_to_tensor() interface and can be easily refactored to use ubersum(),
        # the application here relies on contract_tensor_tree() to extract the dependency
        # structure of different log_prob terms, which is used by Dice to eliminate
        # zero-expectation terms. One possible refactoring would be to replace
        # contract_to_tensor() with a RaggedTensor -> Tensor contraction operation, but
        # replace contract_tensor_tree() with a RaggedTensor -> RaggedTensor contraction
        # that preserves some dependency structure.
        with shared_intermediates() as cache:
            log_factors = contract_tensor_tree(log_factors,
                                               sum_dims,
                                               cache=cache)
        for t, log_factors_t in log_factors.items():
            marginal_costs_t = marginal_costs.setdefault(t, [])
            for term in log_factors_t:
                term = packed.scale_and_mask(term, scale=scale)
                marginal_costs_t.append(term)
    costs = marginal_costs

    # Accumulate negative guide costs.
    for name, site in guide_trace.nodes.items():
        if site["type"] == "sample":
            cost = packed.neg(site["packed"]["log_prob"])
            costs.setdefault(ordering[name], []).append(cost)

    return Dice(guide_trace, ordering).compute_expectation(costs)
Ejemplo n.º 10
0
def test_sharing_with_constants(backend):
    inputs = 'ij,jk,kl'
    outputs = 'ijkl'
    equations = ['{}->{}'.format(inputs, output) for output in outputs]
    shapes = (2, 3), (3, 4), (4, 5)
    constants = {0, 2}
    ops = [
        np.random.rand(*shp) if i in constants else shp
        for i, shp in enumerate(shapes)
    ]
    var = np.random.rand(*shapes[1])

    expected = [
        contract_expression(eq, *shapes)(ops[0], var, ops[2])
        for eq in equations
    ]

    with shared_intermediates():
        actual = [
            contract_expression(eq, *ops, constants=constants)(var)
            for eq in equations
        ]

    for dim, expected_dim, actual_dim in zip(outputs, expected, actual):
        assert np.allclose(expected_dim, actual_dim), 'error at {}'.format(dim)
Ejemplo n.º 11
0
def _compute_marginals(model_trace, guide_trace):
    args = _compute_model_factors(model_trace, guide_trace)
    marginal_costs, log_factors, ordering, sum_dims, scale = args

    marginal_dists = OrderedDict()
    with shared_intermediates() as cache:
        for name, site in model_trace.nodes.items():
            if (site["type"] != "sample" or name in guide_trace.nodes
                    or site["infer"].get("_enumerate_dim") is None):
                continue

            enum_dim = site["infer"]["_enumerate_dim"]
            enum_symbol = site["infer"]["_enumerate_symbol"]
            ordinal = _find_ordinal(model_trace, site)
            logits = contract_to_tensor(log_factors,
                                        sum_dims,
                                        target_ordinal=ordinal,
                                        target_dims={enum_symbol},
                                        cache=cache)
            logits = packed.unpack(logits, model_trace.symbol_to_dim)
            logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1)
            while logits.shape[0] == 1:
                logits = logits.squeeze(0)
            marginal_dists[name] = _make_dist(site["fn"], logits)
    return marginal_dists
Ejemplo n.º 12
0
    def fn():
        X, Y, Z = helpers.build_views('ab,bc,cd')

        with shared_intermediates():
            contract('ab,bc,cd->a', X, Y, Z)
            contract('ab,bc,cd->b', X, Y, Z)

            return len(get_sharing_cache())
Ejemplo n.º 13
0
    def fn():
        X, Y, Z = helpers.build_views('ab,bc,cd')

        with shared_intermediates():
            contract('ab,bc,cd->a', X, Y, Z)
            contract('ab,bc,cd->b', X, Y, Z)

            return len(get_sharing_cache())
Ejemplo n.º 14
0
def _logistic_regression_inner(genotype_pdf: pd.DataFrame, log_reg_state: LogRegState,
                               C: NDArray[(Any, Any), Float], Y: NDArray[(Any, Any), Float],
                               Y_mask: NDArray[(Any, Any),
                                               bool], Q: Optional[NDArray[(Any, Any),
                                                                          Float]], correction: str,
                               pvalue_threshold: float, phenotype_names: pd.Series) -> pd.DataFrame:
    '''
    Tests a block of genotypes for association with binary traits. We first residualize
    the genotypes based on the null model fit, then perform a fast score test to check for
    possible significance.

    We use semantic indices for the einsum expressions:
    s, i: sample (or individual)
    g: genotype
    p: phenotype
    c, d: covariate
    '''
    X = np.column_stack(genotype_pdf[_VALUES_COLUMN_NAME].array)

    # For approximate Firth correction, we perform a linear residualization
    if correction == correction_approx_firth:
        X = gwas_fx._residualize_in_place(X, Q)

    with oe.shared_intermediates():
        X_res = _logistic_residualize(X, C, Y_mask, log_reg_state.gamma, log_reg_state.inv_CtGammaC)
        num = gwas_fx._einsum('sgp,sp->pg', X_res, log_reg_state.Y_res)**2
        denom = gwas_fx._einsum('sgp,sgp,sp->pg', X_res, X_res, log_reg_state.gamma)
    chisq = np.ravel(num / denom)
    p_values = stats.chi2.sf(chisq, 1)

    del genotype_pdf[_VALUES_COLUMN_NAME]
    out_df = pd.concat([genotype_pdf] * log_reg_state.Y_res.shape[1])
    out_df['chisq'] = list(np.ravel(chisq))
    out_df['pvalue'] = list(np.ravel(p_values))
    out_df['phenotype'] = phenotype_names.repeat(genotype_pdf.shape[0]).tolist()

    if correction != correction_none:
        out_df['correctionSucceeded'] = None
        correction_indices = list(np.where(out_df['pvalue'] < pvalue_threshold)[0])
        if correction == correction_approx_firth:
            out_df['effect'] = np.nan
            out_df['stderror'] = np.nan
            for correction_idx in correction_indices:
                snp_idx = correction_idx % X.shape[1]
                pheno_idx = int(correction_idx / X.shape[1])
                approx_firth_snp_fit = af.correct_approx_firth(
                    X[:, snp_idx], Y[:, pheno_idx], log_reg_state.firth_offset[:, pheno_idx],
                    Y_mask[:, pheno_idx])
                if approx_firth_snp_fit is None:
                    out_df.correctionSucceeded.iloc[correction_idx] = False
                else:
                    out_df.correctionSucceeded.iloc[correction_idx] = True
                    out_df.effect.iloc[correction_idx] = approx_firth_snp_fit.effect
                    out_df.stderror.iloc[correction_idx] = approx_firth_snp_fit.stderror
                    out_df.chisq.iloc[correction_idx] = approx_firth_snp_fit.chisq
                    out_df.pvalue.iloc[correction_idx] = approx_firth_snp_fit.pvalue

    return out_df
Ejemplo n.º 15
0
    def compute_expectation(self, costs):
        """
        Returns a differentiable expected cost, summing over costs at given ordinals.

        :param dict costs: A dict mapping ordinals to lists of cost tensors
        :returns: a scalar expected cost
        :rtype: torch.Tensor or float
        """
        # Share computation across all cost terms.
        with shared_intermediates() as cache:
            ring = MarginalRing(cache=cache)
            expected_cost = 0.
            for ordinal, cost_terms in costs.items():
                log_factors = self._get_log_factors(ordinal)
                scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor)))
                log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)]

                # Collect log_prob terms to query for marginal probability.
                queries = {frozenset(cost._pyro_dims): None for cost in cost_terms}
                for log_factor in log_factors:
                    key = frozenset(log_factor._pyro_dims)
                    if queries.get(key, False) is None:
                        queries[key] = log_factor
                # Ensure a query exists for each cost term.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    if queries[key] is None:
                        query = cost.new_zeros(cost.shape)
                        query._pyro_dims = cost._pyro_dims
                        log_factors.append(query)
                        queries[key] = query

                # Perform sum-product contraction. Note that plates never need to be
                # product-contracted due to our plate-based dependency ordering.
                sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal
                for query in queries.values():
                    require_backward(query)
                root = ring.sumproduct(log_factors, sum_dims)
                root._pyro_backward()
                probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()}

                # Aggregate prob * cost terms.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    prob = probs[key]
                    prob._pyro_dims = queries[key]._pyro_dims
                    mask = prob > 0
                    if torch._C._get_tracing_state() or not mask.all():
                        mask._pyro_dims = prob._pyro_dims
                        cost, prob, mask = packed.broadcast_all(cost, prob, mask)
                        prob = prob[mask]
                        cost = cost[mask]
                    else:
                        cost, prob = packed.broadcast_all(cost, prob)
                    expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())

        LAST_CACHE_SIZE[0] = count_cached_ops(cache)
        return expected_cost
Ejemplo n.º 16
0
def test_sharing_value(eq, backend):
    views = helpers.build_views(eq)
    shapes = [v.shape for v in views]
    expr = contract_expression(eq, *shapes)

    expected = expr(*views, backend=backend)
    with shared_intermediates():
        actual = expr(*views, backend=backend)

    assert (actual == expected).all()
Ejemplo n.º 17
0
def test_sharing_value(eq, backend):
    views = helpers.build_views(eq)
    shapes = [v.shape for v in views]
    expr = contract_expression(eq, *shapes)

    expected = expr(*views, backend=backend)
    with shared_intermediates():
        actual = expr(*views, backend=backend)

    assert (actual == expected).all()
Ejemplo n.º 18
0
 def log_prob(self, model_trace):
     """
     Returns the log pdf of `model_trace` by appropriately handling
     enumerated log prob factors.
     :return: log pdf of the trace.
     """
     if not self.has_enumerable_sites:
         return model_trace.log_prob_sum()
     log_probs = self._get_log_factors(model_trace)
     with shared_intermediates() as cache:
         return contract_to_tensor(log_probs, self._enum_dims, cache=cache)
Ejemplo n.º 19
0
    def log_prob(self, model_trace):
        """
        Returns the log pdf of `model_trace` by appropriately handling
        enumerated log prob factors.

        :return: log pdf of the trace.
        """
        with shared_intermediates():
            if not self.has_enumerable_sites:
                return model_trace.log_prob_sum()
            self._compute_log_prob_terms(model_trace)
            return self._aggregate_log_probs(ordinal=frozenset()).sum()
Ejemplo n.º 20
0
def test_complete_sharing(backend):
    eq = 'ab,bc,cd->'
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expr(*views, backend=backend)
        actual = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 21
0
 def method2(views):
     with shared_intermediates():
         y = contract_expression(eqs[2], *shapes)(*views, backend=backend)
         z = contract_expression(eqs[3], *shapes)(*views, backend=backend)
         refs['y'] = y
         refs['z'] = z
         result = contract_expression('c,d->', y.shape, z.shape)(y, z, backend=backend)
         result = result + method1(views)  # nest method1 in method2
         del y, z
         assert 'y' in refs
         assert 'z' in refs
     assert 'y' not in refs
     assert 'z' not in refs
Ejemplo n.º 22
0
 def method1(views):
     with shared_intermediates():
         w = contract_expression(eqs[0], *shapes)(*views, backend=backend)
         x = contract_expression(eqs[2], *shapes)(*views, backend=backend)
         result = contract_expression('a,b->', w.shape, x.shape)(w, x, backend=backend)
         refs['w'] = w
         refs['x'] = x
         del w, x
         assert 'w' in refs
         assert 'x' in refs
     assert 'w' not in refs, 'cache leakage'
     assert 'x' not in refs, 'cache leakage'
     return result
Ejemplo n.º 23
0
def test_complete_sharing(backend):
    eq = 'ab,bc,cd->'
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expr(*views, backend=backend)
        actual = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 24
0
 def method2(views):
     with shared_intermediates():
         y = contract_expression(eqs[2], *shapes)(*views, backend=backend)
         z = contract_expression(eqs[3], *shapes)(*views, backend=backend)
         refs["y"] = y
         refs["z"] = z
         result = contract_expression("c,d->", y.shape,
                                      z.shape)(y, z, backend=backend)
         result = result + method1(views)  # nest method1 in method2
         del y, z
         assert "y" in refs
         assert "z" in refs
     assert "y" not in refs
     assert "z" not in refs
Ejemplo n.º 25
0
def test_sharing_reused_cache(backend):
    eq = "ab,bc,cd->"
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print("-" * 40)
    print("Without sharing:")
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)

    print("-" * 40)
    print("With sharing:")
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
    with shared_intermediates(cache):
        expr(*views, backend=backend)
        actual = count_cached_ops(cache)

    print("-" * 40)
    print("Without sharing: {} expressions".format(expected))
    print("With sharing: {} expressions".format(actual))
    assert actual == expected
Ejemplo n.º 26
0
 def method1(views):
     with shared_intermediates():
         w = contract_expression(eqs[0], *shapes)(*views, backend=backend)
         x = contract_expression(eqs[2], *shapes)(*views, backend=backend)
         result = contract_expression('a,b->', w.shape,
                                      x.shape)(w, x, backend=backend)
         refs['w'] = w
         refs['x'] = x
         del w, x
         assert 'w' in refs
         assert 'x' in refs
     assert 'w' not in refs, 'cache leakage'
     assert 'x' not in refs, 'cache leakage'
     return result
Ejemplo n.º 27
0
 def method2(views):
     with shared_intermediates():
         y = contract_expression(eqs[2], *shapes)(*views, backend=backend)
         z = contract_expression(eqs[3], *shapes)(*views, backend=backend)
         refs['y'] = y
         refs['z'] = z
         result = contract_expression('c,d->', y.shape,
                                      z.shape)(y, z, backend=backend)
         result = result + method1(views)  # nest method1 in method2
         del y, z
         assert 'y' in refs
         assert 'z' in refs
     assert 'y' not in refs
     assert 'z' not in refs
Ejemplo n.º 28
0
 def method1(views):
     with shared_intermediates():
         w = contract_expression(eqs[0], *shapes)(*views, backend=backend)
         x = contract_expression(eqs[2], *shapes)(*views, backend=backend)
         result = contract_expression("a,b->", w.shape,
                                      x.shape)(w, x, backend=backend)
         refs["w"] = w
         refs["x"] = x
         del w, x
         assert "w" in refs
         assert "x" in refs
     assert "w" not in refs, "cache leakage"
     assert "x" not in refs, "cache leakage"
     return result
Ejemplo n.º 29
0
 def log_prob(self, model_trace):
     """
     almost identical to that of TraceEinsumEvaluator but
     uses log_prob instead of log_prob_sum
     """
     if not self.has_enumerable_sites:
         log_prob = 0
         for name in model_trace.stochastic_nodes:
             dist = model_trace.nodes[name]['fn']
             value = model_trace.nodes[name]['value']
             site_log_prob = dist.log_prob(value)
             log_prob = log_prob + site_log_prob
         return log_prob
     log_probs = self._get_log_factors(model_trace)
     with shared_intermediates() as cache:
         return contract_to_tensor(log_probs, self._enum_dims, cache=cache)
Ejemplo n.º 30
0
def test_chain_2(size, backend):
    xs = [np.random.rand(2, 2) for _ in range(size)]
    shapes = [x.shape for x in xs]
    alphabet = ''.join(get_symbol(i) for i in range(size + 1))
    names = [alphabet[i:i + 2] for i in range(size)]
    inputs = ','.join(names)

    with shared_intermediates():
        print(inputs)
        for i in range(size):
            target = alphabet[i:i + 2]
            eq = '{}->{}'.format(inputs, target)
            path_info = contract_path(eq, *xs)
            print(path_info[1])
            expr = contract_expression(eq, *shapes)
            expr(*xs, backend=backend)
        print('-' * 40)
Ejemplo n.º 31
0
def test_chain_2(size, backend):
    xs = [np.random.rand(2, 2) for _ in range(size)]
    shapes = [x.shape for x in xs]
    alphabet = ''.join(get_symbol(i) for i in range(size + 1))
    names = [alphabet[i:i+2] for i in range(size)]
    inputs = ','.join(names)

    with shared_intermediates():
        print(inputs)
        for i in range(size):
            target = alphabet[i:i+2]
            eq = '{}->{}'.format(inputs, target)
            path_info = contract_path(eq, *xs)
            print(path_info[1])
            expr = contract_expression(eq, *shapes)
            expr(*xs, backend=backend)
        print('-' * 40)
Ejemplo n.º 32
0
 def _pyro_sample(self, msg):
     enum_msg = self.enum_trace.nodes.get(msg["name"])
     if enum_msg is None:
         return
     enum_symbol = enum_msg["infer"].get("_enumerate_symbol")
     if enum_symbol is None:
         return
     enum_dim = enum_msg["infer"]["_enumerate_dim"]
     with shared_intermediates(self.cache):
         ordinal = _find_ordinal(self.enum_trace, msg)
         logits = contract_to_tensor(self.log_factors, self.sum_dims,
                                     target_ordinal=ordinal, target_dims={enum_symbol},
                                     cache=self.cache)
         logits = packed.unpack(logits, self.enum_trace.symbol_to_dim)
         logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1)
         while logits.shape[0] == 1:
             logits = logits.squeeze(0)
     msg["fn"] = _make_dist(msg["fn"], logits)
Ejemplo n.º 33
0
def test_sharing_with_constants(backend):
    inputs = 'ij,jk,kl'
    outputs = 'ijkl'
    equations = ['{}->{}'.format(inputs, output) for output in outputs]
    shapes = (2, 3), (3, 4), (4, 5)
    constants = {0, 2}
    ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
    var = np.random.rand(*shapes[1])

    expected = [contract_expression(eq, *shapes)(ops[0], var, ops[2])
                for eq in equations]

    with shared_intermediates():
        actual = [contract_expression(eq, *ops, constants=constants)(var)
                  for eq in equations]

    for dim, expected_dim, actual_dim in zip(outputs, expected, actual):
        assert np.allclose(expected_dim, actual_dim), 'error at {}'.format(dim)
Ejemplo n.º 34
0
def test_chain_2_growth(backend):
    sizes = list(range(1, 21))
    costs = []
    for size in sizes:
        xs = [np.random.rand(2, 2) for _ in range(size)]
        alphabet = ''.join(get_symbol(i) for i in range(size + 1))
        names = [alphabet[i:i + 2] for i in range(size)]
        inputs = ','.join(names)

        with shared_intermediates() as cache:
            for i in range(size):
                target = alphabet[i:i + 2]
                eq = '{}->{}'.format(inputs, target)
                expr = contract_expression(eq, *(x.shape for x in xs))
                expr(*xs, backend=backend)
            costs.append(_compute_cost(cache))

    print('sizes = {}'.format(repr(sizes)))
    print('costs = {}'.format(repr(costs)))
    for size, cost in zip(sizes, costs):
        print('{}\t{}'.format(size, cost))
Ejemplo n.º 35
0
def test_chain_2_growth(backend):
    sizes = list(range(1, 21))
    costs = []
    for size in sizes:
        xs = [np.random.rand(2, 2) for _ in range(size)]
        alphabet = ''.join(get_symbol(i) for i in range(size + 1))
        names = [alphabet[i:i+2] for i in range(size)]
        inputs = ','.join(names)

        with shared_intermediates() as cache:
            for i in range(size):
                target = alphabet[i:i+2]
                eq = '{}->{}'.format(inputs, target)
                expr = contract_expression(eq, *(x.shape for x in xs))
                expr(*xs, backend=backend)
            costs.append(_compute_cost(cache))

    print('sizes = {}'.format(repr(sizes)))
    print('costs = {}'.format(repr(costs)))
    for size, cost in zip(sizes, costs):
        print('{}\t{}'.format(size, cost))
Ejemplo n.º 36
0
def ubersum(equation, *operands, **kwargs):
    """
    Generalized batched sum-product algorithm via tensor message passing.

    This generalizes :func:`~pyro.ops.einsum.contract` in two ways:

    1.  Multiple outputs are allowed, and intermediate results can be shared.
    2.  Inputs and outputs can be batched along symbols given in ``batch_dims``;
        reductions along ``batch_dims`` are product reductions.

    The best way to understand this function is to try the examples below,
    which show how :func:`ubersum` calls can be implemented as multiple calls
    to :func:`~pyro.ops.einsum.contract` (which is generally more expensive).

    To illustrate multiple outputs, note that the following are equivalent::

        z1, z2, z3 = ubersum('ab,bc->a,b,c', x, y)  # multiple outputs

        backend = 'pyro.ops.einsum.torch_log'
        z1 = contract('ab,bc->a', x, y, backend=backend)
        z2 = contract('ab,bc->b', x, y, backend=backend)
        z3 = contract('ab,bc->c', x, y, backend=backend)

    To illustrate batched inputs, note that the following are equivalent::

        assert len(x) == 3 and len(y) == 3
        z = ubersum('ab,ai,bi->b', w, x, y, batch_dims='i')

        z = contract('ab,a,a,a,b,b,b->b', w, *x, *y, backend=backend)

    When a sum dimension `a` always appears with a batch dimension `i`,
    then `a` corresponds to a distinct symbol for each slice of `a`. Thus
    the following are equivalent::

        assert len(x) == 3 and len(y) == 3
        z = ubersum('ai,ai->', x, y, batch_dims='i')

        z = contract('a,b,c,a,b,c->', *x, *y, backend=backend)

    When such a sum dimension appears in the output, it must be
    accompanied by all of its batch dimensions, e.g. the following are
    equivalent::

        assert len(x) == 3 and len(y) == 3
        z = ubersum('abi,abi->bi', x, y, batch_dims='i')

        z0 = contract('ab,ac,ad,ab,ac,ad->b', *x, *y, backend=backend)
        z1 = contract('ab,ac,ad,ab,ac,ad->c', *x, *y, backend=backend)
        z2 = contract('ab,ac,ad,ab,ac,ad->d', *x, *y, backend=backend)
        z = torch.stack([z0, z1, z2])

    Note that each batch slice through the output is multilinear in all batch
    slices through all inptus, thus e.g. batch matrix multiply would be
    implemented *without* ``batch_dims``, so the following are all equivalent::

        xy = ubersum('abc,acd->abd', x, y, batch_dims='')
        xy = torch.stack([xa.mm(ya) for xa, ya in zip(x, y)])
        xy = torch.bmm(x, y)

    Among all valid equations, some computations are polynomial in the sizes of
    the input tensors and other computations are exponential in the sizes of
    the input tensors. This function raises :py:class:`NotImplementedError`
    whenever the computation is exponential.

    :param str equation: An einsum equation, optionally with multiple outputs.
    :param torch.Tensor operands: A collection of tensors.
    :param str batch_dims: An optional string of batch dims.
    :param dict cache: An optional :func:`~opt_einsum.shared_intermediates`
        cache.
    :param bool modulo_total: Optionally allow ubersum to arbitrarily scale
        each result batch, which can significantly reduce computation. This is
        safe to set whenever each result batch denotes a nonnormalized
        probability distribution whose total is not of interest.
    :return: a tuple of tensors of requested shape, one entry per output.
    :rtype: tuple
    :raises ValueError: if tensor sizes mismatch or an output requests a
        batched dim without that dim's batch dims.
    :raises NotImplementedError: if contraction would have cost exponential in
        the size of any input tensor.
    """
    # Extract kwargs.
    cache = kwargs.pop('cache', None)
    batch_dims = kwargs.pop('batch_dims', '')
    backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_log')
    modulo_total = kwargs.pop('modulo_total', False)
    try:
        Ring = BACKEND_TO_RING[backend]
    except KeyError:
        raise NotImplementedError('\n'.join(
            ['Only the following pyro backends are currently implemented:'] +
            list(BACKEND_TO_RING)))

    # Parse generalized einsum equation.
    if '.' in equation:
        raise NotImplementedError(
            'ubsersum does not yet support ellipsis notation')
    inputs, outputs = equation.split('->')
    inputs = inputs.split(',')
    outputs = outputs.split(',')
    assert len(inputs) == len(operands)
    assert all(isinstance(x, torch.Tensor) for x in operands)
    if not modulo_total and any(outputs):
        raise NotImplementedError(
            'Try setting modulo_total=True and ensuring that your use case '
            'allows an arbitrary scale factor on each result batch.')
    if len(operands) != len(set(operands)):
        operands = [x[...] for x in operands]  # ensure tensors are unique

    # Check sizes.
    with ignore_jit_warnings():
        dim_to_size = {}
        for dims, term in zip(inputs, operands):
            for dim, size in zip(dims, map(int, term.shape)):
                old = dim_to_size.setdefault(dim, size)
                if old != size:
                    raise ValueError(
                        u"Dimension size mismatch at dim '{}': {} vs {}".
                        format(dim, size, old))

    # Construct a tensor tree shared by all outputs.
    tensor_tree = OrderedDict()
    batch_dims = frozenset(batch_dims)
    for dims, term in zip(inputs, operands):
        assert len(dims) == term.dim()
        term._pyro_dims = dims
        ordinal = batch_dims.intersection(dims)
        tensor_tree.setdefault(ordinal, []).append(term)

    # Compute outputs, sharing intermediate computations.
    results = []
    with shared_intermediates(cache) as cache:
        ring = Ring(cache, dim_to_size=dim_to_size)
        for output in outputs:
            sum_dims = set(output).union(*inputs) - set(batch_dims)
            term = contract_to_tensor(
                tensor_tree,
                sum_dims,
                target_ordinal=batch_dims.intersection(output),
                target_dims=sum_dims.intersection(output),
                ring=ring)
            if term._pyro_dims != output:
                term = term.permute(*map(term._pyro_dims.index, output))
                term._pyro_dims = output
            results.append(term)
    return tuple(results)
Ejemplo n.º 37
0
def _sample_posterior_from_trace(model, enum_trace, temperature, *args,
                                 **kwargs):
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    sum_dims = set()
    queries = []
    dim_to_size = {}
    cost_terms = OrderedDict()
    enum_terms = OrderedDict()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(plate_to_symbol[f.name]
                                for f in node["cond_indep_stack"]
                                if f.vectorized and f.size > 1)
            # For sites that depend on an enumerated variable, we need to apply
            # the mask but not the scale when sampling.
            if "masked_log_prob" not in node["packed"]:
                node["packed"]["masked_log_prob"] = packed.scale_and_mask(
                    node["packed"]["unscaled_log_prob"],
                    mask=node["packed"]["mask"])
            log_prob = node["packed"]["masked_log_prob"]
            sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal)
            if sum_dims.isdisjoint(log_prob._pyro_dims):
                continue
            dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape))
            if node["infer"].get("_enumerate_dim") is None:
                cost_terms.setdefault(ordinal, []).append(log_prob)
            else:
                enum_terms.setdefault(ordinal, []).append(log_prob)
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:
                queries.append(log_prob)
                require_backward(log_prob)

    # We take special care to match the term ordering in
    # pyro.infer.traceenum_elbo._compute_model_factors() to allow
    # contract_tensor_tree() to use shared_intermediates() inside
    # TraceEnumSample_ELBO. The special ordering is: first all cost terms in
    # order of model_trace, then all enum_terms in order of model trace.
    log_probs = cost_terms
    for ordinal, terms in enum_terms.items():
        log_probs.setdefault(ordinal, []).extend(terms)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    cache = getattr(enum_trace, "_sharing_cache", {})
    ring = _make_ring(temperature, cache, dim_to_size)
    with shared_intermediates(cache):
        log_probs = contract_tensor_tree(log_probs, sum_dims,
                                         ring=ring)  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if query not in query_to_ordinal and query._pyro_backward_result is not pending:
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            }
            log_prob = node["packed"]["masked_log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    f for f in node["cond_indep_stack"]
                    if not (f.vectorized and f.size > 1)
                    or plate_to_symbol[f.name] in ordinal)

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(node["value"],
                                            node["infer"]["_dim_to_symbol"])
                    for index, dim in zip(jit_iter(sample),
                                          sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(new_value,
                                                      enum_trace.symbol_to_dim)

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)