def eager_integrate(delta, integrand, reduced_vars): if not reduced_vars & delta.fresh: return None subs = tuple((name, point) for name, (point, log_density) in delta.terms if name in reduced_vars) new_integrand = Subs(integrand, subs) new_log_measure = Subs(delta, subs) result = Integrate(new_log_measure, new_integrand, reduced_vars - delta.fresh) return result
def adjoint_subs_gaussianmixture_discrete(adj_redop, adj_binop, out_adj, arg, subs): if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs): raise NotImplementedError( "TODO implement adjoint for substitution into Gaussian real variable" ) # invert renaming renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable)) out_adj = Subs(out_adj, renames) # inverting advanced indexing slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable)) arg_int_inputs = OrderedDict( (k, v) for k, v in arg.inputs.items() if v.dtype != 'real') zeros_like_out = Subs( Tensor( arg.terms[1].info_vec.new_full(arg.terms[1].info_vec.shape[:-1], ops.UNITS[adj_binop]), arg_int_inputs), slices) out_adj = adj_binop(out_adj, zeros_like_out) in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj, arg.terms[0], subs)[arg.terms[0]] # invert the slicing for the Gaussian term even though the message does not affect the values in_adj_info_vec = list( adjoint_ops( Subs, adj_redop, adj_binop, # ops.add, ops.mul, zeros_like_out, Tensor(arg.terms[1].info_vec, arg_int_inputs), slices).values())[0] in_adj_precision = list( adjoint_ops( Subs, adj_redop, adj_binop, # ops.add, ops.mul, zeros_like_out, Tensor(arg.terms[1].precision, arg_int_inputs), slices).values())[0] assert isinstance(in_adj_info_vec, Tensor) assert isinstance(in_adj_precision, Tensor) in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data, arg.inputs.copy()) in_adj = in_adj_gaussian + in_adj_discrete return {arg: in_adj}
def eager_add(op, joint, delta): # Update with a degenerate distribution, typically a monte carlo sample. if delta.name in joint.inputs: joint = Subs(joint, ((delta.name, delta.point), )) if not isinstance(joint, Joint): return joint + delta for d in joint.deltas: if d.name in delta.inputs: delta = Subs(delta, ((d.name, d.point), )) deltas = joint.deltas + (delta, ) return Joint(deltas, joint.discrete, joint.gaussian)
def _eager_subs_int(self, subs, remaining_subs): # Perform integer substitution, i.e. slicing into a batch. int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.info_vec, self.precision] funsors = [Subs(Tensor(x, int_inputs), subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, remaining_subs) if remaining_subs else int_result
def eager_add(op, delta, other): if delta.name in other.inputs: other = Subs(other, ((delta.name, delta.point), )) assert isinstance(other, (Number, Tensor, Gaussian)) if isinstance(other, (Number, Tensor)): return Joint((delta, ), discrete=other) else: return Joint((delta, ), gaussian=other)
def distribute_subs_contraction(arg, subs): new_terms = tuple( Subs(v, tuple( (name, sub) for name, sub in subs if name in v.inputs)) if any(name in v.inputs for name, sub in subs) else v for v in arg.terms) return Contraction(arg.red_op, arg.bin_op, arg.reduced_vars, *new_terms)
def _eager_subs_var(self, subs, remaining_subs): # Perform variable substitution, i.e. renaming of inputs. rename = {k: v.name for k, v in subs} inputs = OrderedDict( (rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.info_vec, self.precision, inputs) return Subs(var_result, remaining_subs) if remaining_subs else var_result
def test_affine_subs(): # This was recorded from test_pyro_convert. x = Subs( Gaussian( torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32), # noqa torch.tensor([[1.0199567079544067, 0.9840421676635742, -0.473368763923645, 0.34206756949424744, -0.7562517523765564], [0.9840421676635742, 1.511502742767334, -1.7593903541564941, 0.6647964119911194, -0.5119513273239136], [-0.4733688533306122, -1.7593903541564941, 3.2386727333068848, -0.9345928430557251, -0.1534711718559265], [0.34206756949424744, 0.6647964119911194, -0.9345928430557251, 0.3141004145145416, -0.12399007380008698], [-0.7562517523765564, -0.5119513273239136, -0.1534711718559265, -0.12399007380008698, 0.6450173854827881]], dtype=torch.float32), # noqa (('state_1_b6', reals(3,),), ('obs_b2', reals(2,),),)), (('obs_b2', Contraction(ops.nullop, ops.add, frozenset(), (Variable('bias_b5', reals(2,)), Tensor( torch.tensor([-2.1787893772125244, 0.5684312582015991], dtype=torch.float32), # noqa (), 'real'),)),),)) assert isinstance(x, (Gaussian, Contraction)), x.pretty()
def eager_add(op, joint, other): # Update with a delayed gaussian random variable. subs = tuple( (d.name, d.point) for d in joint.deltas if d.name in other.inputs) if subs: other = Subs(other, subs) if joint.gaussian is not Number(0): other = joint.gaussian + other if not isinstance(other, Gaussian): return Joint(joint.deltas, joint.discrete) + other return Joint(joint.deltas, joint.discrete, other)
def _simplify_integrate(fn, joint, integrand, reduced_vars): if any(d.name in reduced_vars for d in joint.deltas): subs = tuple( (d.name, d.point) for d in joint.deltas if d.name in reduced_vars) deltas = tuple(d for d in joint.deltas if d.name not in reduced_vars) log_measure = Joint(deltas, joint.discrete, joint.gaussian) integrand = Subs(integrand, subs) reduced_vars = reduced_vars - frozenset(name for name, point in subs) return Integrate(log_measure, integrand, reduced_vars) return fn(joint, integrand, reduced_vars)
def eager_markov_product(sum_op, prod_op, trans, time, step, step_names): if step: result = sequential_sum_product(sum_op, prod_op, trans, time, dict(step)) elif time.name in trans.inputs: result = trans.reduce(prod_op, time.name) elif prod_op is ops.add: result = trans * time.size elif prod_op is ops.mul: result = trans ** time.size else: raise NotImplementedError('https://github.com/pyro-ppl/funsor/issues/233') return Subs(result, step_names)
def eager_subs(self, subs): assert isinstance(subs, tuple) # Eagerly rename variables. rename = {k: v.name for k, v in subs if isinstance(v, Variable)} if not rename: return None step_names = frozenset( (k, rename.get(v, v)) for k, v in self.step_names.items()) result = MarkovProduct(self.sum_op, self.prod_op, self.trans, self.time, self.step, step_names) lazy = tuple((k, v) for k, v in subs if not isinstance(v, Variable)) if lazy: result = Subs(result, lazy) return result
def adjoint_subs_tensor(adj_redop, adj_binop, out_adj, arg, subs): assert all(isinstance(v, Funsor) for k, v in subs) # invert renaming renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable)) out_adj = Subs(out_adj, renames) # inverting advanced indexing slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable)) # TODO avoid reifying these zero/one tensors by using symbolic constants # ones for things that weren't sliced away ones_like_out = Subs(Tensor(torch.full_like(arg.data, ops.UNITS[adj_binop]), arg.inputs.copy(), arg.output.dtype), slices) arg_adj = adj_binop(out_adj, ones_like_out) # ones for things that were sliced away ones_like_arg = Tensor(torch.full_like(arg.data, ops.UNITS[adj_binop]), arg.inputs.copy(), arg.output.dtype) arg_adj = _scatter(arg_adj, ones_like_arg, slices) return {arg: arg_adj}
def adjoint_subs_gaussianmixture_gaussianmixture(adj_redop, adj_binop, out_adj, arg, subs): if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs): raise NotImplementedError("TODO implement adjoint for substitution into Gaussian real variable") # invert renaming renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable)) out_adj = Subs(out_adj, renames) # inverting advanced indexing slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable)) assert len(slices + renames) == len(subs) in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj.terms[0], arg.terms[0], subs)[arg.terms[0]] arg_int_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype != 'real') out_adj_int_inputs = OrderedDict((k, v) for k, v in out_adj.inputs.items() if v.dtype != 'real') arg_real_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype == 'real') align_inputs = OrderedDict((k, v) for k, v in out_adj.terms[1].inputs.items() if v.dtype != 'real') align_inputs.update(arg_real_inputs) out_adj_info_vec, out_adj_precision = align_gaussian(align_inputs, out_adj.terms[1]) in_adj_info_vec = list(adjoint_ops(Subs, adj_redop, adj_binop, # ops.add, ops.mul, Tensor(out_adj_info_vec, out_adj_int_inputs), Tensor(arg.terms[1].info_vec, arg_int_inputs), slices).values())[0] in_adj_precision = list(adjoint_ops(Subs, adj_redop, adj_binop, # ops.add, ops.mul, Tensor(out_adj_precision, out_adj_int_inputs), Tensor(arg.terms[1].precision, arg_int_inputs), slices).values())[0] assert isinstance(in_adj_info_vec, Tensor) assert isinstance(in_adj_precision, Tensor) in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data, arg.inputs.copy()) in_adj = in_adj_gaussian + in_adj_discrete return {arg: in_adj}
def _eager_subs_affine(self, subs, remaining_subs): # Extract an affine representation. affine = OrderedDict() for k, v in subs: const, coeffs = extract_affine(v) if (isinstance(const, Tensor) and all( isinstance(coeff, Tensor) for coeff, _ in coeffs.values())): affine[k] = const, coeffs else: remaining_subs += (k, v), if not affine: return reflect(Subs, self, remaining_subs) # Align integer dimensions. old_int_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype != 'real') tensors = [ Tensor(self.info_vec, old_int_inputs), Tensor(self.precision, old_int_inputs) ] for const, coeffs in affine.values(): tensors.append(const) tensors.extend(coeff for coeff, _ in coeffs.values()) new_int_inputs, tensors = align_tensors(*tensors, expand=True) tensors = (Tensor(x, new_int_inputs) for x in tensors) old_info_vec = next(tensors).data old_precision = next(tensors).data for old_k, (const, coeffs) in affine.items(): const = next(tensors) for new_k, (coeff, eqn) in coeffs.items(): coeff = next(tensors) coeffs[new_k] = coeff, eqn affine[old_k] = const, coeffs batch_shape = old_info_vec.shape[:-1] # Align real dimensions. old_real_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype == 'real') new_real_inputs = old_real_inputs.copy() for old_k, (const, coeffs) in affine.items(): del new_real_inputs[old_k] for new_k, (coeff, eqn) in coeffs.items(): new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])] new_real_inputs[new_k] = Reals[new_shape] old_offsets, old_dim = _compute_offsets(old_real_inputs) new_offsets, new_dim = _compute_offsets(new_real_inputs) new_inputs = new_int_inputs.copy() new_inputs.update(new_real_inputs) # Construct a blockwise affine representation of the substitution. subs_vector = BlockVector(batch_shape + (old_dim, )) subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim)) for old_k, old_offset in old_offsets.items(): old_size = old_real_inputs[old_k].num_elements old_slice = slice(old_offset, old_offset + old_size) if old_k in new_real_inputs: new_offset = new_offsets[old_k] new_slice = slice(new_offset, new_offset + old_size) subs_matrix[..., new_slice, old_slice] = \ ops.new_eye(self.info_vec, batch_shape + (old_size,)) continue const, coeffs = affine[old_k] old_shape = old_real_inputs[old_k].shape assert const.data.shape == batch_shape + old_shape subs_vector[..., old_slice] = const.data.reshape(batch_shape + (old_size, )) for new_k, new_offset in new_offsets.items(): if new_k in coeffs: coeff, eqn = coeffs[new_k] new_size = new_real_inputs[new_k].num_elements new_slice = slice(new_offset, new_offset + new_size) assert coeff.shape == new_real_inputs[ new_k].shape + old_shape subs_matrix[..., new_slice, old_slice] = \ coeff.data.reshape(batch_shape + (new_size, old_size)) subs_vector = subs_vector.as_tensor() subs_matrix = subs_matrix.as_tensor() subs_matrix_t = ops.transpose(subs_matrix, -1, -2) # Construct the new funsor. Suppose the old Gaussian funsor g has density # g(x) = < x | i - 1/2 P x> # Now define a new funsor f by substituting x = A y + B: # f(y) = g(A y + B) # = < A y + B | i - 1/2 P (A y + B) > # = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B > # =: < y | i' - 1/2 P' y > + C # where P' = At P A and i' = At (i - P B) parametrize a new Gaussian # and C = < B | i - 1/2 P B > parametrize a new Tensor. precision = subs_matrix @ old_precision @ subs_matrix_t info_vec = _mv(subs_matrix, old_info_vec - _mv(old_precision, subs_vector)) const = _vv(subs_vector, old_info_vec - 0.5 * _mv(old_precision, subs_vector)) result = Gaussian(info_vec, precision, new_inputs) + Tensor( const, new_int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
def _eager_subs_real(self, subs, remaining_subs): # Broadcast all component tensors. subs = OrderedDict(subs) int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.info_vec, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(subs.values()) int_inputs, tensors = align_tensors(*tensors) batch_dim = len(tensors[0].shape) - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (info_vec, precision), values = tensors[:2], tensors[2:] offsets, event_size = _compute_offsets(self.inputs) slices = [(k, slice(offset, offset + self.inputs[k].num_elements)) for k, offset in offsets.items()] # Expand all substituted values. values = OrderedDict(zip(subs, values)) for k, value in values.items(): value = value.reshape(value.shape[:batch_dim] + (-1, )) if not get_tracing_state(): assert value.shape[-1] == self.inputs[k].num_elements values[k] = ops.expand(value, batch_shape + value.shape[-1:]) # Try to perform a complete substitution of all real variables, resulting in a Tensor. if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'): # Form the concatenated value. value = BlockVector(batch_shape + (event_size, )) for k, i in slices: if k in values: value[..., i] = values[k] value = value.as_tensor() # Evaluate the non-normalized log density. result = _vv(value, info_vec - 0.5 * _mv(precision, value)) result = Tensor(result, int_inputs) assert result.output == Real return Subs(result, remaining_subs) if remaining_subs else result # Perform a partial substution of a subset of real variables, resulting in a Joint. # We split real inputs into two sets: a for the preserved and b for the substituted. b = frozenset(k for k, v in subs.items()) a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b) prec_aa = ops.cat( -2, *[ ops.cat( -1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in a]) for k1, i1 in slices if k1 in a ]) prec_ab = ops.cat( -2, *[ ops.cat( -1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b]) for k1, i1 in slices if k1 in a ]) prec_bb = ops.cat( -2, *[ ops.cat( -1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b]) for k1, i1 in slices if k1 in b ]) info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a]) info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b]) value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b]) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:]) inputs = int_inputs.copy() for k, d in self.inputs.items(): if k not in subs: inputs[k] = d result = Gaussian(info_vec, precision, inputs) + Tensor( log_scale, int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
def test_bart(analytic_kl): global call_count call_count = 0 with interpretation(reflect): q = Independent( Independent( Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor( [[ -0.6077086925506592, -1.1546266078948975, -0.7021151781082153, -0.5303535461425781, -0.6365622282028198, -1.2423288822174072, -0.9941254258155823, -0.6287292242050171 ], [ -0.6987162828445435, -1.0875964164733887, -0.7337473630905151, -0.4713417589664459, -0.6674002408981323, -1.2478348016738892, -0.8939017057418823, -0.5238542556762695 ]], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ), 'real'), Gaussian( torch.tensor([ [[-0.3536059558391571], [-0.21779225766658783], [0.2840439975261688], [0.4531521499156952], [-0.1220812276005745], [-0.05519985035061836], [0.10932210087776184], [0.6656699776649475]], [[-0.39107921719551086], [ -0.20241987705230713 ], [0.2170514464378357], [0.4500560462474823], [0.27945515513420105], [-0.0490039587020874], [-0.06399798393249512], [0.846565842628479]] ], dtype=torch.float32), # noqa torch.tensor([ [[[1.984686255455017]], [[0.6699360013008118]], [[1.6215802431106567]], [[2.372016668319702]], [[1.77385413646698]], [[0.526767373085022]], [[0.8722561597824097]], [[2.1879124641418457]] ], [[[1.6996612548828125]], [[ 0.7535632252693176 ]], [[1.4946647882461548]], [[2.642792224884033]], [[1.7301604747772217]], [[0.5203893780708313]], [[1.055436372756958]], [[2.8370864391326904]]] ], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ( 'value_b1', reals(), ), )), )), 'gate_rate_b3', '_event_1_b2', 'value_b1'), 'gate_rate_t', 'time_b4', 'gate_rate_b3') p_prior = Contraction( ops.logaddexp, ops.add, frozenset({'state(time=1)_b11', 'state_b10'}), ( MarkovProduct( ops.logaddexp, ops.add, Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor(2.7672932147979736, dtype=torch.float32), (), 'real'), Gaussian( torch.tensor([-0.0, -0.0, 0.0, 0.0], dtype=torch.float32), torch.tensor([[ 98.01002502441406, 0.0, -99.0000228881836, -0.0 ], [ 0.0, 98.01002502441406, -0.0, -99.0000228881836 ], [ -99.0000228881836, -0.0, 100.0000228881836, 0.0 ], [ -0.0, -99.0000228881836, 0.0, 100.0000228881836 ]], dtype=torch.float32), # noqa ( ( 'state_b7', reals(2, ), ), ( 'state(time=1)_b8', reals(2, ), ), )), Subs( AffineNormal( Tensor( torch.tensor( [[ 0.03488487750291824, 0.07356668263673782, 0.19946961104869843, 0.5386509299278259, -0.708323061466217, 0.24411526322364807, -0.20855577290058136, -0.2421337217092514 ], [ 0.41762110590934753, 0.5272183418273926, -0.49835553765296936, -0.0363837406039238, -0.0005282597267068923, 0.2704298794269562, -0.155222088098526, -0.44802337884902954 ]], dtype=torch.float32), # noqa (), 'real'), Tensor( torch.tensor( [[ -0.003566693514585495, -0.2848514914512634, 0.037103548645973206, 0.12648648023605347, -0.18501518666744232, -0.20899859070777893, 0.04121830314397812, 0.0054807960987091064 ], [ 0.0021788496524095535, -0.18700894713401794, 0.08187370002269745, 0.13554862141609192, -0.10477752983570099, -0.20848378539085388, -0.01393645629286766, 0.011670656502246857 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Tensor( torch.tensor( [[ 0.5974780917167664, 0.864071786403656, 1.0236268043518066, 0.7147538065910339, 0.7423890233039856, 0.9462157487869263, 1.2132389545440674, 1.0596832036972046 ], [ 0.5787821412086487, 0.9178534150123596, 0.9074794054031372, 0.6600189208984375, 0.8473222255706787, 0.8426999449729919, 1.194266438484192, 1.0471148490905762 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Variable('state(time=1)_b8', reals(2, )), Variable('gate_rate_b6', reals(8, ))), (( 'gate_rate_b6', Binary( ops.GetitemOp(0), Variable('gate_rate_t', reals(2, 8)), Variable('time_b9', bint(2))), ), )), )), Variable('time_b9', bint(2)), frozenset({('state_b7', 'state(time=1)_b8')}), frozenset({('state(time=1)_b8', 'state(time=1)_b11'), ('state_b7', 'state_b10')})), # noqa Subs( dist.MultivariateNormal( Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32), (), 'real'), Tensor( torch.tensor([[10.0, 0.0], [0.0, 10.0]], dtype=torch.float32), (), 'real'), Variable('value_b5', reals(2, ))), (( 'value_b5', Variable('state_b10', reals(2, )), ), )), )) p_likelihood = Contraction( ops.add, ops.nullop, frozenset({'time_b17', 'destin_b16', 'origin_b15'}), ( Contraction( ops.logaddexp, ops.add, frozenset({'gated_b14'}), ( dist.Categorical( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_0, reals(2, 2, 2), (Variable('gate_rate_b12', reals(8, )), )), (( 'gate_rate_b12', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable('time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Variable('gated_b14', bint(2))), Stack( 'gated_b14', ( dist.Poisson( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_1, reals(2, 2), (Variable( 'gate_rate_b13', reals(8, )), )), (( 'gate_rate_b13', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable( 'time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), dist.Delta( Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), )), )), )) if analytic_kl: exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t") with interpretation(monte_carlo): approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t") elbo = exact_part + approx_part else: p = p_prior + p_likelihood with interpretation(monte_carlo): elbo = Integrate(q, p - q, "gate_rate_t") assert isinstance(elbo, Tensor), elbo.pretty() assert call_count == 1
def eager_integrate(delta, integrand, reduced_vars): assert delta.name in reduced_vars integrand = Subs(integrand, ((delta.name, delta.point), )) log_measure = delta.log_density reduced_vars -= frozenset([delta.name]) return Integrate(log_measure, integrand, reduced_vars)
def eager_add(op, joint, other): # Update with a delayed discrete random variable. subs = tuple((d.name, d.point) for d in joint.deltas if d in other.inputs) if subs: return joint + Subs(other, subs) return Joint(joint.deltas, joint.discrete + other, joint.gaussian)
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple((k, materialize(to_funsor(v, self.inputs[k]))) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Variables are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple((k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') if not (var_subs or int_subs or real_subs): return reflect(Subs, self, lazy_subs) # First perform any variable substitutions. if var_subs: rename = {k: v.name for k, v in var_subs} inputs = OrderedDict( (rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.loc, self.precision, inputs) return Subs(var_result, int_subs + real_subs + lazy_subs) # Next perform any integer substitution, i.e. slicing into a batch. if int_subs: int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.loc, self.precision] funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, real_subs + lazy_subs) # Try to perform a complete substitution of all real variables, resulting in a Tensor. real_subs = OrderedDict(subs) assert real_subs and not int_subs if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'): # Broadcast all component tensors. int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.loc, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(real_subs.values()) inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (loc, precision), values = tensors[:2], tensors[2:] # Form the concatenated value. offsets, event_size = _compute_offsets(self.inputs) value = BlockVector(batch_shape + (event_size, )) for k, value_k in zip(real_subs, values): offset = offsets[k] value_k = value_k.reshape(value_k.shape[:batch_dim] + (-1, )) if not torch._C._get_tracing_state(): assert value_k.size(-1) == self.inputs[k].num_elements value_k = value_k.expand(batch_shape + value_k.shape[-1:]) value[..., offset:offset + self.inputs[k].num_elements] = value_k value = value.as_tensor() # Evaluate the non-normalized log density. result = -0.5 * _vmv(precision, value - loc) result = Tensor(result, inputs) assert result.output == reals() return Subs(result, lazy_subs) # Perform a partial substution of a subset of real variables, resulting in a Joint. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf raise NotImplementedError( 'TODO implement partial substitution of real variables')
def normalize_fuse_subs(arg, subs): # a(b)(c) -> a(b(c), c) arg_subs = tuple(arg.subs.items()) if isinstance(arg.subs, OrderedDict) else arg.subs new_subs = subs + tuple((k, Subs(v, subs)) for k, v in arg_subs) return Subs(arg.arg, new_subs)
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple( (k, v if isinstance(v, (Variable, Slice)) else materialize(v)) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Variables are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple( (k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable, Slice))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') if not (var_subs or int_subs or real_subs): return reflect(Subs, self, lazy_subs) # First perform any variable substitutions. if var_subs: rename = {k: v.name for k, v in var_subs} inputs = OrderedDict( (rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.info_vec, self.precision, inputs) return Subs(var_result, int_subs + real_subs + lazy_subs) # Next perform any integer substitution, i.e. slicing into a batch. if int_subs: int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.info_vec, self.precision] funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, real_subs + lazy_subs) # Broadcast all component tensors. real_subs = OrderedDict(subs) assert real_subs and not int_subs int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.info_vec, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(real_subs.values()) int_inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (info_vec, precision), values = tensors[:2], tensors[2:] offsets, event_size = _compute_offsets(self.inputs) slices = [(k, slice(offset, offset + self.inputs[k].num_elements)) for k, offset in offsets.items()] # Expand all substituted values. values = OrderedDict(zip(real_subs, values)) for k, value in values.items(): value = value.reshape(value.shape[:batch_dim] + (-1, )) if not torch._C._get_tracing_state(): assert value.size(-1) == self.inputs[k].num_elements values[k] = value.expand(batch_shape + value.shape[-1:]) # Try to perform a complete substitution of all real variables, resulting in a Tensor. if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'): # Form the concatenated value. value = BlockVector(batch_shape + (event_size, )) for k, i in slices: if k in values: value[..., i] = values[k] value = value.as_tensor() # Evaluate the non-normalized log density. result = _vv(value, info_vec - 0.5 * _mv(precision, value)) result = Tensor(result, int_inputs) assert result.output == reals() return Subs(result, lazy_subs) # Perform a partial substution of a subset of real variables, resulting in a Joint. # We split real inputs into two sets: a for the preserved and b for the substituted. b = frozenset(k for k, v in real_subs.items()) a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b) prec_aa = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_ab = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_bb = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in b ], dim=-2) info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a], dim=-1) info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b], dim=-1) value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = prec_aa.expand(info_vec.shape + (-1, )) inputs = int_inputs.copy() for k, d in self.inputs.items(): if k not in real_subs: inputs[k] = d return Gaussian(info_vec, precision, inputs) + Tensor( log_scale, int_inputs)