def _find_variational_and_priors(model,
                                 variational_with_prior,
                                 require_prior=True):
    """Find upstream StochasticTensors and match with registered priors."""
    if variational_with_prior is None:
        # pylint: disable=protected-access
        upstreams = sg._upstream_stochastic_nodes([model])
        # pylint: enable=protected-access
        upstreams = list(upstreams[model])
        if not upstreams:
            raise ValueError(
                "No upstream stochastic nodes found for tensor: %s", model)
        prior_map = dict(ops.get_collection(VI_PRIORS))
        variational_with_prior = {}
        for q in upstreams:
            if require_prior and (q not in prior_map or prior_map[q] is None):
                raise ValueError("No prior specified for StochasticTensor: %s",
                                 q)
            variational_with_prior[q] = prior_map.get(q)

    if not all(
        [isinstance(q, st.StochasticTensor) for q in variational_with_prior]):
        raise TypeError("variationals must be StochasticTensors")
    if not all([
            p is None or isinstance(p, distributions.Distribution)
            for p in variational_with_prior.values()
    ]):
        raise TypeError("priors must be Distributions")

    return variational_with_prior
def _find_variational_and_priors(model,
                                 variational_with_prior,
                                 require_prior=True):
  """Find upstream DistributionTensors and match with registered priors."""
  if variational_with_prior is None:
    # pylint: disable=protected-access
    upstreams = sg._upstream_stochastic_nodes([model])
    # pylint: enable=protected-access
    upstreams = list(upstreams[model])
    if not upstreams:
      raise ValueError("No upstream stochastic nodes found for tensor: %s",
                       model)
    prior_map = dict(ops.get_collection(VI_PRIORS))
    variational_with_prior = {}
    for q in upstreams:
      if require_prior and (q not in prior_map or prior_map[q] is None):
        raise ValueError("No prior specified for DistributionTensor: %s", q)
      variational_with_prior[q] = prior_map.get(q)

  if not all(
      [isinstance(q, st.StochasticTensor) for q in variational_with_prior]):
    raise TypeError("variationals must be DistributionTensors")
  if not all([p is None or isinstance(p, distributions.Distribution)
              for p in variational_with_prior.values()]):
    raise TypeError("priors must be BaseDistributions")

  return variational_with_prior