예제 #1
0
 def eval_tran(self, values, **kwargs):
     if 'cond' in kwargs:
         return kwargs['cond']
     cond = DEFAULT_CONDITIONAL_PROBABILITY[iscomplex(self._pscale)]
     reverse = False if 'reverse' not in kwargs else kwargs['reverse']
     if self._tran is None:
         if self._tfun is not None:  # tfun without tran means no cond. prob.
             return cond
         rvs = self._varlist
         if len(rvs) == 1 and rvs[0]._tran is not None:
             return rvs[0].eval_tran(values, **kwargs)
         pred_vals = dict()
         succ_vals = dict()
         for key_, val in values.items():
             prime = key_[-1] == "'"
             key = key_[:-1] if prime else key_
             if key in self._keylist:
                 if prime:
                     succ_vals.update({key: val})
                 else:
                     pred_vals.update({key: val})
         cond, _ = rv_prod_rule(pred_vals,
                                succ_vals,
                                rvs=rvs,
                                pscale=self._pscale)
     elif not self._tran.callable or self._tran.isscipy:
         return cond
     else:
         cond = self._tran(values) if self._sym_tran else \
                self._tran[int(reverse)](values)
     return cond
예제 #2
0
    def eval_prob(self, values=None, dims=None):
        if values is None:
            values = {}
        else:
            assert isinstance(values, dict), \
                "Input to eval_prob() requires values dict"
            assert set(values.keys()) == self._keyset, \
              "Sample dictionary keys {} mismatch with RV names {}".format(
                values.keys(), self._keylist)

        # If not specified, treat as independent variables
        if self._prob is None or self.__def_prob:
            rvs = self._varlist
            if len(rvs) == 1 and rvs[0]._prob is not None:
                prob = rvs[0].eval_prob(values[rvs[0].name])
            else:
                prob, _ = rv_prod_rule(values, rvs=rvs, pscale=self._pscale)
            return prob

        # Otherwise distinguish between uncallable and callables
        if not self._callable:
            return self._call()
        if self.issympy:
            prob = self._partials['logp'](values) if iscomplex(self._pscale) else \
                   self._partials['prob'](values)
            return prob

        # Pass-dims is to replaced when passing Distributions()
        if self._passdims:
            return super().eval_prob(values, dims=dims)
        return super().eval_prob(values)
예제 #3
0
def rv_prod_rule(*args, rvs, pscale=None):
    """ Returns the probability product treating all rvs as independent.
  Values (=args[0]) are keyed by RV name and rvs are a list of RVs.
  """
    values = args[0]
    pscales = [rv.pscale for rv in rvs]
    pscale = pscale or prod_pscale(pscales)
    use_logs = iscomplex(pscale)
    probs = [rv.eval_prob(values[rv.name]) for rv in rvs]
    prob, pscale = prod_rule(*tuple(probs), pscales=pscales, pscale=pscale)

    # This section below is there just to play nicely with conditionals
    if len(args) > 1:
        if use_logs:
            prob = rescale(prob, pscale, 0.j)
        else:
            prob = rescale(prob, pscale, 1.)
        for arg in args[1:]:
            if use_logs:
                offs, _ = rv_prod_rule(arg, rvs=rvs, pscale=0.j)
                prob = prob + offs
            else:
                coef, _ = rv_prod_rule(arg, rvs=rvs, pscale=1.)
                prob = prob * coef
        if use_logs:
            prob = prob / float(len(args))
            prob = rescale(prob, 0.j, pscale)
        else:
            prob = prob**(1. / float(len(args)))
            prob = rescale(prob, 1., pscale)
    return prob, pscale
