def VerifySampleAndPdfConsistency(self, pspherical, rtol=0.075):
        """Verifies samples are consistent with the PDF using importance sampling.

    In particular, we verify an estimate the surface area of the n-dimensional
    hypersphere, and the surface areas of the spherical caps demarcated by
    a handful of survival rates.

    Args:
      pspherical: A `PowerSpherical` distribution instance.
      rtol: Relative difference tolerable.
    """
        dim = tf.compat.dimension_value(pspherical.event_shape[-1])
        nsamples = int(1e5)
        samples = pspherical.sample(sample_shape=[nsamples],
                                    seed=test_util.test_seed())
        samples = tf.debugging.check_numerics(samples, 'samples')
        log_prob = pspherical.log_prob(samples)
        log_prob = self.evaluate(log_prob)
        # Check that the log_prob is not nan or +inf. It can be -inf since
        # if we sample a direction diametrically opposite to the mean direction,
        # we'll get an inner product of -1.
        self.assertFalse(np.any(np.isnan(log_prob)))
        self.assertFalse(np.any(np.isposinf(log_prob)))
        log_importance = -log_prob
        sphere_surface_area_estimate, samples, importance = self.evaluate([
            tf.reduce_mean(tf.math.exp(log_importance), axis=0), samples,
            tf.exp(log_importance)
        ])
        true_sphere_surface_area = 2 * (np.pi)**(dim / 2) * self.evaluate(
            tf.exp(-tf.math.lgamma(dim / 2)))
        # Broadcast to correct size
        true_sphere_surface_area += np.zeros_like(sphere_surface_area_estimate)
        # Highly concentrated distributions do not get enough coverage to provide
        # a reasonable full-sphere surface area estimate. These are covered below
        # by CDF-based hypersphere cap surface area estimates.
        # Because the PowerSpherical distribution has zero mass at
        # -`mean_direction` (and points close to -`mean_direction` due to floating
        # point), we only compute this at concentration = 0, which has guaranteed
        # mass everywhere.
        self.assertAllClose(true_sphere_surface_area[0],
                            sphere_surface_area_estimate[0],
                            rtol=rtol)

        # Assert surface area of hyperspherical cap For some CDFs in [.05,.45],
        # (h must be greater than 0 for the hypersphere cap surface area
        # calculation to hold).
        for survival_rate in 0.95, .9, .75, .6:
            cdf = (1 - survival_rate)
            mean_dir = self.evaluate(pspherical.mean_direction)
            dotprods = np.sum(samples * mean_dir, -1)
            # Empirical estimate of the effective dot-product of the threshold that
            # selects for a given CDF level, that is the cosine of the largest
            # passable angle, or the minimum cosine for a within-CDF sample.
            dotprod_thresh = np.percentile(dotprods,
                                           100 * survival_rate,
                                           axis=0,
                                           keepdims=True)
            # We mask this sum because it is possible for the log_prob to be -inf when
            # the mean_direction is -mean_dir.
            importance_masked = np.ma.array(importance,
                                            mask=dotprods <= dotprod_thresh)
            sphere_cap_surface_area_ests = (cdf * (importance_masked).sum(0) /
                                            (dotprods > dotprod_thresh).sum(0))
            h = (1 - dotprod_thresh)
            self.assertGreaterEqual(h.min(),
                                    0)  # h must be >= 0 for the eqn below
            true_sphere_cap_surface_area = (
                0.5 * true_sphere_surface_area *
                self.evaluate(tf.math.betainc(
                    (dim - 1) / 2, 0.5, 2 * h - h**2)))
            if dim == 3:  # For 3-d we have a simpler form we can double-check.
                self.assertAllClose(2 * np.pi * h,
                                    true_sphere_cap_surface_area)

            self.assertAllClose(true_sphere_cap_surface_area,
                                sphere_cap_surface_area_ests +
                                np.zeros_like(true_sphere_cap_surface_area),
                                rtol=rtol)
  def testTrainAndServe(self, use_adapt):

    with self.coordinator.strategy.scope():

      feature_ps, label_ps = self.define_kpls_for_training(use_adapt)

      def dataset_fn():

        def feature_and_label_gen():
          while True:
            features = random.sample(FEATURE_VOCAB, 3)
            label = ["yes"] if "avenger" in features else ["no"]
            yield {"features": features, "label": label}

        # The dataset will be created on the coordinator.
        raw_dataset = tf.data.Dataset.from_generator(
            feature_and_label_gen,
            output_signature={
                "features": tf.TensorSpec([3], tf.string),
                "label": tf.TensorSpec([1], tf.string)
            }).shuffle(100).batch(32)

        train_dataset = raw_dataset.map(lambda x: (  # pylint: disable=g-long-lambda
            {
                "features": feature_ps(x["features"])
            }, label_ps(x["label"])))
        return train_dataset

      # Create the model. The input needs to be compatible with KPLs.
      model_input = keras.layers.Input(
          shape=(3,), dtype=tf.int64, name="model_input")

      # input_dim includes a mask token and an oov token.
      emb_output = keras.layers.Embedding(
          input_dim=len(FEATURE_VOCAB) + 2, output_dim=20)(
              model_input)
      emb_output = tf.reduce_mean(emb_output, axis=1)
      dense_output = keras.layers.Dense(
          units=1, activation="sigmoid")(
              emb_output)
      model = keras.Model({"features": model_input}, dense_output)

      optimizer = rmsprop.RMSprop(learning_rate=0.1)
      accuracy = keras.metrics.Accuracy()

    @tf.function
    def worker_fn(iterator):

      def replica_fn(iterator):
        batch_data, labels = next(iterator)
        with tf.GradientTape() as tape:
          pred = model(batch_data, training=True)
          loss = tf.nn.compute_average_loss(
              keras.losses.BinaryCrossentropy(
                  reduction=losses_utils.ReductionV2.NONE)(labels, pred))
          gradients = tape.gradient(loss, model.trainable_variables)

        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
        accuracy.update_state(labels, actual_pred)

      self.coordinator.strategy.run(replica_fn, args=(iterator,))

    distributed_dataset = self.coordinator.create_per_worker_dataset(dataset_fn)
    distributed_iterator = iter(distributed_dataset)
    for _ in range(4):
      accuracy.reset_state()
      for _ in range(7):
        self.coordinator.schedule(worker_fn, args=(distributed_iterator,))
      self.coordinator.join()
    self.assertGreater(accuracy.result().numpy(), 0.5)

    # Create a saved model.
    model.feature_ps = feature_ps
    model.label_ps = label_ps
    model.label_inverse_lookup_layer = self.define_reverse_lookup_layer()

    def create_serving_signature(model):

      @tf.function
      def serve_fn(raw_features):
        raw_features = tf.compat.v1.expand_dims(raw_features, axis=0)
        transformed_features = model.feature_ps(raw_features)
        outputs = model(transformed_features)
        outputs = tf.compat.v1.squeeze(outputs, axis=0)
        outputs = tf.cast(tf.greater(outputs, 0.5), tf.int64)
        decoded_outputs = model.label_inverse_lookup_layer(outputs)
        return tf.compat.v1.squeeze(decoded_outputs, axis=0)

      # serving does NOT have batch dimension
      return serve_fn.get_concrete_function(
          tf.TensorSpec(
              shape=(3), dtype=tf.string, name="example"))

    serving_fn = create_serving_signature(model)

    saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
    model.save(saved_model_dir, signatures={"serving_default": serving_fn})

    # Test the saved_model.
    loaded_serving_fn = keras.saving.save.load_model(
        saved_model_dir).signatures["serving_default"]

    # check the result w/ and w/o avenger.
    prediction0 = loaded_serving_fn(
        tf.constant(["avenger", "ironman", "avenger"]))["output_0"]
    self.assertIn(prediction0, ("yes", "no"))

    prediction1 = loaded_serving_fn(
        tf.constant(["ironman", "ironman", "unkonwn"]))["output_0"]
    self.assertIn(prediction1, ("yes", "no"))
def _compute_loss(logits, labels):
    return tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                       labels=labels))
Ejemplo n.º 4
0
 def sample_estimate(*parameters):
   mixture = mixture_func(*parameters)
   values = mixture.sample(num_samples, seed=test_util.test_seed())
   if function == 'variance':
     values = tf.math.squared_difference(values, mixture.mean())
   return tf.reduce_mean(values, axis=0)
Ejemplo n.º 5
0
    def fit_actor(self, states, actions, next_states, rewards, masks, discount,
                  target_entropy, init_states):
        """Updates critic parameters.

    Args:
      states: A batch of states.
      actions: A batch of actions.
      next_states: A batch of next states.
      rewards: A batch of rewards.
      masks: A batch of masks indicating the end of the episodes.
      discount: An MDP discount factor.
      target_entropy: Target entropy value for alpha.
      init_states: A batch of init states from the MDP.

    Returns:
      Actor and alpha losses.
    """
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(self.actor.variables)
            _, init_actions, _ = self.actor(init_states)
            _, next_actions, next_log_probs = self.actor(next_states)

            if self.use_dqn:
                target_q1, target_q2 = self.critic_mix(next_states,
                                                       next_actions)
                target_q1 = target_q1 - self.alpha * next_log_probs
                target_q2 = target_q2 - self.alpha * next_log_probs
                target_q1 = rewards + discount * masks * target_q1
                target_q2 = rewards + discount * masks * target_q2

                q1, q2 = self.critic(states, actions)
                init_q1, init_q2 = self.critic(init_states, init_actions)

                if discount == 1:
                    actor_loss1 = -tf.reduce_mean(
                        tf.stop_gradient(
                            self.fgrad(self._lambda + self.algae_alpha +
                                       target_q1 - q1)) * (target_q1 - q1))

                    actor_loss2 = -tf.reduce_mean(
                        tf.stop_gradient(
                            self.fgrad(self._lambda + self.algae_alpha +
                                       target_q2 - q2)) * (target_q2 - q2))
                else:
                    actor_loss1 = -tf.reduce_mean(
                        tf.stop_gradient(self.fgrad(target_q1 - q1)) *
                        (target_q1 - q1) +
                        (1 - discount) * init_q1 * self.algae_alpha)

                    actor_loss2 = -tf.reduce_mean(
                        tf.stop_gradient(self.fgrad(target_q2 - q2)) *
                        (target_q2 - q2) +
                        (1 - discount) * init_q2 * self.algae_alpha)

                actor_loss = (actor_loss1 + actor_loss2) / 2.0
            else:
                target_q = self.critic_mix(next_states, next_actions)
                target_q = target_q - self.alpha * next_log_probs
                target_q = rewards + discount * masks * target_q

                q = self.critic(states, actions)
                init_q = self.critic(init_states, init_actions)

                if discount == 1:
                    actor_loss = -tf.reduce_mean(
                        tf.stop_gradient(
                            self.fgrad(self._lambda + self.algae_alpha +
                                       target_q - q)) * (target_q - q))
                else:
                    actor_loss = -tf.reduce_mean(
                        tf.stop_gradient(self.fgrad(target_q - q)) *
                        (target_q - q) +
                        (1 - discount) * init_q * self.algae_alpha)
            actor_loss += keras_utils.orthogonal_regularization(
                self.actor.trunk)

        actor_grads = tape.gradient(actor_loss, self.actor.variables)
        self.actor_optimizer.apply_gradients(
            zip(actor_grads, self.actor.variables))

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch([self.log_alpha])
            alpha_loss = tf.reduce_mean(self.alpha *
                                        (-next_log_probs - target_entropy))

        if self.learn_alpha:
            alpha_grads = tape.gradient(alpha_loss, [self.log_alpha])
            self.alpha_optimizer.apply_gradients(
                zip(alpha_grads, [self.log_alpha]))

        return actor_loss, alpha_loss, -next_log_probs
