Exemplo n.º 1
0
def disc_noise(x, keep=0.9, temp=5.0):
    """Add noise to a input that is either a continuous value or a probability
    distribution over discrete categorical values

    Args:

    x: tf.Tensor
        a tensor that will have noise added to it, such that the resulting
        tensor is sound given its definition
    keep: float
        the amount of probability mass to keep on the element that is activated
        teh rest is redistributed evenly to all elements
    temp: float
        the temperature of teh gumbel distribution that is used to corrupt
        the input probabilities x

    Returns:

    noisy_x: tf.Tensor
        a tensor that has noise added to it, which has the interpretation of
        the original tensor (such as a probability distribution)
    """

    p = tf.ones_like(x)
    p = p / tf.reduce_sum(p, axis=-1, keepdims=True)
    p = keep * x + (1.0 - keep) * p
    return tfpd.RelaxedOneHotCategorical(temp, probs=p).sample()
Exemplo n.º 2
0
    def sample(self, y, **kwargs):
        """Generate samples of designs X that have a score y where y is
        the score that the generator conditions on

        Args:

        y: tf.Tensor
            a batch of scalar scores wherein the generator is trained to
            produce designs that have score y

        Returns:

        x_fake: tf.Tensor
            a design the generator is trained to sample from a distribution
            conditioned on the score y achieved by that design
        """

        temp = kwargs.pop("temp", 1.0)
        z = tf.random.normal([tf.shape(y)[0], self.latent_size])
        x = tf.cast(z, tf.float32)
        y = tf.cast(y, tf.float32)

        y_embed = self.embed_0(y, **kwargs)
        x = self.dense_0(tf.concat([x, y_embed], 1), **kwargs)
        x = tf.nn.leaky_relu(self.ln_0(x), alpha=0.2)
        x = self.dense_1(tf.concat([x, y_embed], 1), **kwargs)

        x = tf.nn.leaky_relu(self.ln_1(x), alpha=0.2)
        x = self.dense_2(tf.concat([x, y_embed], 1), **kwargs)
        x = tf.nn.leaky_relu(self.ln_2(x), alpha=0.2)
        x = self.dense_3(tf.concat([x, y_embed], 1), **kwargs)

        logits = tf.reshape(x, [tf.shape(y)[0], *self.design_shape])
        return tfpd.RelaxedOneHotCategorical(
            temp, logits=tf.math.log_softmax(logits)).sample()
Exemplo n.º 3
0
 def __call__(self, features):
     raw_init_std = np.log(np.exp(self._init_std) - 1)
     x = features
     for index in range(self._layers):
         x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
     if self._dist == 'tanh_normal':
         # https://www.desmos.com/calculator/rcmcf5jwe7
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         mean, std = tf.split(x, 2, -1)
         mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
         std = tf.nn.softplus(std + raw_init_std) + self._min_std
         dist = tfd.Normal(mean, std)
         dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
         dist = tfd.Independent(dist, 1)
         dist = tools.SampleDist(dist)
     elif self._dist == 'onehot':
         x = self.get(f'hout', tfkl.Dense, self._size)(x)
         dist = tools.OneHotDist(x)
     elif self._dist == 'gumbel':
         x = self.get(f'hout', tfkl.Dense, self._size)(x)
         dist = tfd.RelaxedOneHotCategorical(temperature=1e-1, logits=x)
         dist = tools.SampleDist(dist)
     else:
         raise NotImplementedError
     return dist
Exemplo n.º 4
0
 def surrogate_design_matrix_model_fn(self, qF_logits_var,
                                      qF_temperature_var,
                                      qF_confounders_loc_var,
                                      qF_confounders_softplus_scale_var):
     qF_test = yield JDCRoot(
         Independent(
             tfd.RelaxedOneHotCategorical(temperature=qF_temperature_var,
                                          logits=qF_logits_var)))