예제 #4
0
파일: prob.py 프로젝트: Bhumbra/probayes
    def pscale(self, pscale=None):
        """ Sets the probability scaling constant used for probabilities.

    :param pscale: can be None, a real number, or a complex number, or 'log'

       if pscale is None (default) the normalisation constant is set as 1.
       if pscale is real, this defines the normalisation constant.
       if pscale is complex, this defines the offset for log probabilities.
       if pscale is 'log', this denotes a logarithmic scale with an offset of 0.

    :return: pscale (either as a real or complex number)
    """
        self._pscale = eval_pscale(pscale)
        self._logp = iscomplex(self._pscale)
        return self._pscale
예제 #5
0
def uniform_prob(*args, prob=None, inside=None, pscale=1.):
    """ Uniform probability function for discrete and continuous vtypes. """

    # Detect ptype, default to prob if no values, otherwise detect vtype
    assert len(args) >= 1, "Minimum of a single positional argument"
    pscale = eval_pscale(pscale)
    use_logs = iscomplex(pscale)
    if prob is None:
        prob = 0. if use_logs else 1.
    vals = args[0]
    if vals is None:
        return prob
    vtype = eval_vtype(vals) if callable(inside) else eval_vtype(inside)

    # Set inside function by vtype if not specified
    if not callable(inside):
        if vtype in VTYPES[float]:
            inside = lambda x: np.logical_and(x >= min(inside), x <= max(inside
                                                                         ))
        else:
            inside = lambda x: np.isin(x, inside)

    # If scalar, check within variable set
    p_zero = NEARLY_NEGATIVE_INF if use_logs else 0.
    if isscalar(vals):
        prob = prob if inside(vals) else p_zero

    # Otherwise treat as uniform within range
    else:
        p_true = prob
        prob = np.tile(p_zero, vals.shape)
        prob[inside(vals)] = p_true

    # This section below is there just to play nicely with conditionals
    if len(args) > 1:
        for arg in args[1:]:
            if use_logs:
                prob = prob + uniform_prob(arg, inside=inside, pscale=0.j)
            else:
                prob = prob * uniform_prob(arg, inside=inside)
    return prob
예제 #6
0
def call_scipy_prob(func, pscale, *args, **kwds):
    index = 1 if iscomplex(pscale) else 0
    return func[index](*args, **kwds)