Ejemplo n.º 6
0
  def get_loss_tensors(self, f0_candidates, freqs, amps):
    """Get traces of loss to estimate fundamental frequency.

    Args:
      f0_candidates: Frequencies of candidates in hertz. [batch, time, freq].
      freqs: Frequencies of sinusoids in hertz. [batch, time, feq].
      amps: Amplitudes of sinusoids, greater than 0. [batch, time, freq].

    Returns:
      sinusoids_loss: -log p(sinusoids|harmonics), [batch, time, f0_candidate].
      harmonics_loss: - log p(harmonics|sinusoids), [batch, time, f0_candidate].
    """
    # ==========================================================================
    # P(sinusoids | candidate_harmonics).
    # ==========================================================================
    p_sinusoids_given_harmonics = self.get_p_sinusoids_given_harmonics()

    # Treat each partial as a candidate.
    # Get the ratio of each partial to each candidate.
    # -> [batch, time, candidate, partial]
    freq_ratios = safe_divide(freqs[:, :, tf.newaxis, :],
                              f0_candidates[:, :, :, tf.newaxis])
    nll_sinusoids = - p_sinusoids_given_harmonics.log_prob(freq_ratios)

    a = tf.convert_to_tensor(amps[:, :, tf.newaxis, :])

    # # Don't count sinusoids that are less than 1 std > mean.
    # a_mean, a_var = tf.nn.moments(a, axes=-1, keepdims=True)
    # a = tf.where(a > a_mean + 0.5 * a_var**0.5, a, tf.zeros_like(a))

    # Weighted sum by sinusoid amplitude.
    # -> [batch, time, candidate]
    sinusoids_loss = safe_divide(tf.reduce_sum(nll_sinusoids * a, axis=-1),
                                 tf.reduce_sum(a, axis=-1))

    # ==========================================================================
    # P(candidate_harmonics | sinusoids)
    # ==========================================================================
    p_harm_given_sin = self.get_p_harmonics_given_sinusoids(freqs, amps)
    harmonics = self.get_candidate_harmonics(f0_candidates, as_midi=True)

    # Need to rearrage shape as tfp expects, [sample_sh, batch_sh, event_sh].
    # -> [candidate, harmonic, batch, time]
    harmonics_transpose = tf.transpose(harmonics, [2, 3, 0, 1])
    nll_harmonics_transpose = - p_harm_given_sin.log_prob(harmonics_transpose)
    # -> [batch, time, candidate, harm]
    nll_harmonics = tf.transpose(nll_harmonics_transpose, [2, 3, 0, 1])

    # Prior decreasing importance of upper harmonics.
    amps_prior = tf.linspace(
        1.0, 1.0 / self.n_harmonic_points, self.n_harmonic_points)
    harmonics_loss = (nll_harmonics *
                      amps_prior[tf.newaxis, tf.newaxis, tf.newaxis, :])

    # Don't count loss for harmonics above nyquist.
    # Reweight by the number of harmonics below nyquist,
    # (so it doesn't just pick the highest frequency possible).
    nyquist_midi = hz_to_midi(self.sample_rate / 2.0)
    nyquist_mask = tf.where(harmonics < nyquist_midi,
                            tf.ones_like(harmonics_loss),
                            tf.zeros_like(harmonics_loss))
    harmonics_loss *= safe_divide(
        nyquist_mask, tf.reduce_mean(nyquist_mask, axis=-1, keepdims=True))

    # Sum over harmonics.
    harmonics_loss = tf.reduce_mean(harmonics_loss, axis=-1)

    return sinusoids_loss, harmonics_loss
 def loss_fn():
   x, y = next(train_iter)
   nll = -tf.reduce_mean(bnn(x).log_prob(y), axis=-1)
   kl = tfn.losses.compute_extra_loss(bnn) / n
   return nll + kl, (nll, kl)
Ejemplo n.º 8
0
def get_ac_loss(learner_agent_output, env_output, actor_agent_output,
                actor_action, reward_clipping, discounting, baseline_cost,
                entropy_cost, num_steps):
    """Computes actor-critic loss.

  Args:
    learner_agent_output: A nested structure of type `AgentOutput`. The tensors
      are expected to have shape [num_timesteps, batch, ....]
    env_output: A nested structure of type `EnvOutput`. The tensors are expected
      to have shape [num_timesteps, batch, ...].
    actor_agent_output: A nested structure of type `AgentOutput`. The tensors
      are expected to have shape [num_timesteps, batch, ....]
    actor_action: An instance of `ActorAction` containing indices of the actions
      chosen by actor. The total number of actions available to actor at any
      point is equal to actor_agent_output.policy_logits.shape()[-1].
    reward_clipping: A string denoting the clipping strategy to be applied to
      rewards. An empty string means no clipping is applied.
    discounting: The discount factor.
    baseline_cost: A multiplier for baseline loss.
    entropy_cost: A multiplier for entropy.
    num_steps: An int to be used as step arg for summaries.

  Returns:
    A tensor of shape [num_timesteps - 1, batch_size] which contains the
    computed actor-critic loss per timestep per element.
  """
    # Use last baseline value (from the value function) to bootstrap.
    bootstrap_value = learner_agent_output.baseline[-1]

    # At this point, the environment outputs at time step `t` are the inputs
    # that lead to the learner_outputs at time step `t`. After the following
    # shifting, the actions in actor_agent_output and learner_outputs at time step
    # `t` is what leads to the environment outputs at time step `t`.
    actor_agent_output = tf.nest.map_structure(lambda t: t[1:],
                                               actor_agent_output)
    rewards, done, _, _ = tf.nest.map_structure(lambda t: t[1:], env_output)
    actor_action_idx = actor_action.chosen_action_idx[1:]
    learner_agent_output = tf.nest.map_structure(lambda t: t[:-1],
                                                 learner_agent_output)

    clipped_rewards = rewards
    if reward_clipping == 'abs_one':
        clipped_rewards = tf.clip_by_value(rewards, -1, 1)
    elif reward_clipping == 'soft_asymmetric':
        squeezed = tf.tanh(rewards / 5.0)
        # Negative rewards are given less weight than positive rewards.
        clipped_rewards = tf.where(rewards < 0, .3 * squeezed, squeezed) * 5.

    discounts = tf.cast(~done, tf.float32) * discounting

    # Compute V-trace returns and weights.
    vtrace_returns = vtrace.from_logits(
        behaviour_policy_logits=actor_agent_output.policy_logits,
        target_policy_logits=learner_agent_output.policy_logits,
        actions=actor_action_idx,
        discounts=discounts,
        rewards=clipped_rewards,
        values=learner_agent_output.baseline,
        bootstrap_value=bootstrap_value)

    pg_advantages = vtrace_returns.pg_advantages
    v_advantages = vtrace_returns.vs - learner_agent_output.baseline
    tf.summary.histogram('pg_advantages', pg_advantages, step=num_steps)
    tf.summary.histogram('v_advantages', v_advantages, step=num_steps)

    # Compute loss as a weighted sum of the baseline loss, the policy gradient
    # loss and an entropy regularization term.
    pg_loss = _compute_policy_gradient_loss(learner_agent_output.policy_logits,
                                            actor_action_idx,
                                            pg_advantages,
                                            step=num_steps)
    baseline_loss = _compute_baseline_loss(v_advantages, step=num_steps)
    entropy = _compute_entropy_loss(learner_agent_output.policy_logits,
                                    step=num_steps)

    total_loss = pg_loss + baseline_cost * baseline_loss + entropy_cost * entropy
    tf.summary.scalar('loss/ac_loss',
                      tf.reduce_mean(total_loss),
                      step=num_steps)
    return total_loss
def cross_entropy(logits, targets):
    labels = tf.stack([1 - targets, targets], axis=1)
    loss_vals = tf.nn.softmax_cross_entropy_with_logits(labels=labels,
                                                        logits=logits)
    return tf.reduce_mean(loss_vals)
