def disc_noise(x, keep=0.9, temp=5.0): """Add noise to a input that is either a continuous value or a probability distribution over discrete categorical values Args: x: tf.Tensor a tensor that will have noise added to it, such that the resulting tensor is sound given its definition keep: float the amount of probability mass to keep on the element that is activated teh rest is redistributed evenly to all elements temp: float the temperature of teh gumbel distribution that is used to corrupt the input probabilities x Returns: noisy_x: tf.Tensor a tensor that has noise added to it, which has the interpretation of the original tensor (such as a probability distribution) """ p = tf.ones_like(x) p = p / tf.reduce_sum(p, axis=-1, keepdims=True) p = keep * x + (1.0 - keep) * p return tfpd.RelaxedOneHotCategorical(temp, probs=p).sample()
def sample(self, y, **kwargs): """Generate samples of designs X that have a score y where y is the score that the generator conditions on Args: y: tf.Tensor a batch of scalar scores wherein the generator is trained to produce designs that have score y Returns: x_fake: tf.Tensor a design the generator is trained to sample from a distribution conditioned on the score y achieved by that design """ temp = kwargs.pop("temp", 1.0) z = tf.random.normal([tf.shape(y)[0], self.latent_size]) x = tf.cast(z, tf.float32) y = tf.cast(y, tf.float32) y_embed = self.embed_0(y, **kwargs) x = self.dense_0(tf.concat([x, y_embed], 1), **kwargs) x = tf.nn.leaky_relu(self.ln_0(x), alpha=0.2) x = self.dense_1(tf.concat([x, y_embed], 1), **kwargs) x = tf.nn.leaky_relu(self.ln_1(x), alpha=0.2) x = self.dense_2(tf.concat([x, y_embed], 1), **kwargs) x = tf.nn.leaky_relu(self.ln_2(x), alpha=0.2) x = self.dense_3(tf.concat([x, y_embed], 1), **kwargs) logits = tf.reshape(x, [tf.shape(y)[0], *self.design_shape]) return tfpd.RelaxedOneHotCategorical( temp, logits=tf.math.log_softmax(logits)).sample()
def __call__(self, features): raw_init_std = np.log(np.exp(self._init_std) - 1) x = features for index in range(self._layers): x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) if self._dist == 'tanh_normal': # https://www.desmos.com/calculator/rcmcf5jwe7 x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) mean, std = tf.split(x, 2, -1) mean = self._mean_scale * tf.tanh(mean / self._mean_scale) std = tf.nn.softplus(std + raw_init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == 'onehot': x = self.get(f'hout', tfkl.Dense, self._size)(x) dist = tools.OneHotDist(x) elif self._dist == 'gumbel': x = self.get(f'hout', tfkl.Dense, self._size)(x) dist = tfd.RelaxedOneHotCategorical(temperature=1e-1, logits=x) dist = tools.SampleDist(dist) else: raise NotImplementedError return dist
def surrogate_design_matrix_model_fn(self, qF_logits_var, qF_temperature_var, qF_confounders_loc_var, qF_confounders_softplus_scale_var): qF_test = yield JDCRoot( Independent( tfd.RelaxedOneHotCategorical(temperature=qF_temperature_var, logits=qF_logits_var)))
def variational_model_fn(self): qw = yield JDCRoot( Independent( tfd.Normal(loc=self.qw_loc_var, scale=tf.nn.softplus(self.qw_softplus_scale_var)))) qz_scale = yield JDCRoot( Independent( SoftplusNormal(loc=self.qz_scale_loc_var, scale=tf.nn.softplus( self.qz_scale_softplus_scale_var)))) qF_test = yield JDCRoot( Independent( tfd.RelaxedOneHotCategorical( temperature=self.qF_temperature_var, logits=self.qF_logits_var))) # qz = yield JDCRoot(Independent(tfd.Normal( # loc=qz_loc_var, # scale=tf.nn.softplus(qz_softplus_scale_var)))) qz = yield JDCRoot(Independent(tfd.Deterministic(loc=self.qz_loc_var))) qx_bias = yield JDCRoot( Independent( tfd.Normal(loc=self.qx_bias_loc_var, scale=tf.nn.softplus( self.qx_bias_softplus_scale_var)))) qx_scale_concentration_c = yield JDCRoot( Independent( tfd.Deterministic(loc=tf.nn.softplus( self.qx_scale_concentration_c_loc_var)))) qx_scale_mode_c = yield JDCRoot( Independent( tfd.Deterministic( loc=tf.nn.softplus(self.qx_scale_mode_c_loc_var)))) qx_scale = yield JDCRoot( Independent( SoftplusNormal(loc=self.qx_scale_loc_var, scale=tf.nn.softplus( self.qx_scale_softplus_scale_var)))) if self.use_point_estimates: qx = yield JDCRoot( Independent(tfd.Deterministic(loc=self.qx_loc_var))) else: qx = yield JDCRoot( Independent( tfd.Normal(loc=self.qx_loc_var, scale=tf.nn.softplus( self.qx_softplus_scale_var)))) qrnaseq_reads = yield JDCRoot( Independent(tfd.Deterministic(tf.zeros([self.num_samples]))))
def sample(self, time, outputs, state, name=None): """Returns `sample_id` of shape `[batch_size, vocab_size]`. If `straight_through` is False, this is gumbel softmax distributions over vocabulary with temperature `tau`. If `straight_through` is True, this is one-hot vectors of the greedy samples. """ sample_ids = tf.nn.softmax(outputs / self._tau) sample_ids = tfpd.RelaxedOneHotCategorical( self._tau, logits=outputs).sample() if self._straight_through: size = tf.shape(sample_ids)[-1] sample_ids_hard = tf.cast( tf.one_hot(tf.argmax(sample_ids, -1), size), sample_ids.dtype) sample_ids = tf.stop_gradient(sample_ids_hard - sample_ids) \ + sample_ids return sample_ids
def prepare_output_embed( decoder_output, temperature_starter, decay_rate, decay_steps, ): # [VOCAB x word_dim] embeddings = vocab.get_embeddings() # Extend embedding matrix to support oov tokens unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN) unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0) unk_embeddings = tf.tile(unk_embed, [50, 1]) # [VOCAB+50 x word_dim] embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0) global_step = tf.train.get_global_step() temperature = tf.train.exponential_decay(temperature_starter, global_step, decay_steps, decay_rate, name='temperature') tf.summary.scalar('temper', temperature, ['extra']) # [batch x max_len x VOCAB+50], softmax probabilities outputs = decoder.rnn_output(decoder_output) # substitute values less than 0 for numerical stability outputs = tf.where(tf.less_equal(outputs, 0), tf.ones_like(outputs) * 1e-10, outputs) # convert softmax probabilities to one_hot vectors dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs) # [batch x max_len x VOCAB+50], one_hot outputs_one_hot = dist.sample() # [batch x max_len x word_dim], one_hot^T * embedding_matrix outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot, embeddings_extended) return outputs_embed
def gumbel_softmax_bottleneck(dist, vector_quantizer, temperature=0.5, num_iaf_flows=0, use_transformer_for_iaf_parameters=False, num_samples=1, sum_over_latents=True, summary=True): """Gumbel-Softmax discrete bottleneck. Args: dist: Distances between encoder outputs and codebook entries, to be used as categorical logits. A float Tensor of shape [batch_size, latent_size, code_size]. vector_quantizer: An instance of the VectorQuantizer class. temperature: Temperature parameter used for Gumbel-Softmax distribution. num_iaf_flows: Number of inverse-autoregressive flows to perform. use_transformer_for_iaf_parameters: Whether to use a Transformer instead of a lower-triangular mat-mul to generate IAF parameters. num_samples: Number of categorical samples. sum_over_latents: Whether to sum over latent dimension when computing entropy. summary: Whether to log summary histogram. Returns: one_hot_assignments: Simplex-valued assignments sampled from categorical. neg_q_entropy: Negative entropy of categorical distribution. """ latent_size = dist.shape[1] # TODO(vafa): Consider randomly setting high temperature to help training. one_hot_assignments = tfd.RelaxedOneHotCategorical( temperature=temperature, logits=-dist).sample(num_samples) one_hot_assignments = tf.clip_by_value(one_hot_assignments, 1e-6, 1-1e-6) # Approximate density with multinomial distribution. q_dist = tfd.Multinomial(total_count=1., logits=-dist) neg_q_entropy = q_dist.log_prob(one_hot_assignments) if summary: tf.summary.histogram("neg_q_entropy_0", tf.reshape(tf.reduce_sum(neg_q_entropy, axis=-1), [-1])) # Perform IAF flows for flow_num in range(num_iaf_flows): with tf.variable_scope("iaf_variables", reuse=tf.AUTO_REUSE): # Pad the one_hot_assignments by zeroing out the first latent dimension # and shifting the rest down by one (and removing the last dimension). shifted_codes = shift_assignments(one_hot_assignments) if use_transformer_for_iaf_parameters: unconstrained_scale = iaf_scale_from_transformer( shifted_codes, vector_quantizer.code_size, name=str(flow_num)) else: unconstrained_scale = iaf_scale_from_matmul(shifted_codes, name=str(flow_num)) # Initialize scale bias to be log(e/2 - 1) so initial scale + scale_bias # evaluates to 1. initial_scale_bias = tf.fill([latent_size, vector_quantizer.num_codes], INITIAL_SCALE_BIAS) scale_bias = tf.get_variable("scale_bias_" + str(flow_num), initializer=initial_scale_bias) one_hot_assignments, inverse_log_det_jacobian = iaf_flow( one_hot_assignments, unconstrained_scale, scale_bias, summary=summary) neg_q_entropy += inverse_log_det_jacobian if sum_over_latents: neg_q_entropy = tf.reduce_sum(neg_q_entropy, axis=-1) neg_q_entropy = tf.reduce_mean(neg_q_entropy) return one_hot_assignments, neg_q_entropy
def construct_model(self, learning_rate=None): if learning_rate is None: learning_rate = self.learning_rate with self.graph.as_default(): self.sess.close() self.sess = tf.compat.v1.InteractiveSession() self.sess.as_default() self.x = tf.convert_to_tensor(self.rescaled_features, dtype=tf.float32) self.y = tf.convert_to_tensor(self.targets, dtype=tf.float32) # construct precisness self.tau_rescaling = np.zeros((self.num_obs, self.bnn_output_size)) kernel_ranges = self.config.kernel_ranges for obs_index in range(self.num_obs): self.tau_rescaling[obs_index] += kernel_ranges self.tau_rescaling = self.tau_rescaling**2 # construct weight and bias shapes activations = [tf.nn.tanh] weight_shapes, bias_shapes = [[self.feature_size, self.hidden_shape]], [[self.hidden_shape]] for _ in range(1, self.num_layers - 1): activations.append(tf.nn.tanh) weight_shapes.append([self.hidden_shape, self.hidden_shape]) bias_shapes.append([self.hidden_shape]) activations.append(lambda x: x) weight_shapes.append([self.hidden_shape, self.bnn_output_size]) bias_shapes.append([self.bnn_output_size]) # --------------- # construct prior # --------------- self.prior_layer_outputs = [self.x] self.priors = {} for layer_index in range(self.num_layers): weight_shape, bias_shape = weight_shapes[layer_index], bias_shapes[layer_index] activation = activations[layer_index] weight = tfd.Normal(loc=tf.zeros(weight_shape) + self.weight_loc, scale=tf.zeros(weight_shape) + self.weight_scale) bias = tfd.Normal(loc=tf.zeros(bias_shape) + self.bias_loc, scale=tf.zeros(bias_shape) + self.bias_scale) self.priors['weight_%d' % layer_index] = weight self.priors['bias_%d' % layer_index] = bias prior_layer_output = activation(tf.matmul(self.prior_layer_outputs[-1], weight.sample()) + bias.sample()) self.prior_layer_outputs.append(prior_layer_output) self.prior_bnn_output = self.prior_layer_outputs[-1] # draw precisions from gamma distribution self.prior_tau_normed = tfd.Gamma( 12*(self.num_obs/self.frac_feas)**2 + tf.zeros((self.num_obs, self.bnn_output_size)), tf.ones((self.num_obs, self.bnn_output_size)), ) self.prior_tau = self.prior_tau_normed.sample() / self.tau_rescaling self.prior_scale = tfd.Deterministic(1. / tf.sqrt(self.prior_tau)) # ------------------- # construct posterior # ------------------- self.post_layer_outputs = [self.x] self.posteriors = {} for layer_index in range(self.num_layers): weight_shape, bias_shape = weight_shapes[layer_index], bias_shapes[layer_index] activation = activations[layer_index] weight = tfd.Normal(loc=tf.Variable(tf.random.normal(weight_shape)), scale=tf.nn.softplus(tf.Variable(tf.zeros(weight_shape)))) bias = tfd.Normal(loc=tf.Variable(tf.random.normal(bias_shape)), scale=tf.nn.softplus(tf.Variable(tf.zeros(bias_shape)))) self.posteriors['weight_%d' % layer_index] = weight self.posteriors['bias_%d' % layer_index] = bias post_layer_output = activation(tf.matmul(self.post_layer_outputs[-1], weight.sample()) + bias.sample()) self.post_layer_outputs.append(post_layer_output) self.post_bnn_output = self.post_layer_outputs[-1] self.post_tau_normed = tfd.Gamma( 12*(self.num_obs/self.frac_feas)**2 + tf.Variable(tf.zeros((self.num_obs, self.bnn_output_size))), tf.nn.softplus(tf.Variable(tf.ones((self.num_obs, self.bnn_output_size)))), ) self.post_tau = self.post_tau_normed.sample() / self.tau_rescaling self.post_sqrt_tau = tf.sqrt(self.post_tau) self.post_scale = tfd.Deterministic(1. / self.post_sqrt_tau) # map bnn output to prediction post_kernels = {} targets_dict = {} inferences = [] target_element_index = 0 kernel_element_index = 0 while kernel_element_index < len(self.config.kernel_names): kernel_type = self.config.kernel_types[kernel_element_index] kernel_size = self.config.kernel_sizes[kernel_element_index] feature_begin, feature_end = target_element_index, target_element_index + 1 kernel_begin, kernel_end = kernel_element_index, kernel_element_index + kernel_size prior_relevant = self.prior_bnn_output[:, kernel_begin: kernel_end] post_relevant = self.post_bnn_output[:, kernel_begin: kernel_end] if kernel_type == 'continuous': target = self.y[:, kernel_begin: kernel_end] lowers, uppers = self.config.kernel_lowers[kernel_begin: kernel_end], self.config.kernel_uppers[kernel_begin : kernel_end] prior_support = (uppers - lowers) * (1.2 * tf.nn.sigmoid(prior_relevant) - 0.1) + lowers post_support = (uppers - lowers) * (1.2 * tf.nn.sigmoid(post_relevant) - 0.1) + lowers prior_predict = tfd.Normal(prior_support, self.prior_scale[:, kernel_begin: kernel_end].sample()) post_predict = tfd.Normal(post_support, self.post_scale[:, kernel_begin: kernel_end].sample()) targets_dict[prior_predict] = target post_kernels['param_%d' % target_element_index] = { 'loc': tfd.Deterministic(post_support), 'sqrt_prec': tfd.Deterministic(self.post_sqrt_tau[:, kernel_begin: kernel_end]), 'scale': tfd.Deterministic(self.post_scale[:, kernel_begin: kernel_end].sample())} inference = {'pred': post_predict, 'target': target} inferences.append(inference) elif kernel_type in ['categorical', 'discrete']: target = tf.cast(self.y[:, kernel_begin: kernel_end], tf.int32) prior_temperature = 0.5 + 10.0 / (self.num_obs / self.frac_feas) #prior_temperature = 1.0 post_temperature = prior_temperature prior_support = prior_relevant post_support = post_relevant prior_predict_relaxed = tfd.RelaxedOneHotCategorical(prior_temperature, prior_support) prior_predict = tfd.OneHotCategorical(probs=prior_predict_relaxed.sample()) post_predict_relaxed = tfd.RelaxedOneHotCategorical(post_temperature, post_support) post_predict = tfd.OneHotCategorical(probs=post_predict_relaxed.sample()) targets_dict[prior_predict] = target post_kernels['param_%d' % target_element_index] = {'probs': post_predict_relaxed} inference = {'pred': post_predict, 'target': target} inferences.append(inference) ''' Temperature annealing schedule: - temperature of 100 yields 1e-2 deviation from uniform - temperature of 10 yields 1e-1 deviation from uniform - temperature of 1 yields *almost* perfect agreement with expectation - temperature of 0.1 yields perfect agreement with expectation ''' else: GryffinUnknownSettingsError(f'did not understand kernel type: {kernel_type}') target_element_index += 1 kernel_element_index += kernel_size self.post_kernels = post_kernels self.targets_dict = targets_dict self.loss = 0. for inference in inferences: self.loss += - tf.reduce_sum(inference['pred'].log_prob(inference['target'])) self.optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate) self.train_op = self.optimizer.minimize(self.loss) tf.compat.v1.global_variables_initializer().run()
def __init__(self, logits=None, probs=None): self._dist = tfd.RelaxedOneHotCategorical(logits=logits, probs=probs) self._num_classes = self.mean().shape[-1] self._dtype = prec.global_policy().compute_dtype
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter, decay_rate, decay_steps, edit_dim, enc_hidden_dim, enc_num_layers, dense_layers, swap_memory): with tf.variable_scope(OPS_NAME): # [VOCAB x word_dim] embeddings = vocab.get_embeddings() # Extend embedding matrix to support oov tokens unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN) unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0) unk_embeddings = tf.tile(unk_embed, [50, 1]) # [VOCAB+50 x word_dim] embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0) global_step = tf.train.get_global_step() temperature = tf.train.exponential_decay(temperature_starter, global_step, decay_steps, decay_rate, name='temperature') tf.summary.scalar('temper', temperature, ['extra']) # [batch x max_len x VOCAB+50], softmax probabilities outputs = decoder.rnn_output(decoder_output) # substitute values less than 0 for numerical stability outputs = tf.where(tf.less_equal(outputs, 0), tf.ones_like(outputs) * 1e-10, outputs) # convert softmax probabilities to one_hot vectors dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs) # [batch x max_len x VOCAB+50], one_hot outputs_one_hot = dist.sample() # [batch x max_len x word_dim], one_hot^T * embedding_matrix outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot, embeddings_extended) # [batch] outputs_length = decoder.seq_length(decoder_output) # [batch x max_len x hidden], [batch x hidden] hidden_states, sentence_embedding = encoder.source_sent_encoder( outputs_embed, outputs_length, enc_hidden_dim, enc_num_layers, use_dropout=False, dropout_keep=1.0, swap_memory=swap_memory) h = sentence_embedding for l in dense_layers: h = tf.layers.dense(h, l, activation='relu', name='hidden_%s' % (l)) # [batch x edit_dim] edit_vector = tf.layers.dense(h, edit_dim, activation=None, name='linear') return edit_vector