Example #1
0
def MakeIAFBijectorFn(
    num_dims,
    num_stages,
    hidden_layers,
    scale=1.0,
    activation=tf.nn.elu,
    train=False,
    dropout_rate=0.0,
    learn_scale=False,
):
    swap = tfb.Permute(permutation=np.arange(num_dims - 1, -1, -1))

    bijectors = []
    for i in range(num_stages):
        _iaf_template = utils.DenseAR(
            "iaf_%d" % i,
            hidden_layers=hidden_layers,
            activation=activation,
            kernel_initializer=utils.L2HMCInitializer(factor=0.01),
            dropout_rate=dropout_rate,
            train=train)

        def iaf_template(x, t=_iaf_template):
            # TODO: I don't understand why the shape gets lost.
            x.set_shape([None, num_dims])
            return t(x)

        bijectors.append(
            tfb.Invert(
                tfb.MaskedAutoregressiveFlow(
                    shift_and_log_scale_fn=iaf_template)))
        bijectors.append(swap)
    # Drop the last swap.
    bijectors = bijectors[:-1]
    if learn_scale:
        scale = tf.nn.softplus(
            tf.get_variable("isp_global_scale",
                            initializer=tfp.math.softplus_inverse(scale)))
    bijectors.append(tfb.Scale(scale=scale))

    bijector = tfb.Chain(bijectors)

    # Construct the variables
    _ = bijector.forward(tf.zeros([1, num_dims]))

    return bijector
Example #2
0
def MakeRNVPBijectorFn(num_dims,
                       num_stages,
                       hidden_layers,
                       scale=1.0,
                       activation=tf.nn.elu,
                       train=False,
                       learn_scale=False,
                       dropout_rate=0.0):
    swap = tfb.Permute(permutation=np.arange(num_dims - 1, -1, -1))

    bijectors = []
    for i in range(num_stages):
        _rnvp_template = utils.DenseShiftLogScale(
            "rnvp_%d" % i,
            hidden_layers=hidden_layers,
            activation=activation,
            kernel_initializer=utils.L2HMCInitializer(factor=0.01),
            dropout_rate=dropout_rate,
            train=train)

        def rnvp_template(x, output_units, t=_rnvp_template):
            # TODO: I don't understand why the shape gets lost.
            x.set_shape([None, num_dims - output_units])
            return t(x, output_units)

        bijectors.append(
            tfb.RealNVP(num_masked=num_dims // 2,
                        shift_and_log_scale_fn=rnvp_template))
        bijectors.append(swap)
    # Drop the last swap.
    bijectors = bijectors[:-1]
    if learn_scale:
        scale = tf.nn.softplus(
            tf.get_variable("isp_global_scale",
                            initializer=tfp.math.softplus_inverse(scale)))
    bijectors.append(tfb.Scale(scale=scale))

    bijector = tfb.Chain(bijectors)

    # Construct the variables
    _ = bijector.forward(tf.zeros([1, num_dims]))

    return bijector