def testAlignedEventDims(self): x = [tf.ones((3, ), dtype=tf.float32), tf.ones((2, 2), tf.float32)] op = self.build_operator() bijector = tfb.ScaleMatvecLinearOperatorBlock(op, validate_args=True) with self.assertRaisesRegexp(ValueError, 'equal for all elements'): self.evaluate( bijector.forward_log_det_jacobian(x, event_ndims=[1, 2]))
def testOperatorBroadcast(self): x = [ tf.ones((1, 1, 1, 4), dtype=tf.float32), tf.ones((1, 1, 1, 3), dtype=tf.float32) ] op = self.build_batched_operator() bijector = tfb.ScaleMatvecLinearOperatorBlock(op, validate_args=True) self.assertAllEqual( self.evaluate( tf.shape(bijector.forward_log_det_jacobian(x, [1, 1]))), self.evaluate(op.batch_shape_tensor())) # Broadcasting of event shape components with batched LinearOperators # raises. with self.assertRaisesRegexp(ValueError, 'bijector parameters changes'): self.evaluate( bijector.forward_log_det_jacobian(x, event_ndims=[2, 2])) # Broadcasting of event shape components with batched LinearOperators # raises for `ldj_reduce_ndims > batch_ndims`. with self.assertRaisesRegexp(ValueError, 'bijector parameters changes'): self.evaluate( bijector.forward_log_det_jacobian(x, event_ndims=[3, 3]))
def testBijector(self): x = [ np.array([4., 3., 3.]).astype(np.float32), np.array([0., -5.]).astype(np.float32) ] op = self.build_operator() y = self.evaluate(op.matvec(x)) ldj = self.evaluate(op.log_abs_determinant()) bijector = tfb.ScaleMatvecLinearOperatorBlock(scale=op, validate_args=True) self.assertStartsWith(bijector.name, 'scale_matvec_linear_operator_block') f_x = bijector.forward(x) self.assertAllClose(y, self.evaluate(f_x)) inv_y = self.evaluate(bijector.inverse(y)) self.assertAllClose(x, inv_y) # Calling `inverse` on an output of `bijector.forward` (that is equal to # `y`) is a cache hit and returns the original, non-broadcasted input `x`. for x_, z_ in zip(x, bijector.inverse(f_x)): self.assertIs(x_, z_) ldj_ = self.evaluate( bijector.forward_log_det_jacobian(x, event_ndims=[1, 1])) self.assertAllClose(ldj, ldj_) self.assertEmpty(ldj_.shape) self.assertAllClose( ldj_, self.evaluate( -bijector.inverse_log_det_jacobian(y, event_ndims=[1, 1])))
def get_mean_field_approximation( model: tfd.JointDistribution, init_loc=None, dtype=DEFAULT_FLOAT_DTYPE_TF, joint_bijector_func: tp.Callable[ [tfd.JointDistribution], tfb.Composition] = get_default_event_space_bijector, event_shape_fn: tp.Callable[[tfd.JointDistribution], object] = event_shape_fn, ) -> tfd.Distribution: event_shape = event_shape_fn(model) flat_event_shape = tf.nest.flatten(event_shape) flat_event_size = tf.nest.map_structure(tf.reduce_prod, flat_event_shape) operator_classes = get_mean_field_operator_classes(flat_event_size) linear_operator_block = build_trainable_linear_operator_block( operator_classes, flat_event_size, dtype=dtype) scale_bijector = tfb.ScaleMatvecLinearOperatorBlock(linear_operator_block) if init_loc is None: init_loc = tf.nest.map_structure(lambda _: None, flat_event_shape) else: init_loc = defaultdict(lambda: None, init_loc) # TODO: Handle nesting event_space_bijector = joint_bijector_func(model) unflatten_bijector = tfb.Restructure( tf.nest.pack_sequence_as(event_shape, range(len(flat_event_shape)))) reshape_bijector = tfb.JointMap( tf.nest.map_structure(tfb.Reshape, flat_event_shape)) init_loc_unconstrained = joint_inverse_with_nones(event_space_bijector, init_loc) init_loc_flat = unflatten_bijector.inverse(init_loc_unconstrained) init_loc_1d = joint_inverse_with_nones(reshape_bijector, init_loc_flat) loc_bijector = get_trainable_shift_bijector(flat_event_size, init_loc_1d, dtype=dtype) base_standard_dist = get_base_distribution(flat_event_size, dtype=dtype) chain_bijector = tfb.Chain([ event_space_bijector, unflatten_bijector, reshape_bijector, loc_bijector, scale_bijector, ]) distribution = tfd.TransformedDistribution(base_standard_dist, chain_bijector) return distribution
def testEventShapeBroadcast(self): op = self.build_operator() bijector = tfb.ScaleMatvecLinearOperatorBlock( op, validate_args=True) x = [tf.broadcast_to(tf.constant(1., dtype=tf.float32), [2, 3, 3]), tf.broadcast_to(tf.constant(2., dtype=tf.float32), [2, 1, 2])] # Forward/inverse event shape methods return the correct value. self.assertAllEqual( self.evaluate(bijector.forward_event_shape_tensor( [tf.shape(x_) for x_ in x])), [self.evaluate(tf.shape(y_)) for y_ in bijector.forward(x)]) self.assertAllEqual( bijector.inverse_event_shape([x_.shape for x_ in x]), [y_.shape for y_ in bijector.inverse(x)]) # Broadcasting of inputs within `ldj_reduce_shape` raises. with self.assertRaisesRegexp(ValueError, 'left of `min_event_ndims`'): self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=[2, 2]))