Ejemplo n.º 10
0
def csiszar_vimco(f,
                  p_log_prob,
                  q,
                  num_draws,
                  num_batch_draws=1,
                  seed=None,
                  name=None):
  """Use VIMCO to lower the variance of gradient[csiszar_function(log(Avg(u))].

  This function generalizes VIMCO [(Mnih and Rezende, 2016)][1] to Csiszar
  f-Divergences.

  Note: if `q.reparameterization_type = tfd.FULLY_REPARAMETERIZED`,
  consider using `monte_carlo_variational_loss`.

  The VIMCO loss is:

  ```none
  vimco = f(log(Avg{u[i] : i=0,...,m-1}))
  where,
    logu[i] = log( p(x, h[i]) / q(h[i] | x) )
    h[i] iid~ q(H | x)
  ```

  Interestingly, the VIMCO gradient is not the naive gradient of `vimco`.
  Rather, it is characterized by:

  ```none
  grad[vimco] - variance_reducing_term
  where,
    variance_reducing_term = Sum{ grad[log q(h[i] | x)] *
                                    (vimco - f(log Avg{h[j;i] : j=0,...,m-1}))
                                 : i=0, ..., m-1 }
    h[j;i] = { u[j]                             j!=i
             { GeometricAverage{ u[k] : k!=i}   j==i
  ```

  (We omitted `stop_gradient` for brevity. See implementation for more details.)

  The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th
  element has been replaced by the leave-`i`-out Geometric-average.

  This implementation prefers numerical precision over efficiency, i.e.,
  `O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`.
  (The constant may be fairly large, perhaps around 12.)

  Args:
    f: Python `callable` representing a Csiszar-function in log-space.
    p_log_prob: Python `callable` representing the natural-log of the
      probability under distribution `p`. (In variational inference `p` is the
      joint distribution.)
    q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and
      `log_prob(x)`. (In variational inference `q` is the approximate posterior
      distribution.)
    num_draws: Integer scalar number of draws used to approximate the
      f-Divergence expectation.
    num_batch_draws: Integer scalar number of draws used to approximate the
      f-Divergence expectation.
    seed: Python `int` seed for `q.sample`.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    vimco: The Csiszar f-Divergence generalized VIMCO objective.

  Raises:
    ValueError: if `num_draws < 2`.

  #### References

  [1]: Andriy Mnih and Danilo Rezende. Variational Inference for Monte Carlo
       objectives. In _International Conference on Machine Learning_, 2016.
       https://arxiv.org/abs/1602.06725
  """
  with tf.name_scope(name or 'csiszar_vimco'):
    if num_draws < 2:
      raise ValueError('Must specify num_draws > 1.')
    stop = tf.stop_gradient  # For readability.

    q_sample = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed)
    x = tf.nest.map_structure(stop, q_sample)
    logqx = q.log_prob(x)
    logu = nest_util.call_fn(p_log_prob, x) - logqx
    f_log_sooavg_u, f_log_avg_u = map(f, log_soomean_exp(logu, axis=0))

    dotprod = tf.reduce_sum(
        logqx * stop(f_log_avg_u - f_log_sooavg_u),
        axis=0)  # Sum over iid samples.
    # We now rewrite f_log_avg_u so that:
    #   `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`.
    # To achieve this, we use a trick that
    #   `f(x) - stop(f(x)) == zeros_like(f(x))`
    # but its gradient is grad[f(x)].
    # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence
    # this trick loses no precision. For more discussion regarding the relevant
    # portions of the IEEE754 standard, see the StackOverflow question,
    # "Is there a floating point value of x, for which x-x == 0 is false?"
    # http://stackoverflow.com/q/2686644
    # Following is same as adding zeros_like(dot_prod).
    f_log_avg_u = f_log_avg_u + dotprod - stop(dotprod)
    return tf.reduce_mean(f_log_avg_u, axis=0)  # Avg over batches.
Ejemplo n.º 11
0
def _compute_entropy_loss(logits, step):
    policy = tf.nn.softmax(logits)
    log_policy = tf.nn.log_softmax(logits)
    entropy = -tf.reduce_mean(-policy * log_policy, axis=-1)
    tf.summary.scalar('loss/entropy', tf.reduce_mean(entropy), step=step)
    return entropy
 def loss_fn(*params, seed=None):
     surrogate_posterior = build_surrogate_posterior_fn(*params)
     zs, q_lp = surrogate_posterior.experimental_sample_and_log_prob(
         10, seed=seed)
     return tf.reduce_mean(q_lp - target_log_prob(*zs), axis=0)
Ejemplo n.º 13
0
 def call(self, y_true, y_pred):
   return tf.reduce_mean(metrics.quantile_error(
       y_true, y_pred, quantile=self._quantile, power=self._power))
Ejemplo n.º 14
0
 def call(self, y_true, y_pred):
   return tf.reduce_mean(metrics.trimmed_error(
       y_true, y_pred,
       self._start_quantile, self._end_quantile, power=self._power))
def main(argv):
    del argv  # unused
    if tf.io.gfile.exists(FLAGS.model_dir):
        tf.compat.v1.logging.warning(
            'Warning: deleting old log directory at {}'.format(
                FLAGS.model_dir))
        tf.io.gfile.rmtree(FLAGS.model_dir)
    tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.fake_data:
        train_seq = MNISTSequence(batch_size=FLAGS.batch_size,
                                  fake_data_size=NUM_TRAIN_EXAMPLES)
        heldout_seq = MNISTSequence(batch_size=FLAGS.batch_size,
                                    fake_data_size=NUM_HELDOUT_EXAMPLES)
    else:
        train_set, heldout_set = tf.keras.datasets.mnist.load_data()
        train_seq = MNISTSequence(data=train_set, batch_size=FLAGS.batch_size)
        heldout_seq = MNISTSequence(data=heldout_set,
                                    batch_size=FLAGS.batch_size)

    model = create_model()
    # TODO(b/149259388): understand why Keras does not automatically build the
    # model correctly.
    model.build(input_shape=[None, 28, 28, 1])

    print(' ... Training convolutional neural network')
    for epoch in range(FLAGS.num_epochs):
        epoch_accuracy, epoch_loss = [], []
        for step, (batch_x, batch_y) in enumerate(train_seq):
            batch_loss, batch_accuracy = model.train_on_batch(batch_x, batch_y)
            epoch_accuracy.append(batch_accuracy)
            epoch_loss.append(batch_loss)

            if step % 100 == 0:
                print('Epoch: {}, Batch index: {}, '
                      'Loss: {:.3f}, Accuracy: {:.3f}'.format(
                          epoch, step, tf.reduce_mean(epoch_loss),
                          tf.reduce_mean(epoch_accuracy)))

            if (step + 1) % FLAGS.viz_steps == 0:
                # Compute log prob of heldout set by averaging draws from the model:
                # p(heldout | train) = int_model p(heldout|model) p(model|train)
                #                   ~= 1/n * sum_{i=1}^n p(heldout | model_i)
                # where model_i is a draw from the posterior p(model|train).
                print(' ... Running monte carlo inference')
                probs = tf.stack([
                    model.predict(heldout_seq, verbose=1)
                    for _ in range(FLAGS.num_monte_carlo)
                ],
                                 axis=0)
                mean_probs = tf.reduce_mean(probs, axis=0)
                heldout_log_prob = tf.reduce_mean(tf.math.log(mean_probs))
                print(' ... Held-out nats: {:.3f}'.format(heldout_log_prob))

                if HAS_SEABORN:
                    names = [
                        layer.name for layer in model.layers
                        if 'flipout' in layer.name
                    ]
                    qm_vals = [
                        layer.kernel_posterior.mean().numpy()
                        for layer in model.layers if 'flipout' in layer.name
                    ]
                    qs_vals = [
                        layer.kernel_posterior.stddev().numpy()
                        for layer in model.layers if 'flipout' in layer.name
                    ]
                    plot_weight_posteriors(
                        names,
                        qm_vals,
                        qs_vals,
                        fname=os.path.join(
                            FLAGS.model_dir,
                            'epoch{}_step{:05d}_weights.png'.format(
                                epoch, step)))
                    plot_heldout_prediction(
                        heldout_seq.images,
                        probs.numpy(),
                        fname=os.path.join(
                            FLAGS.model_dir,
                            'epoch{}_step{}_pred.png'.format(epoch, step)),
                        title='mean heldout logprob {:.2f}'.format(
                            heldout_log_prob))
Ejemplo n.º 16
0
 def _mean(self, samples=None):
     if samples is None:
         samples = tf.convert_to_tensor(self._samples)
     return tf.reduce_mean(samples, axis=self._samples_axis)
