def _eval_iid(self, dist_name, vals, dims, prob, iid): if not iid: return PD(dist_name, vals, dims=dims, prob=prob, pscale=self._pscale) # Deal with IID cases max_dim = None for dim in dims.values(): if dim is not None: max_dim = dim if max_dim is None else max(dim, max_dim) # If scalar or prob is expected shape then perform product here if max_dim is None or max_dim == prob.ndim - 1: dist = PD(dist_name, vals, dims=dims, prob=prob, pscale=self._pscale) return dist.prod(iid) # Otherwise it is left to the user function to perform the iid product for key in iid: vals[key] = {len(vals[key])} dims[key] = None # Tidy up probability return PD(dist_name, vals, dims=dims, prob=prob, pscale=self._pscale)
def __call__(self, values=None): """ Return a probability distribution for the quantities in values. """ dist_name = self.eval_dist_name(values) vals = self.evaluate(values) prob = self.eval_prob(vals) dims = {self._name: None} if isscalar(vals[self.name]) else {self._name: 0} return PD(dist_name, vals, dims=dims, prob=prob, pscale=self._pscale)
def step(self, *args, **kwds): """ Returns a proposal distribution p(args[1]) given args[0], depending on whether using self._prop, that denotes a simple proposal distribution, or self._tran, that denotes a transitional distirbution. """ reverse = False if 'reverse' not in kwds else kwds.pop('reverse') pred_vals, succ_vals = None, None if len(args) == 1: if isinstance(args[0], (list, tuple)) and len(args[0]) == 2: pred_vals, succ_vals = args[0][0], args[0][1] else: pred_vals = args[0] elif len(args) == 2: pred_vals, succ_vals = args[0], args[1] # Evaluate predecessor values if not isinstance(pred_vals, dict): pred_vals = {key: pred_vals for key in self._keyset} pred_vals = self.parse_args(pred_vals, pass_all=True) dist_pred_name = self.eval_dist_name(pred_vals) pred_vals, pred_dims = self.evaluate(pred_vals) # Default successor values if None and delta is None if succ_vals is None and self._delta is None: pred_values = list(pred_vals.values()) if all([isscalar(pred_value) for pred_value in pred_values]): succ_vals = {0} else: succ_vals = pred_vals # Evaluate successor evaluates vals, dims, kwargs = self.eval_step(pred_vals, succ_vals, reverse=reverse) succ_vals = { key[:-1]: val for key, val in vals.items() if key[-1] == "'" } cond = self.eval_tran(vals, **kwargs) dist_succ_name = self.eval_dist_name(succ_vals, "'") dist_name = '|'.join([dist_succ_name, dist_pred_name]) return PD(dist_name, vals, dims=dims, prob=cond, pscale=self._pscale)
def propose(self, *args, **kwds): """ Returns a proposal distribution p(args[0]) for values """ suffix = "'" if 'suffix' not in kwds else kwds.pop('suffix') if not kwds and len(args) == 1 and not isinstance(args[0], dict): arg = {key: args[0] for key in self._keyset} args = arg, values = self.parse_args(*args, **kwds) dist_name = self.eval_dist_name(values, suffix) vals, dims = self.evaluate(values, _skip_parsing=True) prop = self.eval_prop(vals) if self._prop is not None else \ self.eval_prob(vals, dims) if suffix: keys = list(vals.keys()) for key in keys: mod_key = key + suffix vals.update({mod_key: vals.pop(key)}) if key in dims: dims.update({mod_key: dims.pop(key)}) return PD(dist_name, vals, dims=dims, prob=prop, pscale=self._pscale)
def reval_tran(self, dist): """ Evaluates the conditional reverse-transition function for corresponding transition conditional distribution dist. This requires a tuple input for self.set_tran() to evaluate a new conditional. """ assert isinstance(dist, PD), \ "Input must be a distribution, not {} type.".format(type(dist)) marg, cond = dist.cond, dist.marg name = margcond_str(marg, cond) vals = collections.OrderedDict(dist) dims = dist.dims # This next line needs to be modified to handle new API """ prob = dist.prob if self._sym_tran or self._tran is None \ else self._tran[1](dist) """ prob = dist.prob pscale = dist.pscale return PD(name, vals, dims=dims, prob=prob, pscale=pscale)
def step(self, *args, reverse=False): """ Returns a conditional probability distribution for quantities in args. :param *args: predecessor, successor values to evaluate conditionals. :param reverse: Boolean flag to evaluate conditional probability in reverse. :return a Dist instance of the conditional probability distribution """ pred_vals, succ_vals = None, None if len(args) == 1: if isinstance(args[0], (list, tuple)) and len(args[0]) == 2: pred_vals, succ_vals = args[0][0], args[0][1] else: pred_vals = args[0] elif len(args) == 2: pred_vals, succ_vals = args[0], args[1] dist_pred_name = self.eval_dist_name(pred_vals) dist_succ_name = None if pred_vals is None and succ_vals is None and \ self._vtype not in VTYPES[float]: dist_succ_name = self.eval_dist_name(succ_vals, "'") pred_vals = self.evaluate(pred_vals) vals, dims, kwargs = self.eval_step(pred_vals, succ_vals, reverse=reverse) cond = self.eval_tran(vals, **kwargs) if dist_succ_name is None: dist_succ_name = self.eval_dist_name(vals[self.__prime_key], "'") dist_name = '|'.join([dist_succ_name, dist_pred_name]) # TODO - distinguish between probabilistic and non-probabistic outputs if cond is None: cond = 1. for val in vals.values(): if isinstance(val, np.ndarray): cond = cond * np.ones(val.shape, dtype=float) return PD(dist_name, vals, dims=dims, prob=cond, pscale=self._pscale)
def product(*args, **kwds): """ Multiplies two or more PDs subject to the following: 1. They must not share the same marginal variables. 2. Conditional variables must be identical unless contained as marginal from another distribution. """ from probayes.pd import PD # Check pscales, scalars, possible fasttrack if not len(args): return None kwds = dict(kwds) pscales = [arg.pscale for arg in args] pscale = kwds.get('pscale', None) or prod_pscale(pscales) aresingleton = [arg.issingleton for arg in args] maybe_fasttrack = all(aresingleton) and \ np.all(pscale == np.array(pscales)) and \ pscale in [0, 1.] # Collate vals, probs, marg_names, and cond_names as lists vals = [collections.OrderedDict(arg) for arg in args] probs = [arg.prob for arg in args] marg_names = [list(arg.marg.values()) for arg in args] cond_names = [list(arg.cond.values()) for arg in args] # Detect uniqueness in marginal keys and identical conditionals all_marg_keys = [] for arg in args: all_marg_keys.extend(list(arg.marg.keys())) marg_sets = None if len(all_marg_keys) != len(set(all_marg_keys)): marg_keys, cond_keys, marg_sets, = None, None, None for arg in args: if marg_keys is None: marg_keys = list(arg.marg.keys()) elif marg_keys != list(arg.marg.keys()): marg_keys = None break if cond_keys is None: cond_keys = list(arg.cond.keys()) elif cond_keys != list(arg.cond.keys()): marg_keys = None break if marg_keys: are_marg_sets = np.array([isunitsetint(arg[marg_key]) for marg_key in marg_keys]) if marg_sets is None: if np.any(are_marg_sets): marg_sets = are_marg_sets else: marg_keys = None break elif not np.all(marg_sets == are_marg_sets): marg_keys = None break assert marg_keys is not None and marg_sets is not None, \ "Non-unique marginal variables for currently not supported: {}".\ format(all_marg_keys) maybe_fasttrack = True # Maybe fast-track identical conditionals if maybe_fasttrack: marg_same = True cond_same = True if marg_sets is None: # no need to recheck if not None (I think) marg_same = True for name in marg_names[1:]: if marg_names[0] != name: marg_same = False break cond_same = not any(cond_names) if not cond_same: cond_same = True for name in cond_names[1:]: if cond_names[0] != name: cond_same = False break if marg_same and cond_same: marg_names = marg_names[0] cond_names = cond_names[0] prod_marg_name = ','.join(marg_names) prod_cond_name = ','.join(cond_names) prod_name = '|'.join([prod_marg_name, prod_cond_name]) prod_vals = collections.OrderedDict() for i, val in enumerate(vals): areunitsetints = np.array([isunitsetint(_val) for _val in val.values()]) if not np.any(areunitsetints): prod_vals.update(val) else: assert marg_sets is not None, "Variable mismatch" assert np.all(marg_sets == areunitsetints[:len(marg_sets)]), \ "Variable mismatch" if not len(prod_vals): prod_vals.update(collections.OrderedDict(val)) else: for j, key in enumerate(prod_vals.keys()): if areunitsetints[j]: prod_vals.update({key: {list(prod_vals[key])[0] + \ list(val[key])[0]}}) if marg_sets is not None: prob, pscale = prod_rule(*tuple(probs), pscales=pscales, pscale=pscale) return PD(prod_name, prod_vals, dims=args[0].dims, prob=prob, pscale=pscale) else: prod_prob = float(sum(probs)) if iscomplex(pscale) else float(np.prod(probs)) return PD(prod_name, prod_vals, prob=prod_prob, pscale=pscale) # Check cond->marg accounts for all differences between conditionals prod_marg = [name for dist_marg_names in marg_names \ for name in dist_marg_names] prod_marg_name = ','.join(prod_marg) flat_cond_names = [name for dist_cond_names in cond_names \ for name in dist_cond_names] cond2marg = [cond_name for cond_name in flat_cond_names \ if cond_name in prod_marg] prod_cond = [cond_name for cond_name in flat_cond_names \ if cond_name not in cond2marg] cond2marg_set = set(cond2marg) # Check conditionals compatible prod_cond_set = set(prod_cond) cond2marg_dict = {name: None for name in cond2marg} for i, arg in enumerate(args): cond_set = set(cond_names[i]) - cond2marg_set if cond_set: assert prod_cond_set == cond_set, \ "Incompatible product conditional {} for conditional set {}: ".format( prod_cond_set, cond_set) for name in cond2marg: if name in arg.keys(): values = arg[name] if not isscalar(values): values = np.ravel(values) if cond2marg_dict[name] is None: cond2marg_dict[name] = values elif not np.allclose(cond2marg_dict[name], values): raise ValueError("Mismatch in values for condition {}".format(name)) # Establish product name, values, and dimensions prod_keys = str2key(prod_marg + prod_cond) prod_nkeys = len(prod_keys) prod_aresingleton = np.zeros(prod_nkeys, dtype=bool) prod_areunitsetints = np.zeros(prod_nkeys, dtype=bool) prod_cond_name = ','.join(prod_cond) prod_name = prod_marg_name if not len(prod_cond_name) \ else '|'.join([prod_marg_name, prod_cond_name]) prod_vals = collections.OrderedDict() for i, key in enumerate(prod_keys): values = None for val in vals: if key in val.keys(): values = val[key] prod_areunitsetints[i] = isunitsetint(val[key]) if prod_areunitsetints[i]: values = {0} break assert values is not None, "Values for key {} not found".format(key) prod_aresingleton[i] = issingleton(values) prod_vals.update({key: values}) if np.any(prod_areunitsetints): for i, key in enumerate(prod_keys): if prod_areunitsetints[i]: for val in vals: if key in val: assert isunitsetint(val[key]), "Mismatch in variables {} vs {}".\ format(prod_vals, val) prod_vals.update({key: {list(prod_vals[key])[0] + list(val[key])[0]}}) prod_newdims = np.array(np.logical_not(prod_aresingleton)) dims_shared = False for arg in args: argdims = [dim for dim in arg.dims.values() if dim is not None] if len(argdims) != len(set(argdims)): dims_shared = True # Shared dimensions limit product dimensionality if dims_shared: seen_keys = set() for i, key in enumerate(prod_keys): if prod_newdims[i] and key not in seen_keys: for arg in args: if key in arg.dims: dim = arg.dims[key] seen_keys.add(key) for argkey, argdim in arg.dims.items(): seen_keys.add(argkey) if argkey != key and argdim is not None: if dim == argdim: index = prod_keys.index(argkey) prod_newdims[index] = False prod_cdims = np.cumsum(prod_newdims) prod_ndims = prod_cdims[-1] # Fast-track scalar products if maybe_fasttrack and prod_ndims == 0: prob = float(sum(probs)) if iscomplex(pscale) else float(np.prod(probs)) return PD(prod_name, prod_vals, prob=prob, pscale=pscale) # Reshape values - they require no axes swapping ones_ndims = np.ones(prod_ndims, dtype=int) prod_shape = np.ones(prod_ndims, dtype=int) scalarset = set() prod_dims = collections.OrderedDict() for i, key in enumerate(prod_keys): if prod_aresingleton[i]: scalarset.add(key) else: values = prod_vals[key] re_shape = np.copy(ones_ndims) dim = prod_cdims[i]-1 prod_dims.update({key: dim}) re_shape[dim] = values.size prod_shape[dim] = values.size prod_vals.update({key: values.reshape(re_shape)}) # Match probability axes and shapes with axes swapping then reshaping for i in range(len(args)): prob = probs[i] if not isscalar(prob): dims = collections.OrderedDict() for key, val in args[i].dims.items(): if val is not None: dims.update({val: prod_dims[key]}) old_dims = [] new_dims = [] for key, val in dims.items(): if key not in old_dims: old_dims.append(key) new_dims.append(val) if len(old_dims) > 1 and not old_dims == new_dims: max_dims_inc = max(new_dims) + 1 while prob.ndim < max_dims_inc: prob = np.expand_dims(prob, -1) prob = np.moveaxis(prob, old_dims, new_dims) re_shape = np.copy(ones_ndims) for dim in new_dims: re_shape[dim] = prod_shape[dim] probs[i] = prob.reshape(re_shape) # Multiply the probabilities and output the result as a distribution instance prob, pscale = prod_rule(*tuple(probs), pscales=pscales, pscale=pscale) return PD(prod_name, prod_vals, dims=prod_dims, prob=prob, pscale=pscale)
def summate(*args): """ Quick and dirty concatenation """ from probayes.pd import PD if not len(args): return None pscales = [arg.pscale for arg in args] vals = [dict(arg) for arg in args] probs = [arg.prob for arg in args] # Check pscales are the same pscale = pscales[0] for _pscale in pscales[1:]: assert pscale == _pscale, \ "Cannot summate distributions with different pscales" # Check marginal and conditional keys marg_keys = list(args[0].marg.keys()) cond_keys = list(args[0].cond.keys()) for arg in args[1:]: assert marg_keys == list(arg.marg.keys()), \ "Marginal variable names not identical across distributions: {}" assert cond_keys == list(arg.cond.keys()), \ "Conditional variable names not identical across distributions: {}" sum_keys = marg_keys + cond_keys sum_name = ','.join(marg_keys) if cond_keys: sum_name += '|' + ','.join(cond_keys) # If all singleton, concatenate in dimension 0 if all([arg.issingleton for arg in args]): unitsets = {key: isunitsetint(args[0][key]) for key in sum_keys} sum_dims = {key: None if unitsets[key] else 0 for key in sum_keys} sum_vals = {key: 0 if unitsets[key] else [] for key in sum_keys} sum_prob = [] for arg in args: for key, val in arg.items(): if unitsets[key]: assert isunitsetint(val), \ "Cannot mix unspecified set and specified values" sum_vals[key] += list(val)[0] else: assert not isunitsetint(val), \ "Cannot mix unspecified set and specified values" sum_vals[key].append(val) sum_prob.append(arg.prob) for key in sum_keys: if unitsets[key]: sum_vals[key] = {sum_vals[key]} else: sum_vals[key] = np.ravel(sum_vals[key]) sum_prob = np.ravel(sum_prob) return PD(sum_name, sum_vals, dims=sum_dims, prob=sum_prob, pscale=pscale) # 2. all identical but in one dimension: concatenate in that dimension # TODO: fix the remaining code of this function below sum_vals = collections.OrderedDict(args[0]) sum_dims = [None] * (len(args) - 1) for i, arg in enumerate(args): if i == 0: continue for key in marg_keys: if sum_dims[i-1] is not None: continue elif not arg.singleton(key): key_vals = arg[key] if key_vals.size == sum_vals[key].size: if np.allclose(key_vals, sum_vals[key]): continue sum_dims[i-1] = arg.dims[key] assert len(set(sum_dims)) > 1, "Cannot find unique concatenation axis" sum_dim = sum_dims[0] sum_dims = args[0].dims key = marg_keys[sum_dim] sum_prob = np.copy(probs[0]) for i, val in enumerate(vals): if i == 0: continue sum_vals[key] = np.concatenate([sum_vals[key], val[key]], axis=sum_dim) sum_prob = np.concatenate([sum_prob, probs[i]], axis=sum_dim) return PD(sum_name, sum_vals, dims=sum_dims, prob=sum_prob, pscale=pscale)