def testAlignedEventDims(self):
     x = [tf.ones((3, ), dtype=tf.float32), tf.ones((2, 2), tf.float32)]
     op = self.build_operator()
     bijector = tfb.ScaleMatvecLinearOperatorBlock(op, validate_args=True)
     with self.assertRaisesRegexp(ValueError, 'equal for all elements'):
         self.evaluate(
             bijector.forward_log_det_jacobian(x, event_ndims=[1, 2]))
    def testOperatorBroadcast(self):
        x = [
            tf.ones((1, 1, 1, 4), dtype=tf.float32),
            tf.ones((1, 1, 1, 3), dtype=tf.float32)
        ]
        op = self.build_batched_operator()
        bijector = tfb.ScaleMatvecLinearOperatorBlock(op, validate_args=True)

        self.assertAllEqual(
            self.evaluate(
                tf.shape(bijector.forward_log_det_jacobian(x, [1, 1]))),
            self.evaluate(op.batch_shape_tensor()))

        # Broadcasting of event shape components with batched LinearOperators
        # raises.
        with self.assertRaisesRegexp(ValueError,
                                     'bijector parameters changes'):
            self.evaluate(
                bijector.forward_log_det_jacobian(x, event_ndims=[2, 2]))

        # Broadcasting of event shape components with batched LinearOperators
        # raises for `ldj_reduce_ndims > batch_ndims`.
        with self.assertRaisesRegexp(ValueError,
                                     'bijector parameters changes'):
            self.evaluate(
                bijector.forward_log_det_jacobian(x, event_ndims=[3, 3]))
    def testBijector(self):
        x = [
            np.array([4., 3., 3.]).astype(np.float32),
            np.array([0., -5.]).astype(np.float32)
        ]
        op = self.build_operator()
        y = self.evaluate(op.matvec(x))
        ldj = self.evaluate(op.log_abs_determinant())

        bijector = tfb.ScaleMatvecLinearOperatorBlock(scale=op,
                                                      validate_args=True)
        self.assertStartsWith(bijector.name,
                              'scale_matvec_linear_operator_block')

        f_x = bijector.forward(x)
        self.assertAllClose(y, self.evaluate(f_x))

        inv_y = self.evaluate(bijector.inverse(y))
        self.assertAllClose(x, inv_y)

        # Calling `inverse` on an output of `bijector.forward` (that is equal to
        # `y`) is a cache hit and returns the original, non-broadcasted input `x`.
        for x_, z_ in zip(x, bijector.inverse(f_x)):
            self.assertIs(x_, z_)

        ldj_ = self.evaluate(
            bijector.forward_log_det_jacobian(x, event_ndims=[1, 1]))
        self.assertAllClose(ldj, ldj_)
        self.assertEmpty(ldj_.shape)

        self.assertAllClose(
            ldj_,
            self.evaluate(
                -bijector.inverse_log_det_jacobian(y, event_ndims=[1, 1])))
Exemplo n.º 4
0
def get_mean_field_approximation(
    model: tfd.JointDistribution,
    init_loc=None,
    dtype=DEFAULT_FLOAT_DTYPE_TF,
    joint_bijector_func: tp.Callable[
        [tfd.JointDistribution],
        tfb.Composition] = get_default_event_space_bijector,
    event_shape_fn: tp.Callable[[tfd.JointDistribution],
                                object] = event_shape_fn,
) -> tfd.Distribution:
    event_shape = event_shape_fn(model)
    flat_event_shape = tf.nest.flatten(event_shape)
    flat_event_size = tf.nest.map_structure(tf.reduce_prod, flat_event_shape)
    operator_classes = get_mean_field_operator_classes(flat_event_size)
    linear_operator_block = build_trainable_linear_operator_block(
        operator_classes, flat_event_size, dtype=dtype)
    scale_bijector = tfb.ScaleMatvecLinearOperatorBlock(linear_operator_block)

    if init_loc is None:
        init_loc = tf.nest.map_structure(lambda _: None, flat_event_shape)
    else:
        init_loc = defaultdict(lambda: None, init_loc)  # TODO: Handle nesting

    event_space_bijector = joint_bijector_func(model)
    unflatten_bijector = tfb.Restructure(
        tf.nest.pack_sequence_as(event_shape, range(len(flat_event_shape))))
    reshape_bijector = tfb.JointMap(
        tf.nest.map_structure(tfb.Reshape, flat_event_shape))

    init_loc_unconstrained = joint_inverse_with_nones(event_space_bijector,
                                                      init_loc)
    init_loc_flat = unflatten_bijector.inverse(init_loc_unconstrained)
    init_loc_1d = joint_inverse_with_nones(reshape_bijector, init_loc_flat)
    loc_bijector = get_trainable_shift_bijector(flat_event_size,
                                                init_loc_1d,
                                                dtype=dtype)

    base_standard_dist = get_base_distribution(flat_event_size, dtype=dtype)
    chain_bijector = tfb.Chain([
        event_space_bijector,
        unflatten_bijector,
        reshape_bijector,
        loc_bijector,
        scale_bijector,
    ])
    distribution = tfd.TransformedDistribution(base_standard_dist,
                                               chain_bijector)
    return distribution
Exemplo n.º 5
0
  def testEventShapeBroadcast(self):
    op = self.build_operator()
    bijector = tfb.ScaleMatvecLinearOperatorBlock(
        op, validate_args=True)
    x = [tf.broadcast_to(tf.constant(1., dtype=tf.float32), [2, 3, 3]),
         tf.broadcast_to(tf.constant(2., dtype=tf.float32), [2, 1, 2])]

    # Forward/inverse event shape methods return the correct value.
    self.assertAllEqual(
        self.evaluate(bijector.forward_event_shape_tensor(
            [tf.shape(x_) for x_ in x])),
        [self.evaluate(tf.shape(y_)) for y_ in bijector.forward(x)])
    self.assertAllEqual(
        bijector.inverse_event_shape([x_.shape for x_ in x]),
        [y_.shape for y_ in bijector.inverse(x)])

    # Broadcasting of inputs within `ldj_reduce_shape` raises.
    with self.assertRaisesRegexp(ValueError, 'left of `min_event_ndims`'):
      self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=[2, 2]))