Ejemplo n.º 17
0
    def _head(self, env_output, neck_outputs):
        disc_mask = tf.reshape(
            neck_outputs[constants.DISC_MASK],
            [self._current_num_timesteps, self._current_batch_size])
        # Get first_true time step for text states as it's the same for all steps
        # in a path.
        # Shape = [time, batch] for both disc_mask and first_true
        first_true = utils.get_first_true_column(disc_mask)
        # Transpose to [batch, time] to ensure correct batch order for boolean_mask.
        first_true = tf.transpose(first_true, perm=[1, 0])

        # Transpose a list of n_lstm_layers (h, c) states to batch major.
        raw_text_state = tf.nest.map_structure(
            lambda t: tf.transpose(t, perm=[1, 0, 2]),
            neck_outputs['text_state'])
        tf.debugging.assert_equal(
            raw_text_state[0][0].shape,
            [self._current_batch_size, self._current_num_timesteps, 512])
        # Take the first step's text state since it's the same for all steps.
        # Selected state has shape [batch, hidden]
        text_state = self._select_by_mask(raw_text_state, first_true)

        # Projected shape: [batch, hidden_dim].
        text_feature = self._get_final_projection(
            self._instruction_feature_projection, text_state)

        # Get last_true mask for image states, i.e., state at end of sequence.
        # Shape = [time, batch] for both disc_mask and last_true
        last_true = utils.get_last_true_column(disc_mask)
        last_true = tf.transpose(last_true, perm=[1, 0])
        # Sanity check: ensure the first and last text states in a path are same.
        text_state_last_true = self._select_by_mask(raw_text_state, last_true)
        tf.debugging.assert_equal(text_state[-1][0],
                                  text_state_last_true[-1][0])

        # Transpose image states, a list of (h, c) states, into batch major. Each
        # state has shape [batch, time_step, hidden_dim]
        raw_image_state = tf.nest.map_structure(
            lambda t: tf.transpose(t, perm=[1, 0, 2]),
            neck_outputs['visual_state'])
        if self._average_image_states_of_all_steps:
            # Shape = [batch, time_step, 1]
            float_disc_mask = tf.expand_dims(tf.cast(tf.transpose(disc_mask),
                                                     tf.float32),
                                             axis=2)
            # Shape of each reduced state: [batch, hidden_dim]
            image_state = tf.nest.map_structure(
                lambda x: tf.reduce_mean(x * float_disc_mask, 1),
                raw_image_state)
        else:
            # Selected state has shape [batch, hidden_dim].
            image_state = self._select_by_mask(raw_image_state, last_true)
        # Projected shape: [batch, hidden].
        visual_feature = self._get_final_projection(
            self._image_feature_projection, image_state)

        # Normalize features.
        visual_feature = tf.nn.l2_normalize(visual_feature, axis=-1)
        text_feature = tf.nn.l2_normalize(text_feature, axis=-1)

        # Select path_ids for current batch.
        # Transposed shape = [batch, time].
        raw_path_ids = tf.transpose(env_output.observation[constants.PATH_ID])
        # Shape = [batch].
        path_ids = self._select_by_mask(raw_path_ids, first_true)
        # Asserts first true and last true are referring to the same path.
        path_ids_last_true = self._select_by_mask(raw_path_ids, last_true)
        tf.debugging.assert_equal(path_ids, path_ids_last_true)

        # Shape = [time, batch]
        raw_labels = tf.cast(env_output.observation['label'], tf.float32)
        raw_labels = tf.transpose(raw_labels)
        # Shape = [batch]
        labels = self._select_by_mask(raw_labels, first_true)
        tf.debugging.assert_equal(labels,
                                  self._select_by_mask(raw_labels, last_true))
        # Add time dimension as required by actor. Shape = [1, batch]
        labels = tf.expand_dims(labels, axis=0)

        # Shape: [batch, batch]
        similarity = tf.matmul(visual_feature,
                               tf.transpose(text_feature, perm=[1, 0]))
        # Add time dim as required by actor. Shape = [1, batch, batch]
        similarity = tf.expand_dims(similarity, axis=0)

        # Make similarity mask to exclude multiple positive matching labels
        diag_mask = tf.eye(self._current_batch_size, dtype=tf.bool)
        # path_id mask where matching col-row pairs are 1 except diagnal pairs.
        rows = tf.tile(tf.reshape(path_ids, [self._current_batch_size, 1]),
                       [1, self._current_batch_size])
        cols = tf.tile(tf.reshape(path_ids, [1, self._current_batch_size]),
                       [self._current_batch_size, 1])
        path_id_mask = tf.logical_and(tf.equal(rows, cols),
                                      tf.logical_not(diag_mask))
        # Filter the mask by label. Positive labels are 1.
        row_labels = tf.tile(tf.reshape(labels, [self._current_batch_size, 1]),
                             [1, self._current_batch_size])
        col_labels = tf.tile(tf.reshape(labels, [1, self._current_batch_size]),
                             [self._current_batch_size, 1])
        label_mask = tf.logical_and(tf.cast(row_labels, tf.bool),
                                    tf.cast(col_labels, tf.bool))

        # M[i, j]=0 (i!=j) if path_id_mask[i,j] is True and label_mask[i, j] is True
        similarity_mask = tf.logical_not(
            tf.logical_and(path_id_mask, label_mask))
        # Add timestep dim as required by actor. Shape = [1, batch, batch]
        similarity_mask = tf.expand_dims(similarity_mask, axis=0)

        # Computes logits by transforming similarity from [-1, 1] to unbound.
        # Shape: [time, batch, batch]
        similarity_logits = self.similarity_scaler * similarity

        output_logits = {
            'similarity': similarity_logits,
            'similarity_mask': similarity_mask,
            'labels': labels
        }

        # Logits for classification loss. Shape = [time, batch]
        classification_logits = (
            self.affine_a * tf.linalg.diag_part(similarity) + self.affine_b)

        return common.AgentOutput(policy_logits=output_logits,
                                  baseline=classification_logits)
Ejemplo n.º 18
0
 def _stddev(self):
     samples = tf.convert_to_tensor(self._samples)
     axis = self._samples_axis
     r = samples - tf.expand_dims(self._mean(samples), axis=axis)
     var = tf.reduce_mean(tf.square(r), axis=axis)
     return tf.sqrt(var)