Exemplo n.º 5
0
    def variational_model_fn(self):
        qw = yield JDCRoot(
            Independent(
                tfd.Normal(loc=self.qw_loc_var,
                           scale=tf.nn.softplus(self.qw_softplus_scale_var))))

        qz_scale = yield JDCRoot(
            Independent(
                SoftplusNormal(loc=self.qz_scale_loc_var,
                               scale=tf.nn.softplus(
                                   self.qz_scale_softplus_scale_var))))

        qF_test = yield JDCRoot(
            Independent(
                tfd.RelaxedOneHotCategorical(
                    temperature=self.qF_temperature_var,
                    logits=self.qF_logits_var)))

        # qz = yield JDCRoot(Independent(tfd.Normal(
        #     loc=qz_loc_var,
        #     scale=tf.nn.softplus(qz_softplus_scale_var))))
        qz = yield JDCRoot(Independent(tfd.Deterministic(loc=self.qz_loc_var)))

        qx_bias = yield JDCRoot(
            Independent(
                tfd.Normal(loc=self.qx_bias_loc_var,
                           scale=tf.nn.softplus(
                               self.qx_bias_softplus_scale_var))))

        qx_scale_concentration_c = yield JDCRoot(
            Independent(
                tfd.Deterministic(loc=tf.nn.softplus(
                    self.qx_scale_concentration_c_loc_var))))

        qx_scale_mode_c = yield JDCRoot(
            Independent(
                tfd.Deterministic(
                    loc=tf.nn.softplus(self.qx_scale_mode_c_loc_var))))

        qx_scale = yield JDCRoot(
            Independent(
                SoftplusNormal(loc=self.qx_scale_loc_var,
                               scale=tf.nn.softplus(
                                   self.qx_scale_softplus_scale_var))))

        if self.use_point_estimates:
            qx = yield JDCRoot(
                Independent(tfd.Deterministic(loc=self.qx_loc_var)))
        else:
            qx = yield JDCRoot(
                Independent(
                    tfd.Normal(loc=self.qx_loc_var,
                               scale=tf.nn.softplus(
                                   self.qx_softplus_scale_var))))
            qrnaseq_reads = yield JDCRoot(
                Independent(tfd.Deterministic(tf.zeros([self.num_samples]))))
Exemplo n.º 6
0
 def sample(self, time, outputs, state, name=None):
     """Returns `sample_id` of shape `[batch_size, vocab_size]`. If
     `straight_through` is False, this is gumbel softmax distributions over
     vocabulary with temperature `tau`. If `straight_through` is True,
     this is one-hot vectors of the greedy samples.
     """
     sample_ids = tf.nn.softmax(outputs / self._tau)
     sample_ids = tfpd.RelaxedOneHotCategorical(
         self._tau, logits=outputs).sample()
     if self._straight_through:
         size = tf.shape(sample_ids)[-1]
         sample_ids_hard = tf.cast(
             tf.one_hot(tf.argmax(sample_ids, -1), size), sample_ids.dtype)
         sample_ids = tf.stop_gradient(sample_ids_hard - sample_ids) \
                      + sample_ids
     return sample_ids
Exemplo n.º 7
0
def prepare_output_embed(
    decoder_output,
    temperature_starter,
    decay_rate,
    decay_steps,
):
    # [VOCAB x word_dim]
    embeddings = vocab.get_embeddings()

    # Extend embedding matrix to support oov tokens
    unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN)
    unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0)
    unk_embeddings = tf.tile(unk_embed, [50, 1])

    # [VOCAB+50 x word_dim]
    embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0)

    global_step = tf.train.get_global_step()
    temperature = tf.train.exponential_decay(temperature_starter,
                                             global_step,
                                             decay_steps,
                                             decay_rate,
                                             name='temperature')
    tf.summary.scalar('temper', temperature, ['extra'])

    # [batch x max_len x VOCAB+50], softmax probabilities
    outputs = decoder.rnn_output(decoder_output)

    # substitute values less than 0 for numerical stability
    outputs = tf.where(tf.less_equal(outputs, 0),
                       tf.ones_like(outputs) * 1e-10, outputs)

    # convert softmax probabilities to one_hot vectors
    dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs)

    # [batch x max_len x VOCAB+50], one_hot
    outputs_one_hot = dist.sample()

    # [batch x max_len x word_dim], one_hot^T * embedding_matrix
    outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot,
                              embeddings_extended)

    return outputs_embed
