def MakeIAFBijectorFn( num_dims, num_stages, hidden_layers, scale=1.0, activation=tf.nn.elu, train=False, dropout_rate=0.0, learn_scale=False, ): swap = tfb.Permute(permutation=np.arange(num_dims - 1, -1, -1)) bijectors = [] for i in range(num_stages): _iaf_template = utils.DenseAR( "iaf_%d" % i, hidden_layers=hidden_layers, activation=activation, kernel_initializer=utils.L2HMCInitializer(factor=0.01), dropout_rate=dropout_rate, train=train) def iaf_template(x, t=_iaf_template): # TODO: I don't understand why the shape gets lost. x.set_shape([None, num_dims]) return t(x) bijectors.append( tfb.Invert( tfb.MaskedAutoregressiveFlow( shift_and_log_scale_fn=iaf_template))) bijectors.append(swap) # Drop the last swap. bijectors = bijectors[:-1] if learn_scale: scale = tf.nn.softplus( tf.get_variable("isp_global_scale", initializer=tfp.math.softplus_inverse(scale))) bijectors.append(tfb.Scale(scale=scale)) bijector = tfb.Chain(bijectors) # Construct the variables _ = bijector.forward(tf.zeros([1, num_dims])) return bijector
def MakeRNVPBijectorFn(num_dims, num_stages, hidden_layers, scale=1.0, activation=tf.nn.elu, train=False, learn_scale=False, dropout_rate=0.0): swap = tfb.Permute(permutation=np.arange(num_dims - 1, -1, -1)) bijectors = [] for i in range(num_stages): _rnvp_template = utils.DenseShiftLogScale( "rnvp_%d" % i, hidden_layers=hidden_layers, activation=activation, kernel_initializer=utils.L2HMCInitializer(factor=0.01), dropout_rate=dropout_rate, train=train) def rnvp_template(x, output_units, t=_rnvp_template): # TODO: I don't understand why the shape gets lost. x.set_shape([None, num_dims - output_units]) return t(x, output_units) bijectors.append( tfb.RealNVP(num_masked=num_dims // 2, shift_and_log_scale_fn=rnvp_template)) bijectors.append(swap) # Drop the last swap. bijectors = bijectors[:-1] if learn_scale: scale = tf.nn.softplus( tf.get_variable("isp_global_scale", initializer=tfp.math.softplus_inverse(scale))) bijectors.append(tfb.Scale(scale=scale)) bijector = tfb.Chain(bijectors) # Construct the variables _ = bijector.forward(tf.zeros([1, num_dims])) return bijector