Ejemplo n.º 19
0
def main(_):
    if FLAGS.check_numerics and FLAGS.dump_dir:
        raise ValueError(
            "The --check_numerics and --dump_dir flags are mutually "
            "exclusive.")
    if FLAGS.check_numerics:
        tf.debugging.enable_check_numerics()
    elif FLAGS.dump_dir:
        tf.debugging.experimental.enable_dump_debug_info(
            FLAGS.dump_dir,
            tensor_debug_mode=FLAGS.dump_tensor_debug_mode,
            circular_buffer_size=FLAGS.dump_circular_buffer_size)

    # Import data
    if FLAGS.fake_data:
        imgs = tf.random.uniform(maxval=256,
                                 shape=(1000, 28, 28),
                                 dtype=tf.int32)
        labels = tf.random.uniform(maxval=10, shape=(1000, ), dtype=tf.int32)
        mnist_train = imgs, labels
        mnist_test = imgs, labels
    else:
        mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()

    @tf.function
    def format_example(imgs, labels):
        """Formats each training and test example to work with our model."""
        imgs = tf.reshape(imgs, [-1, 28 * 28])
        imgs = tf.cast(imgs, tf.float32) / 255.0
        labels = tf.one_hot(labels, depth=10, dtype=tf.float32)
        return imgs, labels

    train_ds = tf.data.Dataset.from_tensor_slices(mnist_train).shuffle(
        FLAGS.train_batch_size * FLAGS.max_steps,
        seed=RAND_SEED).batch(FLAGS.train_batch_size)
    train_ds = train_ds.map(format_example)

    test_ds = tf.data.Dataset.from_tensor_slices(mnist_test).repeat().batch(
        len(mnist_test[0]))
    test_ds = test_ds.map(format_example)

    def get_dense_weights(input_dim, output_dim):
        """Initializes the parameters for a single dense layer."""
        initial_kernel = tf.keras.initializers.TruncatedNormal(mean=0.0,
                                                               stddev=0.1,
                                                               seed=RAND_SEED)
        kernel = tf.Variable(initial_kernel([input_dim, output_dim]))
        bias = tf.Variable(tf.constant(0.1, shape=[output_dim]))

        return kernel, bias

    @tf.function
    def dense_layer(weights, input_tensor, act=tf.nn.relu):
        """Runs the forward computation for a single dense layer."""
        kernel, bias = weights
        preactivate = tf.matmul(input_tensor, kernel) + bias

        activations = act(preactivate)
        return activations

    # init model
    hidden_weights = get_dense_weights(IMAGE_SIZE**2, HIDDEN_SIZE)
    output_weights = get_dense_weights(HIDDEN_SIZE, NUM_LABELS)
    variables = hidden_weights + output_weights

    @tf.function
    def model(x):
        """Feed forward function of the model.

    Args:
      x: a (?, 28*28) tensor consisting of the feature inputs for a batch of
        examples.

    Returns:
      A (?, 10) tensor containing the class scores for each example.
    """
        hidden_act = dense_layer(hidden_weights, x)
        logits_act = dense_layer(output_weights, hidden_act, tf.identity)
        y = tf.nn.softmax(logits_act)
        return y

    @tf.function
    def loss(probs, labels):
        """Calculates cross entropy loss.

    Args:
      probs: Class probabilities predicted by the model. The shape is expected
        to be (?, 10).
      labels: Truth labels for the classes, as one-hot encoded vectors. The
        shape is expected to be the same as `probs`.

    Returns:
      A scalar loss tensor.
    """
        diff = -labels * tf.math.log(probs)
        loss = tf.reduce_mean(diff)
        return loss

    train_batches = iter(train_ds)
    test_batches = iter(test_ds)
    optimizer = tf.optimizers.Adam(learning_rate=FLAGS.learning_rate)
    for i in range(FLAGS.max_steps):
        x_train, y_train = next(train_batches)
        x_test, y_test = next(test_batches)

        # Train Step
        with tf.GradientTape() as tape:
            y = model(x_train)
            loss_val = loss(y, y_train)
        grads = tape.gradient(loss_val, variables)

        optimizer.apply_gradients(zip(grads, variables))

        # Evaluation Step
        y = model(x_test)
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_test, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        print("Accuracy at step %d: %s" % (i, accuracy.numpy()))
Ejemplo n.º 20
0
def model_fn(features, labels, mode, params, config):
    """Build the model function for use in an estimator.

  Args:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.
  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
    del labels, config

    encoder = make_encoder(params["activation"], params["num_topics"],
                           params["layer_sizes"])
    decoder, topics_words = make_decoder(params["num_topics"],
                                         features.shape[1])
    topics_prior = make_prior(params["num_topics"],
                              params["prior_initial_value"])

    alpha = topics_prior.concentration

    topics_posterior = encoder(features)
    topics = topics_posterior.sample(seed=234)
    random_reconstruction = decoder(topics)

    reconstruction = random_reconstruction.log_prob(features)
    tf1.summary.scalar("reconstruction", tf.reduce_mean(reconstruction))

    # Compute the KL-divergence between two Dirichlets analytically.
    # The sampled KL does not work well for "sparse" distributions
    # (see Appendix D of [2]).
    kl = tfd.kl_divergence(topics_posterior, topics_prior)
    tf1.summary.scalar("kl", tf.reduce_mean(kl))

    # Ensure that the KL is non-negative (up to a very small slack).
    # Negative KL can happen due to numerical instability.
    with tf.control_dependencies(
        [tf.debugging.assert_greater(kl, -1e-3, message="kl")]):
        kl = tf.identity(kl)

    elbo = reconstruction - kl
    avg_elbo = tf.reduce_mean(elbo)
    tf1.summary.scalar("elbo", avg_elbo)
    loss = -avg_elbo

    # Perform variational inference by minimizing the -ELBO.
    global_step = tf1.train.get_or_create_global_step()
    optimizer = tf1.train.AdamOptimizer(params["learning_rate"])

    # This implements the "burn-in" for prior parameters (see Appendix D of [2]).
    # For the first prior_burn_in_steps steps they are fixed, and then trained
    # jointly with the other parameters.
    grads_and_vars = optimizer.compute_gradients(loss)
    grads_and_vars_except_prior = [
        x for x in grads_and_vars if x[1] not in topics_prior.variables
    ]

    def train_op_except_prior():
        return optimizer.apply_gradients(grads_and_vars_except_prior,
                                         global_step=global_step)

    def train_op_all():
        return optimizer.apply_gradients(grads_and_vars,
                                         global_step=global_step)

    train_op = tf.cond(pred=global_step < params["prior_burn_in_steps"],
                       true_fn=train_op_except_prior,
                       false_fn=train_op_all)

    # The perplexity is an exponent of the average negative ELBO per word.
    words_per_document = tf.reduce_sum(features, axis=1)
    log_perplexity = -elbo / words_per_document
    tf1.summary.scalar("perplexity", tf.exp(tf.reduce_mean(log_perplexity)))
    (log_perplexity_tensor,
     log_perplexity_update) = tf1.metrics.mean(log_perplexity)
    perplexity_tensor = tf.exp(log_perplexity_tensor)

    # Obtain the topics summary. Implemented as a py_func for simplicity.
    topics = tf1.py_func(functools.partial(get_topics_strings,
                                           vocabulary=params["vocabulary"]),
                         [topics_words, alpha],
                         tf.string,
                         stateful=False)
    tf1.summary.text("topics", topics)

    return tf1.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops={
            "elbo": tf1.metrics.mean(elbo),
            "reconstruction": tf1.metrics.mean(reconstruction),
            "kl": tf1.metrics.mean(kl),
            "perplexity": (perplexity_tensor, log_perplexity_update),
            "topics": (topics, tf.no_op()),
        },
    )
Ejemplo n.º 21
0
def loss_fn(params, inputs, targets):
  predicted = params[0] * inputs + params[1]
  loss = tf.reduce_mean(input_tensor=tf.square(predicted - targets))
  return tf_np.asarray(loss)
Ejemplo n.º 22
0
def soft_multivariate_quantiles(x,
                                quantiles,
                                quantile_width=None,
                                **kwargs):
  """Computes soft multivariate quantiles via optimal transport.

  Transport multivariate input values in x onto 2^d + 1 weighted points,
  {0,1}^d + [0.5, ..., 0.5]. Target weights are adjusted so
  that those values in x that are transported to the middle value in the target
  vector correspond to those concentrating around the quantile of interest.

  Args:
   x: Tensor<float> of shape [batch, N, d]
   quantiles: Tensor<float> of shape [r, d], r targeted quantiles of dimension d
   quantile_width: (float) mass given to the bucket supposed to attract points
     whose value concentrate around the desired quantile value. Bigger width
     means that we allow the soft quantile to be a mixture of more points
     further away from the quantile. If None, the width is set at 1/n where n is
     the number of values considered (the size along the 'axis').
   **kwargs: see sinkhorn.autodiff_sinkhorn for possible extra parameters.

  Returns:
    A Tensor<float> [N,r,d] of multivariate quantiles per batch.

  """
  quantiles = tf.constant(quantiles, tf.float32)
  batch_size = x.shape[0]
  n = tf.cast(x.shape[1], tf.float32)
  d = x.shape[2]
  if quantile_width is None:
    quantile_width = 2 / n
  num_quantiles = tf.shape(quantiles)[0]
  hypercube_vertices = tf.constant(
      list(itertools.product([-1, 1], repeat=d)), tf.float32)
  # weights attached to vertices for each quantile. this is n_quantiles x 2^r
  weights = quantiles[:, tf.newaxis, :]**(
      0.5 * (1 - hypercube_vertices))[tf.newaxis, Ellipsis]
  weights *= (1 - quantiles)[:, tf.newaxis, :]**(
      0.5 * (1 + hypercube_vertices))[tf.newaxis, Ellipsis]

  weights = (1 - quantile_width) * tf.reduce_prod(weights, axis=2)
  # adding weights for quantile itself (in position 0).
  weights = tf.concat((quantile_width * tf.ones((num_quantiles, 1)), weights),
                      axis=1)
  # augmenting and formating as batch_size * 2^r +1 * num_quantiles
  weights = tf.reshape(
      tf.tile(tf.transpose(weights), [batch_size, 1]),
      [batch_size, 2**d + 1, num_quantiles])
  # set target locations, by adding the point at 0 that will absorb the quantile
  # augment it with batch_size
  y = tf.concat((tf.zeros((1, d), dtype=tf.float32), hypercube_vertices),
                axis=0)
  y = tf.reshape(tf.tile(y, [batch_size, 1]), [batch_size, 2**d + 1, d])
  # center x
  x_mean = tf.reduce_mean(x, axis=1)
  x = x - x_mean[:, tf.newaxis, :]
  transports = sinkhorn.autodiff_sinkhorn(
      x, y,
      tf.ones([batch_size, n, num_quantiles], dtype=tf.float32) / n, weights,
      **kwargs)

  # recover convex combinations resulting from transporting to central point in
  # in all batches and quantile variations.
  transports = 1 / quantile_width * tf.reshape(transports[:, :, 0, :],
                                               [batch_size, n, -1])
  # apply these convex combinations to data points + recenter.
  all_soft_quantiles = tf.reduce_sum(
      transports[:, :, :, tf.newaxis] *
      x[:, :, tf.newaxis, :],
      axis=1) + x_mean[:, tf.newaxis, :]
  # reshape those quantiles after having applied convex combinations.
  return tf.reshape(all_soft_quantiles, [batch_size, num_quantiles, d])
Ejemplo n.º 23
0
    def fit_critic(self, states, actions, next_states, rewards, masks,
                   discount, init_states):
        """Updates critic parameters.

    Args:
      states: A batch of states.
      actions: A batch of actions.
      next_states: A batch of next states.
      rewards: A batch of rewards.
      masks: A batch of masks indicating the end of the episodes.
      discount: An MDP discount factor.
      init_states: A batch of init states from the MDP.

    Returns:
      Critic loss.
    """
        _, init_actions, _ = self.actor(init_states)
        _, next_actions, next_log_probs = self.actor(next_states)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(self.critic.variables + [self._lambda])

            if self.use_dqn:
                target_q1, target_q2 = self.critic_mix(next_states,
                                                       next_actions)

                target_q1 = target_q1 - self.alpha * next_log_probs
                target_q2 = target_q2 - self.alpha * next_log_probs

                target_q1 = rewards + discount * masks * target_q1
                target_q2 = rewards + discount * masks * target_q2

                q1, q2 = self.critic(states, actions)
                init_q1, init_q2 = self.critic(init_states, init_actions)

                if discount == 1:
                    critic_loss1 = tf.reduce_mean(
                        self.f(self._lambda + self.algae_alpha + target_q1 -
                               q1) - self.algae_alpha * self._lambda)

                    critic_loss2 = tf.reduce_mean(
                        self.f(self._lambda + self.algae_alpha + target_q2 -
                               q2) - self.algae_alpha * self._lambda)
                else:
                    critic_loss1 = tf.reduce_mean(
                        self.f(target_q1 - q1) +
                        (1 - discount) * init_q1 * self.algae_alpha)

                    critic_loss2 = tf.reduce_mean(
                        self.f(target_q2 - q2) +
                        (1 - discount) * init_q2 * self.algae_alpha)

                critic_loss = (critic_loss1 + critic_loss2)
            else:
                target_q = self.critic_mix(next_states, next_actions)
                target_q = target_q - self.alpha * next_log_probs
                target_q = rewards + discount * masks * target_q

                q = self.critic(states, actions)
                init_q = self.critic(init_states, init_actions)

                if discount == 1:
                    critic_loss = tf.reduce_mean(
                        self.f(self._lambda + self.algae_alpha + target_q -
                               q) - self.algae_alpha * self._lambda)
                else:
                    critic_loss = tf.reduce_mean(
                        self.f(target_q - q) +
                        (1 - discount) * init_q * self.algae_alpha)

        critic_grads = tape.gradient(critic_loss,
                                     self.critic.variables + [self._lambda])

        self.critic_optimizer.apply_gradients(
            zip(critic_grads, self.critic.variables + [self._lambda]))

        return critic_loss
    def testSampleEndtoEndXLA(self):
        """An end-to-end test of sampling using SMC."""
        if tf.executing_eagerly(
        ) or tf.config.experimental_functions_run_eagerly():
            self.skipTest('No need to test XLA under all execution regimes.')

        seed = test_util.test_seed()
        dtype = tf.float32
        # Set up data.
        predictors = np.asarray([
            201., 244., 47., 287., 203., 58., 210., 202., 198., 158., 165.,
            201., 157., 131., 166., 160., 186., 125., 218., 146.
        ])
        obs = np.asarray([
            592., 401., 583., 402., 495., 173., 479., 504., 510., 416., 393.,
            442., 317., 311., 400., 337., 423., 334., 533., 344.
        ])
        y_sigma = np.asarray([
            61., 25., 38., 15., 21., 15., 27., 14., 30., 16., 14., 25., 52.,
            16., 34., 31., 42., 26., 16., 22.
        ])
        y_sigma = tf.cast(y_sigma / (2 * obs.std(axis=0)), dtype)
        obs = tf.cast((obs - obs.mean(axis=0)) / (2 * obs.std(axis=0)), dtype)
        predictors = tf.cast((predictors - predictors.mean(axis=0)) /
                             (2 * predictors.std(axis=0)), dtype)

        hyper_mean = tf.cast(0, dtype)
        hyper_scale = tf.cast(2.5, dtype)
        # Generate model prior_log_prob_fn and likelihood_log_prob_fn.
        prior_jd = tfd.JointDistributionSequential([
            tfd.Normal(loc=hyper_mean, scale=hyper_scale),
            tfd.Normal(loc=hyper_mean, scale=hyper_scale),
            tfd.Normal(loc=hyper_mean, scale=hyper_scale),
            tfd.HalfNormal(scale=tf.cast(.5, dtype)),
            tfd.Uniform(low=tf.cast(0, dtype), high=.5),
        ],
                                                   validate_args=True)

        def likelihood_log_prob_fn(b0, b1, mu_out, sigma_out, weight):
            return tfd.Independent(
                tfd.Mixture(
                    tfd.Categorical(probs=tf.stack([
                        tf.repeat(1 - weight[..., tf.newaxis], 20, axis=-1),
                        tf.repeat(weight[..., tf.newaxis], 20, axis=-1)
                    ], -1)), [
                        tfd.Normal(loc=b0[..., tf.newaxis] +
                                   b1[..., tf.newaxis] * predictors,
                                   scale=y_sigma),
                        tfd.Normal(loc=mu_out[..., tf.newaxis],
                                   scale=y_sigma + sigma_out[..., tf.newaxis])
                    ]), 1).log_prob(obs)

        unconstraining_bijectors = [
            tfb.Identity(),
            tfb.Identity(),
            tfb.Identity(),
            tfb.Softplus(),
            tfb.Sigmoid(tf.constant(0., dtype), .5),
        ]
        make_transform_hmc_kernel_fn = gen_make_transform_hmc_kernel_fn(
            unconstraining_bijectors, num_leapfrog_steps=5)

        @tf.function(autograph=False, experimental_compile=True)
        def run_smc():
            # Ensure we're really in graph mode.
            assert hasattr(tf.constant([]), 'graph')

            return tfp.experimental.mcmc.sample_sequential_monte_carlo(
                prior_jd.log_prob,
                likelihood_log_prob_fn,
                prior_jd.sample([1000, 5], seed=seed),
                make_kernel_fn=make_transform_hmc_kernel_fn,
                tuning_fn=functools.partial(simple_heuristic_tuning,
                                            optimal_accept=.6),
                min_num_steps=5,
                seed=seed)

        n_stage, (b0, b1, mu_out, sigma_out, weight), _ = run_smc()

        (n_stage, b0, b1, mu_out, sigma_out, weight) = self.evaluate(
            (n_stage, b0, b1, mu_out, sigma_out, weight))

        self.assertTrue(n_stage, 10)

        # Compare the SMC posterior with the result from a calibrated HMC.
        self.assertAllClose(tf.reduce_mean(b0), 0.016, atol=0.005, rtol=0.005)
        self.assertAllClose(tf.reduce_mean(b1), 1.245, atol=0.005, rtol=0.035)
        self.assertAllClose(tf.reduce_mean(weight), 0.28, atol=0.03, rtol=0.02)
        self.assertAllClose(tf.reduce_mean(mu_out), 0.13, atol=0.2, rtol=0.2)
        self.assertAllClose(tf.reduce_mean(sigma_out),
                            0.46,
                            atol=0.5,
                            rtol=0.5)

        self.assertAllClose(tf.math.reduce_std(b0),
                            0.031,
                            atol=0.015,
                            rtol=0.3)
        self.assertAllClose(tf.math.reduce_std(b1), 0.068, atol=0.1, rtol=0.1)
        self.assertAllClose(tf.math.reduce_std(weight),
                            0.1,
                            atol=0.1,
                            rtol=0.1)
Ejemplo n.º 25
0
    def testLangevin3DNormalDynamicVolatility(self):
        """Sampling from a 3-D Multivariate Normal distribution."""
        dtype = np.float32
        true_mean = dtype([1, 2, 7])
        true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]])
        num_results = 500
        num_chains = 500

        # Targeg distribution is defined through the Cholesky decomposition
        chol = tf.linalg.cholesky(true_cov)
        target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

        # Assume that the state is passed as a list of 1-d tensors `x` and `y`.
        # Then the target log-density is defined as follows:
        def target_log_prob(x, y):
            # Stack the input tensors together
            z = tf.concat([x, y], axis=-1)
            return target.log_prob(z)

        # Here we define the volatility function to be non-caonstant
        def volatility_fn(x, y):
            # Stack the input tensors together
            return [
                1. / (0.5 + 0.1 * tf.abs(x + y)), 1. / (0.5 + 0.1 * tf.abs(y))
            ]

        # Initial state of the chain
        init_state = [
            np.ones([num_chains, 2], dtype=dtype),
            np.ones([num_chains, 1], dtype=dtype)
        ]

        # Run Random Walk Metropolis with normal proposal for `num_results`
        # iterations for `num_chains` independent chains:
        states = tfp.mcmc.sample_chain(
            num_results=num_results,
            current_state=init_state,
            kernel=tfp.mcmc.MetropolisAdjustedLangevinAlgorithm(
                target_log_prob_fn=target_log_prob,
                volatility_fn=volatility_fn,
                step_size=.1),
            num_burnin_steps=200,
            num_steps_between_results=1,
            trace_fn=None,
            seed=test_util.test_seed())

        states = tf.concat(states, axis=-1)
        sample_mean = tf.reduce_mean(states, axis=[0, 1])
        x = (states - sample_mean)[..., tf.newaxis]
        sample_cov = tf.reduce_mean(tf.matmul(x, x, transpose_b=True),
                                    axis=[0, 1])

        sample_mean_, sample_cov_ = self.evaluate([sample_mean, sample_cov])

        self.assertAllClose(true_mean,
                            np.squeeze(sample_mean_),
                            atol=0.1,
                            rtol=0.1)
        self.assertAllClose(true_cov,
                            np.squeeze(sample_cov_),
                            atol=0.1,
                            rtol=0.1)
Ejemplo n.º 26
0
    def run_test_sample_consistent_mean_covariance(self,
                                                   sess_run_fn,
                                                   dist,
                                                   num_samples=int(1e5),
                                                   seed=None,
                                                   rtol=1e-2,
                                                   atol=0.1,
                                                   cov_rtol=None,
                                                   cov_atol=None):
        """Tests that sample/mean/covariance are consistent with each other.

    "Consistency" means that `sample`, `mean`, `covariance`, etc all correspond
    to the same distribution.

    Args:
      sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
        returning a list of results after running one "step" of TensorFlow
        computation, typically set to `sess.run`.
      dist: Distribution instance or object which implements `sample`,
        `log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
      num_samples: Python `int` scalar indicating the number of Monte-Carlo
        samples to draw from `dist`.
      seed: Python `int` indicating the seed to use when sampling from `dist`.
        In general it is not recommended to use `None` during a test as this
        increases the likelihood of spurious test failure.
      rtol: Python `float`-type indicating the admissible relative error between
        analytical and sample statistics.
      atol: Python `float`-type indicating the admissible absolute error between
        analytical and sample statistics.
      cov_rtol: Python `float`-type indicating the admissible relative error
        between analytical and sample covariance. Default: rtol.
      cov_atol: Python `float`-type indicating the admissible absolute error
        between analytical and sample covariance. Default: atol.
    """

        x = dist.sample(num_samples,
                        seed=test_seed_stream(hardcoded_seed=seed))
        sample_mean = tf.reduce_mean(x, axis=0)
        sample_covariance = tf.reduce_mean(_vec_outer_square(x - sample_mean),
                                           axis=0)
        sample_variance = tf.linalg.diag_part(sample_covariance)
        sample_stddev = tf.sqrt(sample_variance)

        [
            sample_mean_, sample_covariance_, sample_variance_, sample_stddev_,
            mean_, covariance_, variance_, stddev_
        ] = sess_run_fn([
            sample_mean,
            sample_covariance,
            sample_variance,
            sample_stddev,
            dist.mean(),
            dist.covariance(),
            dist.variance(),
            dist.stddev(),
        ])

        self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol)
        self.assertAllClose(covariance_,
                            sample_covariance_,
                            rtol=cov_rtol or rtol,
                            atol=cov_atol or atol)
        self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol)
        self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol)
Ejemplo n.º 27
0
 def call(self, y_true, y_pred):
     losses = tf.ragged.map_flat_values(
         tf.math.squared_difference, y_true, y_pred
     )
     return tf.reduce_mean(losses)
Ejemplo n.º 28
0
def contrastive_loss(features,
                     labels=None,
                     temperature=1.0,
                     contrast_mode=enums.LossContrastMode.ALL_VIEWS,
                     summation_location=enums.LossSummationLocation.OUTSIDE,
                     denominator_mode=enums.LossDenominatorMode.ALL,
                     positives_cap=-1,
                     scale_by_temperature=True):
    r"""Contrastive loss over features.

  Implemented as described in: https://arxiv.org/abs/2004.11362, Equation 2.

  Given `num_views` different views of each of `batch_size` samples, let `f_i`
  (i \in [1, 2 ... (num_views * batch_size)]) denote each respective feature
  vector. The contrastive loss then takes the following form:

    L = \sum_{i} L_i

  where each L_i is computed as:

    L_i = -\tau * \sum_{k \in P(i)} \log(p_{ik})    (1)

  where P(i) is the set of positives for entry i (distinct from i) and where:

                       \exp(f_i^T f_k / \tau)
    p_{ik} = ----------------------------------------                        (2)
             \sum_{j \in A(i)} \exp(f_i^T f_j / \tau)

  where A(i) is the set of all positives or negatives (distinct from i). `i` is
  the anchor, and \tau is the temperature.

  This maximizes the likelihood of a given (anchor, positive) pair with
  respect to all possible pairs where the first member is the anchor and the
  second member is a positive or a negative.

  A typical way to define a positive is to define samples from the
  same class (but not the anchor itself) regardless of what view they are from.
  Similarly, a typical way to define a negative is for it to be any view of a
  sample from a different class.

  There are two ways to define which feature pairs should be treated as
  positives and negatives. All views of the same sample are always treated as
  positives. You can declare other samples to be positives by providing `labels`
  such that all samples with the same label will be positives for each other.

  If `labels` is not provided then we default to every sample belonging to its
  own unique class. Therefore, the only positive used is another view of the
  anchor itself. This implements the loss as described in:

    https://arxiv.org/pdf/2002.05709.pdf
    A Simple Framework for Contrastive Learning of Visual Representations
    Chen T., Kornblith S., Norouzi M., Hinton G.

  It is recommended to use features whose L_2 norm is 1. since that ensures
  that the loss does not return NaN values without changing the intended
  behaviour of the loss function.

  In (1) above, note that the summation over positives is located outside of the
  \log(). However, one can permute these two operations. The result is Eq. 3 in
  https://arxiv.org/abs/2004.11362. Users can specify the location of the
  summation relative to the \log() via the `summation_location' argmument:
   - 'out': Eq. 2 in https://arxiv.org/abs/2004.11362.
   - 'in' : Eq. 3 in https://arxiv.org/abs/2004.11362.

  Additionally, in (2) above, note that the denominator sums over *all* entries
  distinct from i. One can change which terms are included in the denominator
  via the `denominator_mode` argument:
   - LossDenominatorMode.ALL : All entries (i.e., all negatives and all
             positives) distinct from i are included.
   - LossDenominatorMode.ONE_POSITIVE : All negatives are included but only the
             single positive in the numerator of (2) is included. Any other
             positives are excluded.
   - LossDenominatorMode.ONLY_NEGATIVES: All negatives are included but no
             positives are, not even the single positive in the numerator of
             (2).

  On TPUs, this method will internally perform the cross-replica operations that
  enable using the samples from all cores in computing the loss. The inputs to
  this function should be the features and labels from a single core and each
  core will compute the loss using just these features as anchors, but will use
  positives and negatives from the full global batch. Since the loss for each
  anchor is only computed on one TPU core, it's still necessary to have a
  cross-replica reduction in the final loss computation.

  Also, though it is not applicable to multiview contrastive learning, this
  function will work if |features| contains only 1 view. In the high batch size
  limit, the implemented contrastive loss with only 1 view, positives_cap = 1,
  and temperature = 1.0 is equivalent to the N-pairs loss
  (https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective.pdf)

  Args:
    features: A Tensor of rank at least 3, where the first 2 dimensions are
      batch_size and num_views, and the remaining dimensions are the feature
      shape. Note that when running on TPU, batch_size is the per-core batch
      size.
    labels: One-hot labels to be used to construct the supervised contrastive
      loss. Samples with the same labels are used as positives for each other.
      Labels must have shape [batch_size, num_labels] with numeric dtype and be
      0-1 valued. Note that when running on TPU, batch_size is the per-core
      batch size.
    temperature: Temperature at which softmax evaluation is done. Temperature
      must be a python scalar or scalar Tensor of numeric dtype.
    contrast_mode: LossContrastMode specifying which views get used as anchors
      (f_i in the expression above)
      'ALL_VIEWS': All the views of all samples are used as anchors (f_i in the
        expression above).
      'ONE_VIEW': Just the first view of each sample is used as an anchor (f_i
        in the expression above). This view is called the `core` view against
        which other views are contrasted.
    summation_location: LossSummationLocation specifying location of positives
      summation. See documentation above for more details.
    denominator_mode: LossDenominatorMode specifying which positives to include
      in contrastive denominator. See documentation above for more details.
    positives_cap: Integer maximum number of positives *other* than
      augmentations of anchor. Infinite if < 0. Must be multiple of num_views.
      Including augmentations, a maximum of (positives_cap + num_views - 1)
      positives is possible. This parameter modifies the contrastive numerator
      by selecting which positives are present in the summation, and which
      positives contribure to the denominator if denominator_mode ==
      enums.LossDenominatorMode.ALL.
    scale_by_temperature: Boolean. Whether to scale the loss by `temperature`.
      The loss gradient naturally has a 1/temperature scaling factor, so this
      counteracts it.

  Returns:
    Scalar tensor with contrastive loss value with shape [batch_size] and dtype
    tf.float32. The loss for each batch element is the mean over all views.

  Raises:
    ValueError if the shapes of any of the Tensors are unexpected, or if both
    `labels` and `mask` are not `None`.
  """
    features = tf.convert_to_tensor(features)
    labels = tf.convert_to_tensor(labels) if labels is not None else None

    local_batch_size, num_views = _validate_contrastive_loss_inputs(
        features, labels, contrast_mode, summation_location, denominator_mode,
        positives_cap)

    # Flatten `features` to a single dimension per view per sample so it has shape
    # [local_batch_size, num_views, num_features].
    if features.shape.rank > 3:
        features = tf.reshape(
            features, tf.concat([tf.shape(features)[:2], [-1]], axis=0),
            'flattened_features')
    if features.dtype != tf.float32:
        features = tf.cast(features, tf.float32)

    # Grab the features from all TPU cores. We use the local batch as anchors and
    # the full global batch as contrastives. If not on TPU, global_features is the
    # same as features.
    global_features = utils.cross_replica_concat(features)
    global_batch_size = tf.compat.dimension_at_index(global_features.shape,
                                                     0).value
    local_replica_id = utils.local_tpu_replica_id()

    # Generate the [local_batch_size, global_batch_size] slice of the
    # [global_batch_size, global_batch_size] identity matrix that corresponds to
    # the current replica.
    diagonal_mask = tf.one_hot(
        tf.range(local_batch_size) + (local_replica_id * local_batch_size),
        global_batch_size)

    # Generate `mask` with shape [local_batch_size, global_batch_size] that
    # indicates which samples should be considered positives for each other.
    if labels is None:
        # Defaults to every sample belonging to its own unique class, containing
        # just that sample and other views of it.
        mask = diagonal_mask
    else:
        labels = tf.cast(labels,
                         tf.float32)  # TPU matmul op unsupported for ints.
        global_labels = utils.cross_replica_concat(labels)
        mask = tf.linalg.matmul(labels, global_labels, transpose_b=True)
    mask = tf.ensure_shape(mask, [local_batch_size, global_batch_size])

    # To streamline the subsequent TF, the first two dimensions of
    # `global_features` (i.e., global_batch_size and num_views) should be
    # transposed and then flattened. The result has shape
    # [num_views * global_batch_size, num_features], and its first dimension
    # elements are grouped by view, not by sample.
    all_global_features = tf.reshape(
        tf.transpose(global_features, perm=[1, 0, 2]),
        [num_views * global_batch_size, -1])

    if contrast_mode == enums.LossContrastMode.ONE_VIEW:
        anchor_features = features[:, 0]
        num_anchor_views = 1
    else:  # contrast_mode == enums.LossContrastMode.ALL_VIEWS
        # Reshape features to match how global_features is reshaped above.
        anchor_features = tf.reshape(tf.transpose(features, perm=[1, 0, 2]),
                                     [num_views * local_batch_size, -1])
        num_anchor_views = num_views

    # Generate `logits`, the tensor of (temperature-scaled) dot products of the
    # anchor features with all features. It has shape
    # [local_batch_size * num_anchor_views, global_batch_size * num_views]. To
    # improve numerical stability, subtract out the largest |logits| element in
    # each row from all elements in that row. Since |logits| is only ever used as
    # a ratio of exponentials of |logits| values, this subtraction does not change
    # the results correctness. A stop_gradient() is needed because this change is
    # just for numerical precision.
    logits = tf.linalg.matmul(anchor_features,
                              all_global_features,
                              transpose_b=True)
    temperature = tf.cast(temperature, tf.float32)
    logits = logits / temperature
    logits = (logits -
              tf.reduce_max(tf.stop_gradient(logits), axis=1, keepdims=True))
    exp_logits = tf.exp(logits)

    # The following masks are all tiled by the number of views, i.e., they have
    # shape [local_batch_size * num_anchor_views, global_batch_size * num_views].
    positives_mask, negatives_mask = (_create_tiled_masks(
        mask, diagonal_mask, num_views, num_anchor_views, positives_cap))
    num_positives_per_row = tf.reduce_sum(positives_mask, axis=1)

    if denominator_mode == enums.LossDenominatorMode.ALL:
        denominator = tf.reduce_sum(
            exp_logits * negatives_mask, axis=1,
            keepdims=True) + tf.reduce_sum(
                exp_logits * positives_mask, axis=1, keepdims=True)
    elif denominator_mode == enums.LossDenominatorMode.ONE_POSITIVE:
        denominator = exp_logits + tf.reduce_sum(
            exp_logits * negatives_mask, axis=1, keepdims=True)
    else:  # denominator_mode == enums.LossDenominatorMode.ONLY_NEGATIVES
        denominator = tf.reduce_sum(exp_logits * negatives_mask,
                                    axis=1,
                                    keepdims=True)

    # Note that num_positives_per_row can be zero only if 1 view is used. The
    # various tf.math.divide_no_nan() calls below are to handle this case.
    if summation_location == enums.LossSummationLocation.OUTSIDE:
        log_probs = (logits - tf.math.log(denominator)) * positives_mask
        log_probs = tf.reduce_sum(log_probs, axis=1)
        log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row)
    else:  # summation_location == enums.LossSummationLocation.INSIDE
        log_probs = exp_logits / denominator * positives_mask
        log_probs = tf.reduce_sum(log_probs, axis=1)
        log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row)
        log_probs = tf.math.log(log_probs)

    loss = -log_probs
    if scale_by_temperature:
        loss *= temperature
    loss = tf.reshape(loss, [num_anchor_views, local_batch_size])

    if num_views != 1:
        loss = tf.reduce_mean(loss, axis=0)
    else:
        # The 1 view case requires special handling bc, unlike in the > 1 view case,
        # not all samples are guaranteed to have a positive. Also, no reduction over
        # views is needed.
        num_valid_views_per_sample = (tf.reshape(num_positives_per_row,
                                                 [1, local_batch_size]))
        loss = tf.squeeze(
            tf.math.divide_no_nan(loss, num_valid_views_per_sample))

    return loss
    def testSample(self):
        # TODO(jvdillon): This test should be the basis of a new test fixture which
        # is applied to every distribution. When we make this fixture, we'll also
        # separate the analytical- and sample-based tests as well as for each
        # function tested. For now, we group things so we can recycle one batch of
        # samples (thus saving resources).

        mu = np.array([-1., 1, 0.5], dtype=np.float32)
        diag_large = np.array([1., 0.5, 0.75], dtype=np.float32)
        diag_small = np.array([-1.1, 1.2], dtype=np.float32)
        v = np.array([[0.7, 0.8], [0.9, 1], [0.5, 0.6]],
                     dtype=np.float32)  # shape: [k, r] = [3, 2]

        true_mean = mu
        true_scale = np.diag(diag_large) + np.matmul(
            np.matmul(v, np.diag(diag_small)), v.T)
        true_covariance = np.matmul(true_scale, true_scale.T)
        true_variance = np.diag(true_covariance)
        true_stddev = np.sqrt(true_variance)

        dist = tfd.MultivariateNormalDiagPlusLowRank(
            loc=mu,
            scale_diag=diag_large,
            scale_perturb_factor=v,
            scale_perturb_diag=diag_small,
            validate_args=True)

        # The following distributions will test the KL divergence calculation.
        mvn_identity = tfd.MultivariateNormalDiag(loc=np.array(
            [1., 2, 0.25], dtype=np.float32),
                                                  validate_args=True)
        mvn_scaled = tfd.MultivariateNormalDiag(loc=mvn_identity.loc,
                                                scale_identity_multiplier=2.2,
                                                validate_args=True)
        mvn_diag = tfd.MultivariateNormalDiag(loc=mvn_identity.loc,
                                              scale_diag=np.array(
                                                  [0.5, 1.5, 1.],
                                                  dtype=np.float32),
                                              validate_args=True)
        mvn_chol = tfd.MultivariateNormalTriL(
            loc=np.array([1., 2, -1], dtype=np.float32),
            scale_tril=np.array([[6., 0, 0], [2, 5, 0], [1, 3, 4]],
                                dtype=np.float32) / 10.,
            validate_args=True)

        scale = dist.scale.to_dense()

        n = int(30e3)
        samps = dist.sample(n,
                            seed=tfp_test_util.test_seed(hardcoded_seed=0,
                                                         set_eager_seed=False))
        sample_mean = tf.reduce_mean(input_tensor=samps, axis=0)
        x = samps - sample_mean
        sample_covariance = tf.matmul(x, x, transpose_a=True) / n

        sample_kl_identity = tf.reduce_mean(input_tensor=dist.log_prob(samps) -
                                            mvn_identity.log_prob(samps),
                                            axis=0)
        analytical_kl_identity = tfd.kl_divergence(dist, mvn_identity)

        sample_kl_scaled = tf.reduce_mean(input_tensor=dist.log_prob(samps) -
                                          mvn_scaled.log_prob(samps),
                                          axis=0)
        analytical_kl_scaled = tfd.kl_divergence(dist, mvn_scaled)

        sample_kl_diag = tf.reduce_mean(input_tensor=dist.log_prob(samps) -
                                        mvn_diag.log_prob(samps),
                                        axis=0)
        analytical_kl_diag = tfd.kl_divergence(dist, mvn_diag)

        sample_kl_chol = tf.reduce_mean(input_tensor=dist.log_prob(samps) -
                                        mvn_chol.log_prob(samps),
                                        axis=0)
        analytical_kl_chol = tfd.kl_divergence(dist, mvn_chol)

        n = int(10e3)
        baseline = tfd.MultivariateNormalDiag(loc=np.array([-1., 0.25, 1.25],
                                                           dtype=np.float32),
                                              scale_diag=np.array(
                                                  [1.5, 0.5, 1.],
                                                  dtype=np.float32),
                                              validate_args=True)
        samps = baseline.sample(n, seed=tfp_test_util.test_seed())

        sample_kl_identity_diag_baseline = tf.reduce_mean(
            input_tensor=baseline.log_prob(samps) -
            mvn_identity.log_prob(samps),
            axis=0)
        analytical_kl_identity_diag_baseline = tfd.kl_divergence(
            baseline, mvn_identity)

        sample_kl_scaled_diag_baseline = tf.reduce_mean(
            input_tensor=baseline.log_prob(samps) - mvn_scaled.log_prob(samps),
            axis=0)
        analytical_kl_scaled_diag_baseline = tfd.kl_divergence(
            baseline, mvn_scaled)

        sample_kl_diag_diag_baseline = tf.reduce_mean(
            input_tensor=baseline.log_prob(samps) - mvn_diag.log_prob(samps),
            axis=0)
        analytical_kl_diag_diag_baseline = tfd.kl_divergence(
            baseline, mvn_diag)

        sample_kl_chol_diag_baseline = tf.reduce_mean(
            input_tensor=baseline.log_prob(samps) - mvn_chol.log_prob(samps),
            axis=0)
        analytical_kl_chol_diag_baseline = tfd.kl_divergence(
            baseline, mvn_chol)

        [
            sample_mean_,
            analytical_mean_,
            sample_covariance_,
            analytical_covariance_,
            analytical_variance_,
            analytical_stddev_,
            scale_,
            sample_kl_identity_,
            analytical_kl_identity_,
            sample_kl_scaled_,
            analytical_kl_scaled_,
            sample_kl_diag_,
            analytical_kl_diag_,
            sample_kl_chol_,
            analytical_kl_chol_,
            sample_kl_identity_diag_baseline_,
            analytical_kl_identity_diag_baseline_,
            sample_kl_scaled_diag_baseline_,
            analytical_kl_scaled_diag_baseline_,
            sample_kl_diag_diag_baseline_,
            analytical_kl_diag_diag_baseline_,
            sample_kl_chol_diag_baseline_,
            analytical_kl_chol_diag_baseline_,
        ] = self.evaluate([
            sample_mean,
            dist.mean(),
            sample_covariance,
            dist.covariance(),
            dist.variance(),
            dist.stddev(),
            scale,
            sample_kl_identity,
            analytical_kl_identity,
            sample_kl_scaled,
            analytical_kl_scaled,
            sample_kl_diag,
            analytical_kl_diag,
            sample_kl_chol,
            analytical_kl_chol,
            sample_kl_identity_diag_baseline,
            analytical_kl_identity_diag_baseline,
            sample_kl_scaled_diag_baseline,
            analytical_kl_scaled_diag_baseline,
            sample_kl_diag_diag_baseline,
            analytical_kl_diag_diag_baseline,
            sample_kl_chol_diag_baseline,
            analytical_kl_chol_diag_baseline,
        ])

        sample_variance_ = np.diag(sample_covariance_)
        sample_stddev_ = np.sqrt(sample_variance_)

        tf1.logging.vlog(2, "true_mean:\n{}  ".format(true_mean))
        tf1.logging.vlog(2, "sample_mean:\n{}".format(sample_mean_))
        tf1.logging.vlog(2, "analytical_mean:\n{}".format(analytical_mean_))

        tf1.logging.vlog(2, "true_covariance:\n{}".format(true_covariance))
        tf1.logging.vlog(2,
                         "sample_covariance:\n{}".format(sample_covariance_))
        tf1.logging.vlog(
            2, "analytical_covariance:\n{}".format(analytical_covariance_))

        tf1.logging.vlog(2, "true_variance:\n{}".format(true_variance))
        tf1.logging.vlog(2, "sample_variance:\n{}".format(sample_variance_))
        tf1.logging.vlog(
            2, "analytical_variance:\n{}".format(analytical_variance_))

        tf1.logging.vlog(2, "true_stddev:\n{}".format(true_stddev))
        tf1.logging.vlog(2, "sample_stddev:\n{}".format(sample_stddev_))
        tf1.logging.vlog(2,
                         "analytical_stddev:\n{}".format(analytical_stddev_))

        tf1.logging.vlog(2, "true_scale:\n{}".format(true_scale))
        tf1.logging.vlog(2, "scale:\n{}".format(scale_))

        tf1.logging.vlog(
            2, "kl_identity:  analytical:{}  sample:{}".format(
                analytical_kl_identity_, sample_kl_identity_))

        tf1.logging.vlog(
            2, "kl_scaled:    analytical:{}  sample:{}".format(
                analytical_kl_scaled_, sample_kl_scaled_))

        tf1.logging.vlog(
            2, "kl_diag:      analytical:{}  sample:{}".format(
                analytical_kl_diag_, sample_kl_diag_))

        tf1.logging.vlog(
            2, "kl_chol:      analytical:{}  sample:{}".format(
                analytical_kl_chol_, sample_kl_chol_))

        tf1.logging.vlog(
            2, "kl_identity_diag_baseline:  analytical:{}  sample:{}".format(
                analytical_kl_identity_diag_baseline_,
                sample_kl_identity_diag_baseline_))

        tf1.logging.vlog(
            2, "kl_scaled_diag_baseline:  analytical:{}  sample:{}".format(
                analytical_kl_scaled_diag_baseline_,
                sample_kl_scaled_diag_baseline_))

        tf1.logging.vlog(
            2, "kl_diag_diag_baseline:  analytical:{}  sample:{}".format(
                analytical_kl_diag_diag_baseline_,
                sample_kl_diag_diag_baseline_))

        tf1.logging.vlog(
            2, "kl_chol_diag_baseline:  analytical:{}  sample:{}".format(
                analytical_kl_chol_diag_baseline_,
                sample_kl_chol_diag_baseline_))

        self.assertAllClose(true_mean, sample_mean_, atol=0., rtol=0.02)
        self.assertAllClose(true_mean, analytical_mean_, atol=0., rtol=1e-6)

        self.assertAllClose(true_covariance,
                            sample_covariance_,
                            atol=0.,
                            rtol=0.02)
        self.assertAllClose(true_covariance,
                            analytical_covariance_,
                            atol=0.,
                            rtol=1e-6)

        self.assertAllClose(true_variance,
                            sample_variance_,
                            atol=0.,
                            rtol=0.02)
        self.assertAllClose(true_variance,
                            analytical_variance_,
                            atol=0.,
                            rtol=1e-6)

        self.assertAllClose(true_stddev, sample_stddev_, atol=0., rtol=0.02)
        self.assertAllClose(true_stddev,
                            analytical_stddev_,
                            atol=0.,
                            rtol=1e-6)

        self.assertAllClose(true_scale, scale_, atol=0., rtol=1e-6)

        self.assertAllClose(sample_kl_identity_,
                            analytical_kl_identity_,
                            atol=0.,
                            rtol=0.02)
        self.assertAllClose(sample_kl_scaled_,
                            analytical_kl_scaled_,
                            atol=0.,
                            rtol=0.02)
        self.assertAllClose(sample_kl_diag_,
                            analytical_kl_diag_,
                            atol=0.,
                            rtol=0.02)
        self.assertAllClose(sample_kl_chol_,
                            analytical_kl_chol_,
                            atol=0.,
                            rtol=0.02)

        self.assertAllClose(sample_kl_identity_diag_baseline_,
                            analytical_kl_identity_diag_baseline_,
                            atol=0.,
                            rtol=0.02)
        self.assertAllClose(sample_kl_scaled_diag_baseline_,
                            analytical_kl_scaled_diag_baseline_,
                            atol=0.,
                            rtol=0.02)
        self.assertAllClose(sample_kl_diag_diag_baseline_,
                            analytical_kl_diag_diag_baseline_,
                            atol=0.,
                            rtol=0.04)
        self.assertAllClose(sample_kl_chol_diag_baseline_,
                            analytical_kl_chol_diag_baseline_,
                            atol=0.,
                            rtol=0.02)
Ejemplo n.º 30
0
 def loss(logits, labels):
     """Calculates cross entropy loss."""
     diff = -(labels * tf.math.log(logits))
     loss = tf.reduce_mean(diff)
     return loss