Exemplo n.º 8
0
def gumbel_softmax_bottleneck(dist,
                              vector_quantizer,
                              temperature=0.5,
                              num_iaf_flows=0,
                              use_transformer_for_iaf_parameters=False,
                              num_samples=1,
                              sum_over_latents=True,
                              summary=True):
  """Gumbel-Softmax discrete bottleneck.

  Args:
    dist: Distances between encoder outputs and codebook entries, to be used as
      categorical logits. A float Tensor of shape [batch_size, latent_size,
      code_size].
    vector_quantizer: An instance of the VectorQuantizer class.
    temperature: Temperature parameter used for Gumbel-Softmax distribution.
    num_iaf_flows: Number of inverse-autoregressive flows to perform.
    use_transformer_for_iaf_parameters: Whether to use a Transformer instead of
      a lower-triangular mat-mul to generate IAF parameters.
    num_samples: Number of categorical samples.
    sum_over_latents: Whether to sum over latent dimension when computing
      entropy.
    summary: Whether to log summary histogram.

  Returns:
    one_hot_assignments: Simplex-valued assignments sampled from categorical.
    neg_q_entropy: Negative entropy of categorical distribution.
  """
  latent_size = dist.shape[1]
  # TODO(vafa): Consider randomly setting high temperature to help training.
  one_hot_assignments = tfd.RelaxedOneHotCategorical(
      temperature=temperature,
      logits=-dist).sample(num_samples)
  one_hot_assignments = tf.clip_by_value(one_hot_assignments, 1e-6, 1-1e-6)

  # Approximate density with multinomial distribution.
  q_dist = tfd.Multinomial(total_count=1., logits=-dist)
  neg_q_entropy = q_dist.log_prob(one_hot_assignments)
  if summary:
    tf.summary.histogram("neg_q_entropy_0",
                         tf.reshape(tf.reduce_sum(neg_q_entropy, axis=-1),
                                    [-1]))

  # Perform IAF flows
  for flow_num in range(num_iaf_flows):
    with tf.variable_scope("iaf_variables", reuse=tf.AUTO_REUSE):
      # Pad the one_hot_assignments by zeroing out the first latent dimension
      # and shifting the rest down by one (and removing the last dimension).
      shifted_codes = shift_assignments(one_hot_assignments)
      if use_transformer_for_iaf_parameters:
        unconstrained_scale = iaf_scale_from_transformer(
            shifted_codes,
            vector_quantizer.code_size,
            name=str(flow_num))
      else:
        unconstrained_scale = iaf_scale_from_matmul(shifted_codes,
                                                    name=str(flow_num))
      # Initialize scale bias to be log(e/2 - 1) so initial scale + scale_bias
      # evaluates to 1.
      initial_scale_bias = tf.fill([latent_size, vector_quantizer.num_codes],
                                   INITIAL_SCALE_BIAS)
      scale_bias = tf.get_variable("scale_bias_" + str(flow_num),
                                   initializer=initial_scale_bias)

    one_hot_assignments, inverse_log_det_jacobian = iaf_flow(
        one_hot_assignments,
        unconstrained_scale,
        scale_bias,
        summary=summary)
    neg_q_entropy += inverse_log_det_jacobian

  if sum_over_latents:
    neg_q_entropy = tf.reduce_sum(neg_q_entropy, axis=-1)
  neg_q_entropy = tf.reduce_mean(neg_q_entropy)
  return one_hot_assignments, neg_q_entropy
    def construct_model(self, learning_rate=None):

        if learning_rate is None:
            learning_rate = self.learning_rate

        with self.graph.as_default():

            self.sess.close()
            self.sess = tf.compat.v1.InteractiveSession()
            self.sess.as_default()

            self.x = tf.convert_to_tensor(self.rescaled_features, dtype=tf.float32)
            self.y = tf.convert_to_tensor(self.targets, dtype=tf.float32)

            # construct precisness
            self.tau_rescaling = np.zeros((self.num_obs, self.bnn_output_size))
            kernel_ranges      = self.config.kernel_ranges
            for obs_index in range(self.num_obs):
                self.tau_rescaling[obs_index] += kernel_ranges
            self.tau_rescaling = self.tau_rescaling**2

            # construct weight and bias shapes
            activations = [tf.nn.tanh]
            weight_shapes, bias_shapes = [[self.feature_size, self.hidden_shape]], [[self.hidden_shape]]
            for _ in range(1, self.num_layers - 1):
                activations.append(tf.nn.tanh)
                weight_shapes.append([self.hidden_shape, self.hidden_shape])
                bias_shapes.append([self.hidden_shape])
            activations.append(lambda x: x)
            weight_shapes.append([self.hidden_shape, self.bnn_output_size])
            bias_shapes.append([self.bnn_output_size])

            # ---------------
            # construct prior
            # ---------------
            self.prior_layer_outputs = [self.x]
            self.priors = {}
            for layer_index in range(self.num_layers):
                weight_shape, bias_shape = weight_shapes[layer_index], bias_shapes[layer_index]
                activation = activations[layer_index]

                weight = tfd.Normal(loc=tf.zeros(weight_shape) + self.weight_loc, scale=tf.zeros(weight_shape) + self.weight_scale)
                bias = tfd.Normal(loc=tf.zeros(bias_shape) + self.bias_loc, scale=tf.zeros(bias_shape) + self.bias_scale)
                self.priors['weight_%d' % layer_index] = weight
                self.priors['bias_%d' % layer_index] = bias

                prior_layer_output = activation(tf.matmul(self.prior_layer_outputs[-1], weight.sample()) + bias.sample())
                self.prior_layer_outputs.append(prior_layer_output)

            self.prior_bnn_output = self.prior_layer_outputs[-1]
            # draw precisions from gamma distribution
            self.prior_tau_normed = tfd.Gamma(
                            12*(self.num_obs/self.frac_feas)**2 + tf.zeros((self.num_obs, self.bnn_output_size)),
                            tf.ones((self.num_obs, self.bnn_output_size)),
                        )
            self.prior_tau        = self.prior_tau_normed.sample() / self.tau_rescaling
            self.prior_scale      = tfd.Deterministic(1. / tf.sqrt(self.prior_tau))

            # -------------------
            # construct posterior
            # -------------------
            self.post_layer_outputs = [self.x]
            self.posteriors = {}
            for layer_index in range(self.num_layers):
                weight_shape, bias_shape = weight_shapes[layer_index], bias_shapes[layer_index]
                activation = activations[layer_index]

                weight = tfd.Normal(loc=tf.Variable(tf.random.normal(weight_shape)), scale=tf.nn.softplus(tf.Variable(tf.zeros(weight_shape))))
                bias = tfd.Normal(loc=tf.Variable(tf.random.normal(bias_shape)), scale=tf.nn.softplus(tf.Variable(tf.zeros(bias_shape))))

                self.posteriors['weight_%d' % layer_index] = weight
                self.posteriors['bias_%d' % layer_index] = bias

                post_layer_output = activation(tf.matmul(self.post_layer_outputs[-1], weight.sample()) + bias.sample())
                self.post_layer_outputs.append(post_layer_output)

            self.post_bnn_output = self.post_layer_outputs[-1]
            self.post_tau_normed = tfd.Gamma(
                                12*(self.num_obs/self.frac_feas)**2 + tf.Variable(tf.zeros((self.num_obs, self.bnn_output_size))),
                                tf.nn.softplus(tf.Variable(tf.ones((self.num_obs, self.bnn_output_size)))),
                            )
            self.post_tau        = self.post_tau_normed.sample() / self.tau_rescaling
            self.post_sqrt_tau   = tf.sqrt(self.post_tau)
            self.post_scale	     = tfd.Deterministic(1. / self.post_sqrt_tau)

            # map bnn output to prediction
            post_kernels = {}
            targets_dict = {}
            inferences = []

            target_element_index = 0
            kernel_element_index = 0

            while kernel_element_index < len(self.config.kernel_names):

                kernel_type = self.config.kernel_types[kernel_element_index]
                kernel_size = self.config.kernel_sizes[kernel_element_index]

                feature_begin, feature_end = target_element_index, target_element_index + 1
                kernel_begin, kernel_end   = kernel_element_index, kernel_element_index + kernel_size

                prior_relevant = self.prior_bnn_output[:, kernel_begin: kernel_end]
                post_relevant  = self.post_bnn_output[:,  kernel_begin: kernel_end]

                if kernel_type == 'continuous':

                    target = self.y[:, kernel_begin: kernel_end]
                    lowers, uppers = self.config.kernel_lowers[kernel_begin: kernel_end], self.config.kernel_uppers[kernel_begin : kernel_end]

                    prior_support = (uppers - lowers) * (1.2 * tf.nn.sigmoid(prior_relevant) - 0.1) + lowers
                    post_support = (uppers - lowers) * (1.2 * tf.nn.sigmoid(post_relevant) - 0.1) + lowers

                    prior_predict = tfd.Normal(prior_support, self.prior_scale[:, kernel_begin: kernel_end].sample())
                    post_predict = tfd.Normal(post_support,  self.post_scale[:,  kernel_begin: kernel_end].sample())

                    targets_dict[prior_predict] = target
                    post_kernels['param_%d' % target_element_index] = {
                        'loc':       tfd.Deterministic(post_support),
                        'sqrt_prec': tfd.Deterministic(self.post_sqrt_tau[:, kernel_begin: kernel_end]),
                        'scale':     tfd.Deterministic(self.post_scale[:, kernel_begin: kernel_end].sample())}

                    inference = {'pred': post_predict, 'target': target}
                    inferences.append(inference)

                elif kernel_type in ['categorical', 'discrete']:
                    target = tf.cast(self.y[:, kernel_begin: kernel_end], tf.int32)

                    prior_temperature = 0.5 + 10.0 / (self.num_obs / self.frac_feas)
                    #prior_temperature = 1.0
                    post_temperature = prior_temperature

                    prior_support = prior_relevant
                    post_support = post_relevant

                    prior_predict_relaxed = tfd.RelaxedOneHotCategorical(prior_temperature, prior_support)
                    prior_predict = tfd.OneHotCategorical(probs=prior_predict_relaxed.sample())

                    post_predict_relaxed = tfd.RelaxedOneHotCategorical(post_temperature, post_support)
                    post_predict = tfd.OneHotCategorical(probs=post_predict_relaxed.sample())

                    targets_dict[prior_predict] = target
                    post_kernels['param_%d' % target_element_index] = {'probs': post_predict_relaxed}

                    inference = {'pred': post_predict, 'target': target}
                    inferences.append(inference)

                    '''
                        Temperature annealing schedule:
                            - temperature of 100   yields 1e-2 deviation from uniform
                            - temperature of  10   yields 1e-1 deviation from uniform
                            - temperature of   1   yields *almost* perfect agreement with expectation
                            - temperature of   0.1 yields perfect agreement with expectation
                    '''

                else:
                    GryffinUnknownSettingsError(f'did not understand kernel type: {kernel_type}')

                target_element_index += 1
                kernel_element_index += kernel_size

            self.post_kernels = post_kernels
            self.targets_dict = targets_dict

            self.loss = 0.
            for inference in inferences:
                self.loss += - tf.reduce_sum(inference['pred'].log_prob(inference['target']))

            self.optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
            self.train_op = self.optimizer.minimize(self.loss)

            tf.compat.v1.global_variables_initializer().run()