예제 #7
0
def product(*args, **kwds):
  """ Multiplies two or more PDs subject to the following:
  1. They must not share the same marginal variables. 
  2. Conditional variables must be identical unless contained as marginal from
     another distribution.
  """
  from probayes.pd import PD

  # Check pscales, scalars, possible fasttrack
  if not len(args):
    return None
  kwds = dict(kwds)
  pscales = [arg.pscale for arg in args]
  pscale = kwds.get('pscale', None) or prod_pscale(pscales)
  aresingleton = [arg.issingleton for arg in args]
  maybe_fasttrack = all(aresingleton) and \
                    np.all(pscale == np.array(pscales)) and \
                    pscale in [0, 1.]


  # Collate vals, probs, marg_names, and cond_names as lists
  vals = [collections.OrderedDict(arg) for arg in args]
  probs = [arg.prob for arg in args]
  marg_names = [list(arg.marg.values()) for arg in args]
  cond_names = [list(arg.cond.values()) for arg in args]

  # Detect uniqueness in marginal keys and identical conditionals
  all_marg_keys = []
  for arg in args:
    all_marg_keys.extend(list(arg.marg.keys()))
  marg_sets = None
  if len(all_marg_keys) != len(set(all_marg_keys)):
    marg_keys, cond_keys, marg_sets, = None, None, None
    for arg in args:
      if marg_keys is None:
        marg_keys = list(arg.marg.keys())
      elif marg_keys != list(arg.marg.keys()):
        marg_keys = None
        break
      if cond_keys is None:
        cond_keys = list(arg.cond.keys())
      elif cond_keys != list(arg.cond.keys()):
        marg_keys = None
        break
      if marg_keys:  
        are_marg_sets = np.array([isunitsetint(arg[marg_key]) for
                                  marg_key in marg_keys])
        if marg_sets is None:
          if np.any(are_marg_sets):
            marg_sets = are_marg_sets
          else:
            marg_keys = None
            break
        elif not np.all(marg_sets == are_marg_sets):
          marg_keys = None
          break
    assert marg_keys is not None and marg_sets is not None, \
      "Non-unique marginal variables for currently not supported: {}".\
      format(all_marg_keys)
    maybe_fasttrack = True

  # Maybe fast-track identical conditionals
  if maybe_fasttrack:
    marg_same = True
    cond_same = True
    if marg_sets is None: # no need to recheck if not None (I think)
      marg_same = True
      for name in marg_names[1:]:
        if marg_names[0] != name:
          marg_same = False
          break
      cond_same = not any(cond_names)
      if not cond_same:
        cond_same = True
        for name in cond_names[1:]:
          if cond_names[0] != name:
            cond_same = False
            break
    if marg_same and cond_same:
      marg_names = marg_names[0]
      cond_names = cond_names[0]
      prod_marg_name = ','.join(marg_names)
      prod_cond_name = ','.join(cond_names)
      prod_name = '|'.join([prod_marg_name, prod_cond_name])
      prod_vals = collections.OrderedDict()
      for i, val in enumerate(vals):
        areunitsetints = np.array([isunitsetint(_val) 
                                   for _val in val.values()])
        if not np.any(areunitsetints):
          prod_vals.update(val)
        else:
          assert marg_sets is not None, "Variable mismatch"
          assert np.all(marg_sets == areunitsetints[:len(marg_sets)]), \
              "Variable mismatch"
          if not len(prod_vals):
            prod_vals.update(collections.OrderedDict(val))
          else:
            for j, key in enumerate(prod_vals.keys()):
              if areunitsetints[j]:
                prod_vals.update({key: {list(prod_vals[key])[0] + \
                                        list(val[key])[0]}})
      if marg_sets is not None:
        prob, pscale = prod_rule(*tuple(probs), pscales=pscales, pscale=pscale)
        return PD(prod_name, prod_vals, dims=args[0].dims, prob=prob, pscale=pscale)
      else:
        prod_prob = float(sum(probs)) if iscomplex(pscale) else float(np.prod(probs))
        return PD(prod_name, prod_vals, prob=prod_prob, pscale=pscale)

  # Check cond->marg accounts for all differences between conditionals
  prod_marg = [name for dist_marg_names in marg_names \
                          for name in dist_marg_names]
  prod_marg_name = ','.join(prod_marg)
  flat_cond_names = [name for dist_cond_names in cond_names \
                          for name in dist_cond_names]
  cond2marg = [cond_name for cond_name in flat_cond_names \
                         if cond_name in prod_marg]
  prod_cond = [cond_name for cond_name in flat_cond_names \
                         if cond_name not in cond2marg]
  cond2marg_set = set(cond2marg)

  # Check conditionals compatible
  prod_cond_set = set(prod_cond)
  cond2marg_dict = {name: None for name in cond2marg}
  for i, arg in enumerate(args):
    cond_set = set(cond_names[i]) - cond2marg_set
    if cond_set:
      assert prod_cond_set == cond_set, \
          "Incompatible product conditional {} for conditional set {}: ".format(
              prod_cond_set, cond_set)
    for name in cond2marg:
      if name in arg.keys():
        values = arg[name]
        if not isscalar(values):
          values = np.ravel(values)
        if cond2marg_dict[name] is None:
          cond2marg_dict[name] = values
        elif not np.allclose(cond2marg_dict[name], values):
          raise ValueError("Mismatch in values for condition {}".format(name))

  # Establish product name, values, and dimensions
  prod_keys = str2key(prod_marg + prod_cond)
  prod_nkeys = len(prod_keys)
  prod_aresingleton = np.zeros(prod_nkeys, dtype=bool)
  prod_areunitsetints = np.zeros(prod_nkeys, dtype=bool)
  prod_cond_name = ','.join(prod_cond)
  prod_name = prod_marg_name if not len(prod_cond_name) \
              else '|'.join([prod_marg_name, prod_cond_name])
  prod_vals = collections.OrderedDict()
  for i, key in enumerate(prod_keys):
    values = None
    for val in vals:
      if key in val.keys():
        values = val[key]
        prod_areunitsetints[i] = isunitsetint(val[key])
        if prod_areunitsetints[i]:
          values = {0}
        break
    assert values is not None, "Values for key {} not found".format(key)
    prod_aresingleton[i] = issingleton(values)
    prod_vals.update({key: values})
  if np.any(prod_areunitsetints):
    for i, key in enumerate(prod_keys):
      if prod_areunitsetints[i]:
        for val in vals:
          if key in val:
            assert isunitsetint(val[key]), "Mismatch in variables {} vs {}".\
                format(prod_vals, val)
            prod_vals.update({key: {list(prod_vals[key])[0] + list(val[key])[0]}})
  prod_newdims = np.array(np.logical_not(prod_aresingleton))
  dims_shared = False
  for arg in args:
    argdims = [dim for dim in arg.dims.values() if dim is not None]
    if len(argdims) != len(set(argdims)):
      dims_shared = True

  # Shared dimensions limit product dimensionality
  if dims_shared:
    seen_keys = set()
    for i, key in enumerate(prod_keys):
      if prod_newdims[i] and key not in seen_keys:
        for arg in args:
          if key in arg.dims:
            dim = arg.dims[key]
            seen_keys.add(key)
            for argkey, argdim in arg.dims.items():
              seen_keys.add(argkey)
              if argkey != key and argdim is not None:
                if dim == argdim:
                  index = prod_keys.index(argkey)
                  prod_newdims[index] = False

  prod_cdims = np.cumsum(prod_newdims)
  prod_ndims = prod_cdims[-1]

  # Fast-track scalar products
  if maybe_fasttrack and prod_ndims == 0:
     prob = float(sum(probs)) if iscomplex(pscale) else float(np.prod(probs))
     return PD(prod_name, prod_vals, prob=prob, pscale=pscale)

  # Reshape values - they require no axes swapping
  ones_ndims = np.ones(prod_ndims, dtype=int)
  prod_shape = np.ones(prod_ndims, dtype=int)
  scalarset = set()
  prod_dims = collections.OrderedDict()
  for i, key in enumerate(prod_keys):
    if prod_aresingleton[i]:
      scalarset.add(key)
    else:
      values = prod_vals[key]
      re_shape = np.copy(ones_ndims)
      dim = prod_cdims[i]-1
      prod_dims.update({key: dim})
      re_shape[dim] = values.size
      prod_shape[dim] = values.size
      prod_vals.update({key: values.reshape(re_shape)})
  
  # Match probability axes and shapes with axes swapping then reshaping
  for i in range(len(args)):
    prob = probs[i]
    if not isscalar(prob):
      dims = collections.OrderedDict()
      for key, val in args[i].dims.items():
        if val is not None:
          dims.update({val: prod_dims[key]})
      old_dims = []
      new_dims = []
      for key, val in dims.items():
        if key not in old_dims:
          old_dims.append(key)
          new_dims.append(val)
      if len(old_dims) > 1 and not old_dims == new_dims:
        max_dims_inc = max(new_dims) + 1
        while prob.ndim < max_dims_inc:
          prob = np.expand_dims(prob, -1)
        prob = np.moveaxis(prob, old_dims, new_dims)
      re_shape = np.copy(ones_ndims)
      for dim in new_dims:
        re_shape[dim] = prod_shape[dim]
      probs[i] = prob.reshape(re_shape)

  # Multiply the probabilities and output the result as a distribution instance
  prob, pscale = prod_rule(*tuple(probs), pscales=pscales, pscale=pscale)

  return PD(prod_name, prod_vals, dims=prod_dims, prob=prob, pscale=pscale)