def q_net(x, observed=None, n_samples=None, tau=None, is_training=True): use_concrete = config.use_concrete_distribution and tau is not None logging.info('q_net builder: %r', locals()) net = BayesianNet(observed=observed) # compute the hidden features with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_x = tf.to_float(x) h_x = dense(h_x, 500) h_x = dense(h_x, 500) # sample y ~ q(y|x) y_logits = dense(h_x, config.n_clusters, name='y_logits') if use_concrete: y = net.add('y', ExpConcrete(tau, y_logits), is_reparameterized=True, n_samples=n_samples) y_one_hot = tf.exp(y) else: y = net.add('y', Categorical(y_logits), n_samples=n_samples) y_one_hot = tf.one_hot(y, config.n_clusters, dtype=tf.float32) # sample z ~ q(z|y,x) with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): if config.mean_field_assumption_for_q: # by mean-field-assumption we let q(z|y,x) = q(z|x) h_z, s1, s2 = flatten(h_x, 2) z_n_samples = n_samples else: if n_samples is not None: h_z = tf.concat([ tf.tile(tf.reshape(h_x, [1, -1, 500]), tf.stack([n_samples, 1, 1])), y_one_hot ], axis=-1) else: h_z = tf.concat([h_x, y_one_hot], axis=-1) h_z, s1, s2 = flatten(h_z, 2) h_z = dense(h_z, 500) z_n_samples = None z_mean = dense(h_z, config.z_dim, name='z_mean') z_logstd = dense(h_z, config.z_dim, name='z_logstd') z = net.add('z', Normal(mean=unflatten(z_mean, s1, s2), logstd=unflatten(z_logstd, s1, s2), is_reparameterized=use_concrete), n_samples=z_n_samples, group_ndims=1) return net
def p_net(observed=None, n_z=None, is_training=True): logging.info('p_net builder: %r', locals()) net = BayesianNet(observed=observed) # sample z ~ p(z) z = net.add('z', Normal(mean=tf.zeros([1, config.z_dim]), logstd=tf.zeros([1, config.z_dim])), group_ndims=1, n_samples=n_z) # compute the hidden features with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_x, s1, s2 = flatten(z, 2) h_x = dense(h_x, 500) h_x = dense(h_x, 500) # sample x ~ p(x|z) x_logits = unflatten(dense(h_x, config.x_dim, name='x_logits'), s1, s2) x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1) return net
def reinforce_baseline_net(x): x, s1, s2 = flatten(tf.to_float(x), 2) with arg_scope([dense], kernel_regularizer=l2_regularizer(config.l2_reg), activation_fn=tf.nn.leaky_relu): h_x = dense(x, 500) h_x = unflatten(tf.reshape(dense(h_x, 1), [-1]), s1, s2) return h_x
def h_for_p_x(z, is_training): with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_z, s1, s2 = flatten(z, 2) h_z = dense(h_z, 500) h_z = dense(h_z, 500) return { 'logits': unflatten(dense(h_z, config.x_dim, name='x_logits'), s1, s2) }
def p_net(observed=None, n_y=None, n_z=None, tau=None, is_training=True, n_samples=None): if n_samples is not None: warnings.warn('`n_samples` is deprecated, use `n_y` instead.') n_y = n_samples use_concrete = config.use_concrete_distribution and tau is not None logging.info('p_net builder: %r', locals()) net = BayesianNet(observed=observed) # sample y if use_concrete: y = net.add('y', ExpConcrete(tau, tf.zeros([1, config.n_clusters])), n_samples=n_y, is_reparameterized=True) else: y = net.add('y', Categorical(tf.zeros([1, config.n_clusters])), n_samples=n_y) # sample z ~ p(z|y) z = net.add('z', gaussian_mixture_prior(y, config.z_dim, config.n_clusters, use_concrete=use_concrete), group_ndims=1, n_samples=n_z, is_reparameterized=use_concrete) # compute the hidden features for x with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_x, s1, s2 = flatten(z, 2) h_x = dense(h_x, 500) h_x = dense(h_x, 500) # sample x ~ p(x|z) x_logits = unflatten(dense(h_x, config.x_dim, name='x_logits'), s1, s2) x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1) return net
def h_for_p_x(z, is_training, channels_last): with arg_scope([deconv_resnet_block], shortcut_kernel_size=config.shortcut_kernel_size, activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg), channels_last=channels_last): h_z, s1, s2 = flatten(z, 2) h_z = tf.reshape(dense(h_z, 64 * 7 * 7), [-1, 7, 7, 64] if channels_last else [-1, 64, 7, 7]) h_z = deconv_resnet_block(h_z, 64) # output: (64, 7, 7) h_z = deconv_resnet_block(h_z, 32, strides=2) # output: (32, 14, 14) h_z = deconv_resnet_block(h_z, 32) # output: (32, 14, 14) h_z = deconv_resnet_block(h_z, 16, strides=2) # output: (16, 28, 28) h_z = conv2d( h_z, 1, (1, 1), padding='same', name='feature_map_to_pixel', channels_last=channels_last) # output: (1, 28, 28) h_z = tf.reshape(h_z, [-1, config.x_dim]) x_logits = unflatten(h_z, s1, s2) return {'logits': x_logits}
def planar_normalizing_flow( z, log_qz, w_initializer=tf.random_normal_initializer(0., 0.01), b_initializer=tf.zeros_initializer(), u_initializer=tf.random_normal_initializer(0., 0.01), w_regularizer=None, b_regularizer=None, u_regularizer=None, trainable=True, name='planar_normalizing_flow', scope=None): """ Apply Planar Normalizing Flow transformation along the last axis of `z`. .. math :: f(z_t) = z_{t-1} + h(z_{t-1} * w_t + b_t) * u_t with activation function `tanh` as well as the invertibility trick from (Danilo 2016). Args: z: A N-D (N>=2) `float32` Tensor, the samples to be transformed. log_qz: A (N-1)-D `float32` Tensor, the log-probabilities of the samples. The shape should be the same as the first (N-1) dimensions of `z`. w_initializer: The initializer for parameter `w`. b_initializer: The initializer for parameter `b`. u_initializer: The initializer for parameter `u`. w_regularizer: The regularizer for parameter `w`, optional. b_regularizer: The regularizer for parameter `b`, optional. u_regularizer: The regularizer for parameter `u`, optional. trainable (bool): Whether or not the parameters are trainable? (default :obj:`True`) name: The default name for the variable scope. (default "planar_normalizing_flow") scope: The variable scope, will override `name`. Returns: (tf.Tensor, tf.Tensor): The transformed samples, and the transformed log-probability. """ # check `z` and `log_qz` z = tf.convert_to_tensor(z) log_qz = tf.convert_to_tensor(log_qz) dtype = z.dtype.base_dtype if not dtype.is_floating: raise TypeError('`z` is expected to be a float tensor, but got ' 'dtype {}.'.format(z.dtype)) if z.get_shape() is None or len(z.get_shape()) < 2: raise ValueError('The rank of `z` must be fixed and must be at least ' '2-dimensional.') n_units = int_shape(z)[-1] if n_units is None: raise ValueError('The last dimension of `z` must be deterministic.') if int_shape(z)[:-1] != int_shape(log_qz): raise ValueError( 'The static shape mismatch between `z` and `log_qz`: {} vs {}'. format(z.get_shape()[:-1], log_qz.get_shape()) ) # derive the normalizing flow with tf.variable_scope(scope, default_name=name): # create variables w = tf.get_variable( 'w', shape=[1, n_units], dtype=dtype, initializer=w_initializer, regularizer=w_regularizer, trainable=trainable ) b = tf.get_variable( 'b', shape=[1], dtype=dtype, initializer=b_initializer, regularizer=b_regularizer, trainable=trainable ) u = tf.get_variable( 'u', shape=[1, n_units], dtype=dtype, initializer=u_initializer, regularizer=u_regularizer, trainable=trainable ) # flatten z for better performance z, s1, s2 = flatten(z, 2) # z.shape == [?, n_units] # enforce invertible mapping wu = tf.matmul(w, u, transpose_b=True) # shape == [1] u_hat = u + (-1 + tf.nn.softplus(wu) - wu) * \ w / tf.reduce_sum(tf.square(w)) # shape == [1, n_units] # compute f(z) wzb = tf.matmul(z, w, transpose_b=True) + b # shape == [?, 1] tanh_wzb = tf.tanh(wzb) # shape == [?, 1] fz = z + u_hat * tanh_wzb # shape == [?, n_units] fz = unflatten(fz, s1, s2) # compute log(det|df/dz|) grad = 1. - tf.square(tanh_wzb) # dtanh(x)/dx = 1 - tanh^2(x) phi = grad * w # shape == [?, n_units] u_phi = tf.matmul(phi, u_hat, transpose_b=True) # shape == [?, 1] det_jac = 1. + u_phi # shape == [?, 1] log_det_jac = tf.log(tf.abs(det_jac)) # shape == [?, 1] # compute log q(f(z)) log_q_fz = log_qz - \ unflatten(tf.squeeze(log_det_jac, -1), s1, s2) # now returns the transformed sample and log-prob return fz, log_q_fz
def gaussian_mixture_prior(y, z_dim, n_clusters, use_concrete): if config.p_z_given_y == 'unit': if None not in int_shape(y): data_shape = int_shape(y) if use_concrete: data_shape = data_shape[:-1] z_shape = data_shape + (z_dim, ) else: data_shape = tf.shape(y) if use_concrete: data_shape = data_shape[:-1] z_shape = tf.concat([data_shape, [z_dim]], axis=0) return Normal(mean=tf.zeros(z_shape), std=tf.ones(z_shape)) elif config.p_z_given_y == 'learnt': # derive the learnt z_mean prior_mean = tf.get_variable( 'z_prior_mean', dtype=tf.float32, shape=[n_clusters, z_dim], initializer=tf.random_normal_initializer()) if use_concrete: y, s1, s2 = flatten(tf.exp(y), 2) z_mean = unflatten(tf.matmul(y, prior_mean), s1, s2) else: z_mean = tf.nn.embedding_lookup(prior_mean, y) # derive the learnt z_std z_logstd = z_std = None if config.p_z_given_y_std == 'one': z_logstd = tf.zeros_like(z_mean) else: prior_std_or_logstd = tf.get_variable( 'z_prior_std_or_logstd', dtype=tf.float32, shape=[n_clusters, z_dim], initializer=tf.zeros_initializer()) if use_concrete: z_std_or_logstd = unflatten(tf.matmul(y, prior_std_or_logstd), s1, s2) else: z_std_or_logstd = tf.nn.embedding_lookup( prior_std_or_logstd, y) if config.p_z_given_y_std == 'one_plus_softplus_std': z_std = 1. + tf.nn.softplus(z_std_or_logstd) elif config.p_z_given_y_std == 'softplus_logstd': z_logstd = tf.nn.softplus(z_std_or_logstd) elif config.p_z_given_y_std == 'unbound_logstd': z_logstd = z_std_or_logstd else: raise ValueError( 'Unexpected value for config `p_z_given_y_std`: {}'.format( config.p_z_given_y_std)) return Normal(mean=z_mean, std=z_std, logstd=z_logstd) else: raise ValueError( 'Unexpected value for config `p_z_given_y`: {}'.format( config.p_z_given_y))