Exemplo n.º 10
0
 def __init__(self, logits=None, probs=None):
     self._dist = tfd.RelaxedOneHotCategorical(logits=logits, probs=probs)
     self._num_classes = self.mean().shape[-1]
     self._dtype = prec.global_policy().compute_dtype
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter,
                                   decay_rate, decay_steps, edit_dim,
                                   enc_hidden_dim, enc_num_layers,
                                   dense_layers, swap_memory):
    with tf.variable_scope(OPS_NAME):
        # [VOCAB x word_dim]
        embeddings = vocab.get_embeddings()

        # Extend embedding matrix to support oov tokens
        unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN)
        unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0)
        unk_embeddings = tf.tile(unk_embed, [50, 1])

        # [VOCAB+50 x word_dim]
        embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0)

        global_step = tf.train.get_global_step()
        temperature = tf.train.exponential_decay(temperature_starter,
                                                 global_step,
                                                 decay_steps,
                                                 decay_rate,
                                                 name='temperature')
        tf.summary.scalar('temper', temperature, ['extra'])

        # [batch x max_len x VOCAB+50], softmax probabilities
        outputs = decoder.rnn_output(decoder_output)

        # substitute values less than 0 for numerical stability
        outputs = tf.where(tf.less_equal(outputs, 0),
                           tf.ones_like(outputs) * 1e-10, outputs)

        # convert softmax probabilities to one_hot vectors
        dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs)

        # [batch x max_len x VOCAB+50], one_hot
        outputs_one_hot = dist.sample()

        # [batch x max_len x word_dim], one_hot^T * embedding_matrix
        outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot,
                                  embeddings_extended)

        # [batch]
        outputs_length = decoder.seq_length(decoder_output)

        # [batch x max_len x hidden], [batch x hidden]
        hidden_states, sentence_embedding = encoder.source_sent_encoder(
            outputs_embed,
            outputs_length,
            enc_hidden_dim,
            enc_num_layers,
            use_dropout=False,
            dropout_keep=1.0,
            swap_memory=swap_memory)

        h = sentence_embedding
        for l in dense_layers:
            h = tf.layers.dense(h,
                                l,
                                activation='relu',
                                name='hidden_%s' % (l))

        # [batch x edit_dim]
        edit_vector = tf.layers.dense(h,
                                      edit_dim,
                                      activation=None,
                                      name='linear')

    return edit_vector