def _test_vectorization(self, dist_name, dist):
        seed = test_util.test_seed()

        # TODO(b/171752261): New stateless samplers don't work with pfor.
        enable_auto_vectorized_sampling = False

        num_samples = 3
        if (not enable_auto_vectorized_sampling
                or dist_name in SAMPLE_AUTOVECTORIZATION_IS_BROKEN):
            sample = self.evaluate(dist.sample(num_samples, seed=seed))
        else:
            sample = self.evaluate(
                tf.vectorized_map(lambda i: dist.sample(seed=seed),
                                  tf.range(num_samples),
                                  fallback_to_while_loop=False))
        hp.note('Drew samples {}'.format(sample))

        if dist_name not in LOGPROB_AUTOVECTORIZATION_IS_BROKEN:
            pfor_lp = tf.vectorized_map(dist.log_prob,
                                        tf.convert_to_tensor(sample),
                                        fallback_to_while_loop=False)
            batch_lp = dist.log_prob(sample)
            pfor_lp_, batch_lp_ = self.evaluate((pfor_lp, batch_lp))
            self.assertAllClose(pfor_lp_,
                                batch_lp_,
                                atol=VECTORIZED_LOGPROB_ATOL[dist_name],
                                rtol=VECTORIZED_LOGPROB_RTOL[dist_name])
Esempio n. 2
0
  def testAutoVectorization(self, bijector_name, data):

    # TODO(b/150161911): reconcile numeric behavior of eager and graph mode.
    if tf.executing_eagerly():
      return

    bijector, event_dim = self._draw_bijector(
        bijector_name, data,
        batch_shape=[],  # Avoid conflict with vmap sample dimension.
        validate_args=False,  # Work around lack of `If` support in vmap.
        allowed_bijectors=(set(bhps.INSTANTIABLE_BIJECTORS) -
                           set(AUTOVECTORIZATION_IS_BROKEN)))
    atol = AUTOVECTORIZATION_ATOL[bijector_name]
    rtol = AUTOVECTORIZATION_RTOL[bijector_name]

    # Forward
    n = 3
    xs = self._draw_domain_tensor(bijector, data, event_dim, sample_shape=[n])
    ys = bijector.forward(xs)
    vectorized_ys = tf.vectorized_map(bijector.forward, xs,
                                      fallback_to_while_loop=False)
    self.assertAllClose(*self.evaluate((ys, vectorized_ys)),
                        atol=atol, rtol=rtol)

    # FLDJ
    event_ndims = data.draw(
        hps.integers(
            min_value=bijector.forward_min_event_ndims,
            max_value=ps.rank_from_shape(xs.shape) - 1))
    fldj_fn = functools.partial(bijector.forward_log_det_jacobian,
                                event_ndims=event_ndims)
    vectorized_fldj = tf.vectorized_map(fldj_fn, xs,
                                        fallback_to_while_loop=False)
    fldj = tf.broadcast_to(fldj_fn(xs), tf.shape(vectorized_fldj))
    self.assertAllClose(*self.evaluate((fldj, vectorized_fldj)),
                        atol=atol, rtol=rtol)

    # Inverse
    ys = self._draw_codomain_tensor(bijector, data, event_dim, sample_shape=[n])
    xs = bijector.inverse(ys)
    vectorized_xs = tf.vectorized_map(bijector.inverse, ys,
                                      fallback_to_while_loop=False)
    self.assertAllClose(*self.evaluate((xs, vectorized_xs)),
                        atol=atol, rtol=rtol)

    # ILDJ
    event_ndims = data.draw(
        hps.integers(
            min_value=bijector.inverse_min_event_ndims,
            max_value=ps.rank_from_shape(ys.shape) - 1))
    ildj_fn = functools.partial(bijector.inverse_log_det_jacobian,
                                event_ndims=event_ndims)
    vectorized_ildj = tf.vectorized_map(ildj_fn, ys,
                                        fallback_to_while_loop=False)
    ildj = tf.broadcast_to(ildj_fn(ys), tf.shape(vectorized_ildj))
    self.assertAllClose(*self.evaluate((ildj, vectorized_ildj)),
                        atol=atol, rtol=rtol)
Esempio n. 3
0
    def testPForInterop(self):
        def outer_product(a):
            return np.tensordot(a, a, 0)

        batch_size = 100
        a = np.ones((batch_size, 32, 32))
        c = tf.vectorized_map(outer_product, a)

        self.assertIsInstance(c, np.ndarray)
        self.assertEqual(c.shape, (batch_size, 32, 32, 32, 32))

        c = tf.vectorized_map(lambda x: x.T, a)

        self.assertIsInstance(c, np.ndarray)
        self.assertEqual(c.shape, (batch_size, 32, 32))
  def call(self, inputs):
    bins = [tf.cast(tf.compat.v1.squeeze(self.bins), tf.float32)]

    def _bucketize_fn(inputs):
      return tf.raw_ops.BoostedTreesBucketize(
          float_values=[tf.cast(inputs, tf.float32)],
          bucket_boundaries=bins)[0]

    if tf_utils.is_ragged(inputs):
      integer_buckets = tf.ragged.map_flat_values(
          _bucketize_fn, inputs)
      # Ragged map_flat_values doesn't touch the non-values tensors in the
      # ragged composite tensor. If this op is the only op a Keras model,
      # this can cause errors in Graph mode, so wrap the tensor in an identity.
      return tf.identity(integer_buckets)
    elif isinstance(inputs, tf.SparseTensor):
      return tf.SparseTensor(
          indices=tf.identity(inputs.indices),
          values=_bucketize_fn(inputs.values),
          dense_shape=tf.identity(inputs.dense_shape))
    else:
      static_shape = inputs.get_shape()
      if any(dim is None for dim in static_shape.as_list()[1:]):
        raise NotImplementedError(
            "Discretization Layer requires known non-batch shape,"
            "found {}".format(static_shape))

      dynamic_shape = tf.shape(inputs)
      # BoostedTreesBucketize only handles rank 1 inputs. We need to flatten our
      # inputs after batch size and vectorized_map over each sample.
      reshaped = tf.reshape(inputs, [dynamic_shape[0], -1])
      return tf.reshape(
          tf.vectorized_map(_bucketize_fn, reshaped),
          dynamic_shape)
Esempio n. 5
0
  def test_sampled_scale_follows_correct_distribution(self):
    strm = test_util.test_seed_stream()
    prior = tfd.InverseGamma(concentration=0.1, scale=0.1)

    num_timesteps = 100
    observed_samples = tf.random.normal([2, num_timesteps], seed=strm()) * 3.
    is_missing = tf.random.uniform([2, num_timesteps], seed=strm()) > 0.9

    # Check that posterior variance samples have the moments of the correct
    # InverseGamma distribution.
    posterior_scale_samples = tf.vectorized_map(
        lambda seed: gibbs_sampler._resample_scale(  # pylint: disable=g-long-lambda
            prior=prior,
            observed_residuals=observed_samples,
            is_missing=is_missing,
            seed=seed),
        tfp.random.split_seed(strm(), tf.constant(10000)))

    concentration = prior.concentration + tf.reduce_sum(
        1 - tf.cast(is_missing, tf.float32), axis=-1)/2.
    scale = prior.scale + tf.reduce_sum(
        (observed_samples * tf.cast(~is_missing, tf.float32))**2, axis=-1)/2.
    posterior_scale_samples_, concentration_, scale_ = self.evaluate(
        (posterior_scale_samples, concentration, scale))
    self.assertAllClose(np.mean(posterior_scale_samples_**2, axis=0),
                        scale_ / (concentration_ - 1), atol=0.05)
    self.assertAllClose(
        np.std(posterior_scale_samples_**2, axis=0),
        scale_ / ((concentration_ - 1) * np.sqrt(concentration_ - 2)),
        atol=0.05)
Esempio n. 6
0
  def test_batching(self, input_batch_shape, kernel_batch_shape):
    input_shape = (12, 12, 2)
    filter_shape = (3, 3)
    channels_out = 4
    strides = 2
    dilations = (1, 1)
    padding = 'SAME'

    x, k = _make_input_and_kernel(
        self.make_input,
        input_batch_shape=input_batch_shape,
        input_shape=input_shape,
        kernel_batch_shape=kernel_batch_shape,
        filter_shape=filter_shape,
        channels_out=channels_out,
        dtype=self.dtype)

    conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations)
    y_batched = conv_fn(x, k)

    broadcast_batch_shape = ps.broadcast_shape(
        input_batch_shape, kernel_batch_shape)
    broadcasted_input = tf.broadcast_to(
        x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0))
    broadcasted_kernel = tf.broadcast_to(
        k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0))

    flat_y = tf.reshape(
        y_batched,
        shape=ps.pad(
            ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1))
    flat_x = tf.reshape(
        broadcasted_input,
        shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1))
    flat_tf_kernel = tf.einsum(
        '...ij->...ji',
        tf.reshape(
            broadcasted_kernel,
            shape=ps.concat(
                [(-1,), filter_shape, (input_shape[-1], channels_out)],
                axis=0)))

    rank = 2
    output_shape, strides_ = convolution_util._get_output_shape(
        rank=rank, strides=(strides,) * rank, padding=padding,
        dilations=dilations, input_shape=input_shape, output_size=channels_out,
        filter_shape=filter_shape)

    y_expected = tf.vectorized_map(
        lambda args: tf.nn.conv2d_transpose(  # pylint: disable=g-long-lambda
            args[0][tf.newaxis],
            args[1],
            output_shape=ps.concat([[1], output_shape], axis=0),
            strides=strides_,
            padding=padding),
        elems=(flat_x, flat_tf_kernel))

    [y_actual_, y_expected_] = self.evaluate(
        [flat_y, tf.squeeze(y_expected, axis=1)])
    self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0)
  def _test_vectorization(self, dist_name, dist):
    seed = test_util.test_seed()

    num_samples = 3
    if dist_name in SAMPLE_AUTOVECTORIZATION_IS_BROKEN:
      sample = self.evaluate(dist.sample(num_samples, seed=seed))
    else:
      sample = self.evaluate(tf.vectorized_map(
          lambda i: dist.sample(seed=seed), tf.range(num_samples)))
    hp.note('Drew samples {}'.format(sample))

    if dist_name not in LOGPROB_AUTOVECTORIZATION_IS_BROKEN:
      pfor_lp = tf.vectorized_map(dist.log_prob, tf.convert_to_tensor(sample))
      batch_lp = dist.log_prob(sample)
      pfor_lp_, batch_lp_ = self.evaluate((pfor_lp, batch_lp))
      self.assertAllClose(pfor_lp_, batch_lp_,
                          atol=VECTORIZED_LOGPROB_ATOL[dist_name])
Esempio n. 8
0
    def test_can_return_distribution_from_vectorized_map(self):
        def fn(x):
            dist = AutoNormal(loc=x, scale=[3., 5.])
            return dist._broadcast_parameters_with_batch_shape(
                tf.ones_like(dist.batch_shape_tensor()))

        batch_dist = tf.vectorized_map(fn, tf.convert_to_tensor([1., 2., 3.]))
        self.assertAllEqual(batch_dist.batch_shape, [3, 2])
Esempio n. 9
0
    def _f(*args):
        tf_args = tf.nest.map_structure(lambda x: tf_np.asarray(x).data, args)

        def tf_f(x):
            return f(*x)

        outputs = tf.vectorized_map(tf_f, tf_args)
        return tf.nest.map_structure(tf_np.asarray, outputs)
    def test_vectorized_map(self):
        initial_value = tf.ones([5, 3])
        x = tfp.util.TransformedVariable(initial_value, tfb.Sigmoid())

        # TODO(emilyaf): Remove `convert_to_tensor` after tf.Variables are
        # CompositeTensor.
        y = tf.vectorized_map(lambda v: v + 2., tf.convert_to_tensor(x))
        self.evaluate([v.initializer for v in x.trainable_variables])
        self.assertAllClose(self.evaluate(y), initial_value + 2.)
Esempio n. 11
0
  def test_sampled_weights_follow_correct_distribution(self):
    seed = test_util.test_seed(sampler_type='stateless')
    design_seed, true_weights_seed, sampled_weights_seed = samplers.split_seed(
        seed, 3, 'test_sampled_weights_follow_correct_distribution')
    num_timesteps = 10
    num_features = 2
    batch_shape = [3, 1]
    design_matrix = self.evaluate(samplers.normal(
        batch_shape + [num_timesteps, num_features], seed=design_seed))
    true_weights = self.evaluate(samplers.normal(
        batch_shape + [num_features, 1], seed=true_weights_seed) * 10.0)
    targets = np.matmul(design_matrix, true_weights)
    is_missing = np.array([False, False, False, True, True,
                           False, False, True, False, False])
    prior_scale = tf.convert_to_tensor(5.)
    likelihood_scale = tf.convert_to_tensor(0.1)

    # Analytically compute the true posterior distribution on weights.
    valid_design_matrix = design_matrix[..., ~is_missing, :]
    valid_targets = targets[..., ~is_missing, :]
    num_valid_observations = tf.shape(valid_design_matrix)[-2]
    weights_posterior_mean, weights_posterior_cov, _ = linear_gaussian_update(
        prior_mean=tf.zeros([num_features, 1]),
        prior_cov=tf.eye(num_features) * prior_scale**2,
        observation_matrix=tfl.LinearOperatorFullMatrix(valid_design_matrix),
        observation_noise=tfd.MultivariateNormalDiag(
            loc=tf.zeros([num_valid_observations]),
            scale_diag=likelihood_scale * tf.ones([num_valid_observations])),
        x_observed=valid_targets)

    # Check that the empirical moments of sampled weights match the true values.
    sampled_weights = tf.vectorized_map(
        lambda seed: gibbs_sampler._resample_weights(  # pylint: disable=g-long-lambda
            design_matrix=tf.where(is_missing[..., tf.newaxis],
                                   tf.zeros_like(design_matrix),
                                   design_matrix),
            target_residuals=targets[..., 0],
            observation_noise_scale=likelihood_scale,
            weights_prior_scale=tf.linalg.LinearOperatorScaledIdentity(
                num_features, prior_scale),
            seed=seed),
        tfp.random.split_seed(sampled_weights_seed, tf.constant(10000)))
    sampled_weights_mean = tf.reduce_mean(sampled_weights, axis=0)
    centered_weights = sampled_weights - weights_posterior_mean[..., 0]
    sampled_weights_cov = tf.reduce_mean(centered_weights[..., :, tf.newaxis] *
                                         centered_weights[..., tf.newaxis, :],
                                         axis=0)

    (sampled_weights_mean_, weights_posterior_mean_,
     sampled_weights_cov_, weights_posterior_cov_) = self.evaluate((
         sampled_weights_mean, weights_posterior_mean[..., 0],
         sampled_weights_cov, weights_posterior_cov))
    self.assertAllClose(sampled_weights_mean_, weights_posterior_mean_,
                        atol=0.01, rtol=0.05)
    self.assertAllClose(sampled_weights_cov_, weights_posterior_cov_,
                        atol=0.01, rtol=0.05)
Esempio n. 12
0
    def testPForInterop(self):
        def outer_product(a):
            return np.tensordot(a, a, 0)

        batch_size = 100
        a = np.ones((batch_size, 32, 32))
        c = tf.vectorized_map(outer_product, a)

        # # TODO(nareshmodi): vectorized_map doesn't rewrap tensors in ndarray.
        # self.assertIsInstance(c, np.ndarray)
        self.assertEqual(c.shape, (batch_size, 32, 32, 32, 32))
Esempio n. 13
0
 def testReproduceVmap1(self, dtype):
     # Regression test for b/145554459
     loc = tf.constant(-200., dtype=dtype)
     scale = tf.constant(2.188274e+01, dtype=dtype)
     high = tf.constant(113.33857, dtype=dtype)
     low = tf.constant(102.94414, dtype=dtype)
     # Not validating args b/c the assertions confuse pfor.
     dist = tfd.TruncatedNormal(loc, scale, low, high, validate_args=False)
     sample = tf.constant([102.950745, 103.87256, 107.78299], dtype=dtype)
     batch_lp = dist.log_prob(sample)
     pfor_lp = tf.vectorized_map(dist.log_prob, sample)
     batch_lp_, pfor_lp_ = self.evaluate((batch_lp, pfor_lp))
     self.assertAllClose(batch_lp_, pfor_lp_, atol=1e-6)
Esempio n. 14
0
    def test_automatic_conversion_to_tensor(self):
        v = tf.Variable(tf.ones([5]))
        d = tfd.Normal(tf.zeros([5]), v)
        x = tf.convert_to_tensor([3.])

        vectorized_log_prob = tf.vectorized_map(lambda z: z.log_prob(x), d)
        log_prob = d.log_prob(x)
        self.evaluate(v.initializer)
        self.assertAllClose(vectorized_log_prob[:, 0], log_prob)

        loc = tf.Variable(0.)
        self.evaluate(loc.initializer)
        cond_dist = tf.cond(tf.convert_to_tensor(True),
                            lambda: tfd.Normal(loc, 1.),
                            lambda: tfd.Normal(0., 1.))
        self.assertIsInstance(cond_dist, tfd.Normal)
Esempio n. 15
0
 def _get_mode(samples):
     _, idx, count = tf.raw_ops.UniqueWithCountsV2(x=samples, axis=[0])
     # TODO(b/161402486): Remove this hack for fixing the wrong static shape
     # of `idx` in graph mode.
     idx = tf.vectorized_map(lambda x: tf.reshape(x, [-1])[0], idx)
     # NOTE:
     #  - `count` has shape `[K]`, where `K` is the number of unique elements,
     #    and `count[j]` is the number of times the j-th unique element occurs
     #    in `samples`.
     #  - `idx` has shape `[samples.shape[0]]`, and `idx[i] == j` means that
     #    `samples[i]` is equal to the `j`-th unique element.
     max_count_idx = tf.argmax(count, output_type=tf.int32)
     # Return an index `i` for which `idx[i] == max_count_idx`.
     return tf.argmax(tf.cast(tf.math.equal(idx, max_count_idx),
                              dtype=tf.int32),
                      output_type=tf.int32)
            def per_example_test_nll(params):
                """Computes per-example test NLL."""
                test_y_idx = np.stack([
                    dataset.test_student_ids - 1, dataset.test_question_ids - 1
                ],
                                      axis=-1)

                dense_nll = (-test_joint_dist.sample_distributions(
                    value=params)[0][-1].distribution.log_prob(test_dense_y))
                vectorized_dense_nll = tf.reshape(
                    dense_nll, [-1, num_students, num_questions])
                # TODO(siege): Avoid using vmap here.
                log_prob_y = tf.vectorized_map(
                    lambda nll: tf.gather_nd(nll, test_y_idx),
                    vectorized_dense_nll)
                return tf.reshape(
                    log_prob_y,
                    list(params[0].shape) + [test_y_idx.shape[0]])
Esempio n. 17
0
def _vectorize_parameters(f, params, use_pfor, dtype):
    """Loop over `params`, providing a one-hot mask to `f` for each."""
    parameter_sizes = [tf.size(param) for param in params]
    total_size = tf.math.add_n(parameter_sizes)

    def _wrapper(index):
        full_onehot = tf.one_hot(index, total_size)
        split_onehot = tf.split(full_onehot, parameter_sizes)
        tangents = [
            tf.reshape(v, tf.shape(param))
            for param, v in zip(params, split_onehot)
        ]
        return f(tangents)

    if use_pfor:
        return tf.vectorized_map(_wrapper, tf.range(total_size))
    else:
        return tf.map_fn(_wrapper, tf.range(total_size), dtype)
 def testReproduceVmap2(self, dtype):
   # Regression test for b/150811273
   if dtype == np.float32:
     raise unittest.SkipTest('b/150811273')
   seed = test_util.test_seed()
   loc = tf.constant(-12.500191, dtype=dtype)
   scale = tf.constant(1e-06, dtype=dtype)
   high = tf.constant(-12.502851, dtype=dtype)
   low = tf.constant(-187.50009, dtype=dtype)
   # Not validating args b/c the assertions confuse pfor.
   dist = tfd.TruncatedNormal(loc, scale, low, high, validate_args=False)
   # At the default seed, the sample comes out as [-12.502851 -12.502851
   # -12.502851], but that's also weird.  At a scale of 1e-6, the samples
   # should cluster more tightly around the location, which is -12.500191.
   sample = self.evaluate(dist.sample(3, seed=seed))
   batch_lp = dist.log_prob(sample)
   pfor_lp = tf.vectorized_map(dist.log_prob, tf.convert_to_tensor(sample))
   batch_lp_, pfor_lp_ = self.evaluate((batch_lp, pfor_lp))
   self.assertAllClose(batch_lp_, pfor_lp_, atol=1e-6)
Esempio n. 19
0
    def test_vectorized_map(self):
        batch_size = 10
        num_features = 32
        layer = tf.keras.layers.Dense(1)

        def model_fn(arg):
            with tf.GradientTape() as g:
                inp, label = arg
                inp = tf.expand_dims(inp, 0)
                label = tf.expand_dims(label, 0)
                prediction = layer(inp)
                loss = tf.nn.l2_loss(label - prediction)
            return g.gradient(loss, (layer.kernel, layer.bias))

        inputs = tf.random.uniform([batch_size, num_features])
        labels = tf.random.uniform([batch_size, 1])
        per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels))
        self.assertEqual(per_example_gradients[0].shape,
                         (batch_size, num_features, 1))
        self.assertEqual(per_example_gradients[1].shape, (batch_size, 1))
Esempio n. 20
0
    def call(self, inputs):
        def _bucketize_op(bins):
            bins = [tf.cast(bins, tf.float32)]
            return lambda inputs: tf.raw_ops.BoostedTreesBucketize(  # pylint: disable=g-long-lambda
                float_values=[tf.cast(inputs, tf.float32)],
                bucket_boundaries=bins)[0]

        if tf_utils.is_ragged(inputs):
            integer_buckets = tf.ragged.map_flat_values(
                _bucketize_op(tf.compat.v1.squeeze(self.bins)), inputs)
            # Ragged map_flat_values doesn't touch the non-values tensors in the
            # ragged composite tensor. If this op is the only op a Keras model,
            # this can cause errors in Graph mode, so wrap the tensor in an identity.
            return tf.identity(integer_buckets)
        elif isinstance(inputs, tf.SparseTensor):
            integer_buckets = tf.raw_ops.BoostedTreesBucketize(
                float_values=[tf.cast(inputs.values, tf.float32)],
                bucket_boundaries=[
                    tf.cast(tf.compat.v1.squeeze(self.bins), tf.float32)
                ])[0]
            return tf.SparseTensor(indices=tf.identity(inputs.indices),
                                   values=integer_buckets,
                                   dense_shape=tf.identity(inputs.dense_shape))
        else:
            input_shape = inputs.get_shape()
            if any(dim is None for dim in input_shape.as_list()[1:]):
                raise NotImplementedError(
                    "Discretization Layer requires known non-batch shape,"
                    "found {}".format(input_shape))

            reshaped = tf.reshape(
                inputs,
                [-1,
                 tf.raw_ops.Prod(input=input_shape.as_list()[1:], axis=0)])

            return tf.reshape(
                tf.vectorized_map(
                    _bucketize_op(tf.compat.v1.squeeze(self.bins)), reshaped),
                tf.constant([-1] + input_shape.as_list()[1:]))
Esempio n. 21
0
    def vectorized_fn(*args):
        """Vectorized version of `fn` that accepts arguments of any rank."""
        with tf.name_scope(name or 'make_rank_polymorphic'):
            # If we got a single value for core_ndims, tile it across all args.
            core_ndims_structure = (core_ndims if tf.nest.is_nested(core_ndims)
                                    else tf.nest.map_structure(
                                        lambda _: core_ndims, args))

            # Build flat lists of all argument parts and their corresponding core
            # ndims.
            flat_core_ndims = tf.nest.flatten(core_ndims_structure)
            flat_args = nest.flatten_up_to(core_ndims_structure,
                                           args,
                                           check_types=False)

            # Filter to only the `Tensor`-valued args (taken to be those with `None`
            # values for `core_ndims`). Other args will be passed through to `fn`
            # unmodified.
            (vectorized_arg_core_ndims, vectorized_args,
             fn_of_vectorized_args) = _lock_in_non_vectorized_args(
                 fn,
                 arg_structure=core_ndims_structure,
                 flat_core_ndims=flat_core_ndims,
                 flat_args=flat_args)

            # `vectorized_map` requires all inputs to have a single, common batch
            # dimension `[n]`. So we broadcast all input parts to a common
            # batch shape, then flatten it down to a single dimension.
            vectorized_arg_shapes = [ps.shape(arg) for arg in vectorized_args]

            vectorized_arg_actual_core_ndims = []
            batch_shapes, core_shapes = [], []
            for (arg_shape, core_nd) in zip(vectorized_arg_shapes,
                                            vectorized_arg_core_ndims):
                arg_nd = ps.rank_from_shape(arg_shape)
                # Shrink 'core' ndims of rank-deficient args. This guarantees that
                # `batch_ndims` is always nonnegative.
                actual_core_nd = ps.minimum(arg_nd, core_nd)
                vectorized_arg_actual_core_ndims.append(actual_core_nd)

                batch_ndims = arg_nd - actual_core_nd
                batch_shapes.append(arg_shape[:batch_ndims])
                core_shapes.append(arg_shape[batch_ndims:])

            # Flatten all of the batch dimensions into one.
            broadcast_batch_shape = (functools.reduce(ps.broadcast_shape,
                                                      batch_shapes, []))
            n = ps.cast(ps.reduce_prod(broadcast_batch_shape), tf.int32)

            static_n = tf.get_static_value(n)
            if static_n == 1:
                # We can bypass `vectorized_map` if the batch shape is `[]`, `[1]`,
                # `[1, 1]`, etc., just by flattening to batch shape `[]`.
                result_batch_dims = 0
                batched_result = fn_of_vectorized_args(
                    tf.nest.map_structure(lambda x, nd: tf.reshape(
                        x,
                        ps.shape(x)[ps.rank(x) - nd:]),
                                          vectorized_args,
                                          vectorized_arg_actual_core_ndims,
                                          check_types=False))
            else:
                # Pad all input parts to the common shape, then flatten
                # into the single leading dimension `[n]`.
                # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap
                # when batch rank is static so that we can exploit broadcasting.
                broadcast_vectorized_args = [
                    tf.broadcast_to(
                        part,
                        ps.concat([broadcast_batch_shape, core_shape], axis=0))
                    for (part, core_shape) in zip(vectorized_args, core_shapes)
                ]
                vectorized_args_with_flattened_batch_dim = [
                    tf.reshape(part, ps.concat([[n], core_shape], axis=0))
                    for (part, core_shape
                         ) in zip(broadcast_vectorized_args, core_shapes)
                ]
                result_batch_dims = 1
                batched_result = tf.vectorized_map(
                    fn_of_vectorized_args,
                    vectorized_args_with_flattened_batch_dim)

            # Unflatten any `Tensor`s in the result.
            unflatten = lambda x: tf.reshape(
                x,
                ps.concat(
                    [  # pylint: disable=g-long-lambda
                        broadcast_batch_shape,
                        ps.shape(x)[result_batch_dims:]
                    ],
                    axis=0))
            result = tf.nest.map_structure(lambda x: unflatten(x)
                                           if tf.is_tensor(x) else x,
                                           batched_result,
                                           expand_composites=True)
        return result
Esempio n. 22
0
 def test_vectorized_map(self):
     initial_value = tf.ones([5, 3])
     x = tfp.util.TransformedVariable(initial_value, tfb.Sigmoid())
     y = tf.vectorized_map(lambda v: v + 2., x)
     self.evaluate([v.initializer for v in x.trainable_variables])
     self.assertAllClose(self.evaluate(y), initial_value + 2.)
Esempio n. 23
0
 def test_vectorized_map(self):
     pretransformed_input = tf.Variable(tf.ones([5, 3]))
     x = tfp.util.DeferredTensor(pretransformed_input, tfb.Scale([5]))
     y = tf.vectorized_map(lambda v: v + 2., x)
     self.evaluate([v.initializer for v in x.trainable_variables])
     self.assertAllClose(self.evaluate(y), 5. * pretransformed_input + 2.)
Esempio n. 24
0
    def _parse_fn(record):
        """Parses a record into a feature_dict."""
        feature_values = tf.io.parse_single_example(
            serialized=record,
            features={
                'i/o':
                tf.io.FixedLenFeature([], tf.string, default_value=''),
                'program_encoding':
                tf.io.FixedLenFeature([], tf.string, default_value=''),
            })

        ios = tf.strings.split(tf.strings.split(feature_values['i/o'],
                                                sep='>'),
                               sep='<')

        inputs, outputs = ios.merge_dims(0, 1)[::2], ios.merge_dims(0, 1)[1::2]

        # Parse inputs into tokens.
        inputs = tf.strings.unicode_split(inputs, 'UTF-8').to_tensor()
        inputs = spec_vocab_table.lookup(inputs)  # Map characters to tokens.

        # Parse outputs into tokens.
        outputs_with_separators = (tf.strings.unicode_split(
            outputs, 'UTF-8').to_tensor())
        outputs_with_separators = spec_vocab_table.lookup(
            outputs_with_separators)
        split_outputs = tf.strings.unicode_split(
            tf.strings.split(outputs, sep='|'), 'UTF-8')
        outputs = split_outputs.merge_dims(1, 2).to_tensor()
        outputs = spec_vocab_table.lookup(outputs)

        # Compute indices for the start of each part of the spec, w.r.t. the
        # original spec.
        separator_indices = tf.where(
            tf.equal(outputs_with_separators, separator_id))[:, 1]
        separator_indices = tf.reshape(
            separator_indices, (tf.shape(outputs_with_separators)[0], -1))
        start_indices = separator_indices - tf.expand_dims(
            tf.range(tf.shape(separator_indices)[1], dtype=tf.int64), 0)
        start_indices = tf.concat((tf.zeros(
            (tf.shape(start_indices)[0], 1), dtype=tf.int64), start_indices),
                                  axis=1)

        num_examples = tf.shape(start_indices)[0]
        num_parts = tf.shape(start_indices)[1]

        # Construct the shifted spec suffixes.
        flat_start_indices = tf.reshape(start_indices, (-1, ))
        prefix_mask = (1 - tf.sequence_mask(
            flat_start_indices, maxlen=tf.shape(outputs)[-1], dtype=tf.int64))
        masked_outputs = tf.repeat(outputs, num_parts, axis=0) * prefix_mask
        output_suffixes = tf.vectorized_map(
            fn=lambda x: tf.roll(x[0], x[1], axis=0),
            elems=(masked_outputs, -flat_start_indices))

        # Compute indices for the start/end of spec parts, w.r.t. the shifted spec
        # suffixes.
        ground_truth_start_indices = tf.zeros((num_examples * num_parts, ),
                                              dtype=tf.int64)
        cumulative_end_indices = tf.concat(
            (start_indices,
             tf.math.count_nonzero(outputs, axis=-1, keepdims=True)),
            axis=1)
        ground_truth_end_indices = tf.reshape(
            cumulative_end_indices[:, 1:] - cumulative_end_indices[:, :-1],
            (-1, ))

        # Construct the actual spec parts to predict.
        range_indices = tf.expand_dims(tf.range(tf.shape(output_suffixes)[-1],
                                                dtype=tf.int64),
                                       axis=0)
        part_mask = tf.where(
            tf.logical_and(
                range_indices >= tf.expand_dims(ground_truth_start_indices,
                                                axis=1),
                range_indices < tf.expand_dims(ground_truth_end_indices,
                                               axis=1)), 1, 0)
        output_parts = output_suffixes * tf.cast(part_mask, tf.int64)
        output_parts = tf.pad(output_parts,
                              [[0, 0], [0, 1]])  # Make room for sep.
        # TODO(kshi): roll output_parts leftward by start_indices for SCAN.
        first_zero_index = tf.math.count_nonzero(output_parts, axis=-1)
        output_parts += tf.one_hot(first_zero_index,
                                   depth=tf.shape(output_parts)[-1],
                                   dtype=tf.int64) * separator_id

        # Reshape everything so that different spec suffixes become different
        # dataset elements.
        output_suffixes_reshaped = tf.transpose(
            tf.reshape(output_suffixes, (num_examples, num_parts, -1)),
            (1, 0, 2))
        output_parts_reshaped = tf.transpose(
            tf.reshape(output_parts, (num_examples, num_parts, -1)), (1, 0, 2))
        inputs_reshaped = tf.reshape(tf.tile(inputs, (num_parts, 1)),
                                     (num_parts, num_examples, -1))
        ground_truth_start_indices_reshaped = tf.transpose(
            tf.reshape(ground_truth_start_indices, (num_examples, num_parts)))
        ground_truth_end_indices_reshaped = tf.transpose(
            tf.reshape(ground_truth_end_indices, (num_examples, num_parts)))

        # Combine spec parts from all examples into one sequence with separator
        # tokens between examples and ending in EOS.
        shifts = tf.cumsum(tf.concat((tf.zeros(
            (num_parts, 1),
            dtype=tf.int64), ground_truth_end_indices_reshaped[:, :-1] + 1),
                                     1),
                           axis=-1)
        flat_shifts = tf.reshape(shifts, (-1, ))
        output_len = tf.shape(output_parts_reshaped)[-1]
        flat_spec_parts = tf.reshape(output_parts_reshaped, (-1, output_len))
        flat_spec_parts = tf.pad(flat_spec_parts,
                                 [[0, 0], [0, max_target_length - output_len]])
        combined_spec_parts = tf.vectorized_map(
            fn=lambda x: tf.roll(x[0], x[1], axis=0),
            elems=(flat_spec_parts, flat_shifts))
        combined_spec_parts = tf.reshape(combined_spec_parts,
                                         (num_parts, num_examples, -1))
        combined_spec_parts = tf.reduce_sum(combined_spec_parts, axis=1)
        first_zero_index = tf.math.count_nonzero(combined_spec_parts, axis=-1)
        combined_spec_parts += tf.one_hot(
            first_zero_index,
            depth=tf.shape(combined_spec_parts)[-1],
            dtype=tf.int64) * eos_id

        # Create a dataset containing data for all spec suffixes.
        dataset = tf.data.Dataset.from_tensor_slices({
            'inputs':
            inputs_reshaped,
            'outputs':
            output_suffixes_reshaped,
            'spec_parts':
            combined_spec_parts,
            'start_index':
            ground_truth_start_indices_reshaped,
            'end_index':
            ground_truth_end_indices_reshaped
        })
        return dataset
  def vectorized_fn(*args):
    """Vectorized version of `fn` that accepts arguments of any rank."""
    with tf.name_scope(name or 'make_rank_polymorphic'):
      assertions = []

      # If we got a single value for core_ndims, tile it across all args.
      core_ndims_structure = (
          core_ndims
          if tf.nest.is_nested(core_ndims)
          else tf.nest.map_structure(lambda _: core_ndims, args))

      # Build flat lists of all argument parts and their corresponding core
      # ndims.
      flat_core_ndims = tf.nest.flatten(core_ndims_structure)
      flat_args = nest.flatten_up_to(
          core_ndims_structure, args, check_types=False)

      # Filter to only the `Tensor`-valued args (taken to be those with `None`
      # values for `core_ndims`). Other args will be passed through to `fn`
      # unmodified.
      (vectorized_arg_core_ndims,
       vectorized_args,
       fn_of_vectorized_args) = _lock_in_non_vectorized_args(
           fn,
           arg_structure=core_ndims_structure,
           flat_core_ndims=flat_core_ndims,
           flat_args=flat_args)

      # `vectorized_map` requires all inputs to have a single, common batch
      # dimension `[n]`. So we broadcast all input parts to a common
      # batch shape, then flatten it down to a single dimension.

      # First, compute how many 'extra' (batch) ndims each part has. This must
      # be nonnegative.
      vectorized_arg_shapes = [tf.shape(arg) for arg in vectorized_args]
      batch_ndims = [
          ps.rank_from_shape(arg_shape) - nd
          for (arg_shape, nd) in zip(
              vectorized_arg_shapes, vectorized_arg_core_ndims)]
      static_ndims = [tf.get_static_value(nd) for nd in batch_ndims]
      if any([nd and nd < 0 for nd in static_ndims]):
        raise ValueError('Cannot broadcast a Tensor having lower rank than the '
                         'specified `core_ndims`! (saw input ranks {}, '
                         '`core_ndims` {}).'.format(
                             tf.nest.map_structure(
                                 ps.rank_from_shape,
                                 vectorized_arg_shapes),
                             vectorized_arg_core_ndims))
      if validate_args:
        for nd, part, core_nd in zip(
            batch_ndims, vectorized_args, vectorized_arg_core_ndims):
          assertions.append(tf.debugging.assert_non_negative(
              nd, message='Cannot broadcast a Tensor having lower rank than '
              'the specified `core_ndims`! (saw {} vs minimum rank {}).'.format(
                  part, core_nd)))

      # Next, split each part's shape into batch and core shapes, and
      # broadcast the batch shapes.
      with tf.control_dependencies(assertions):
        empty_shape = np.zeros([0], dtype=np.int32)
        batch_shapes, core_shapes = empty_shape, empty_shape
        if vectorized_arg_shapes:
          batch_shapes, core_shapes = zip(*[
              (arg_shape[:nd], arg_shape[nd:])
              for (arg_shape, nd) in zip(vectorized_arg_shapes, batch_ndims)])
        broadcast_batch_shape = (
            functools.reduce(ps.broadcast_shape, batch_shapes, []))

      # Flatten all of the batch dimensions into one.
      n = tf.cast(ps.reduce_prod(broadcast_batch_shape), tf.int32)
      static_n = tf.get_static_value(n)
      if static_n == 1:
        result = fn(*args)
      else:
        # Pad all input parts to the common shape, then flatten
        # into the single leading dimension `[n]`.
        # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap
        # when batch rank is static so that we can exploit broadcasting.
        broadcast_vectorized_args = [
            tf.broadcast_to(part, ps.concat(
                [broadcast_batch_shape, core_shape], axis=0))
            for (part, core_shape) in zip(vectorized_args, core_shapes)]
        vectorized_args_with_flattened_batch_dim = [
            tf.reshape(part, ps.concat([[n], core_shape], axis=0))
            for (part, core_shape) in zip(
                broadcast_vectorized_args, core_shapes)]
        batched_result = tf.vectorized_map(
            fn_of_vectorized_args, vectorized_args_with_flattened_batch_dim)

        # Unflatten any `Tensor`s in the result.
        unflatten = lambda x: tf.reshape(x, ps.concat([  # pylint: disable=g-long-lambda
            broadcast_batch_shape, ps.shape(x)[1:]], axis=0))
        result = tf.nest.map_structure(
            lambda x: unflatten(x) if tf.is_tensor(x) else x, batched_result,
            expand_composites=True)
    return result
Esempio n. 26
0
  def vectorized_fn(*args):
    """Vectorized version of `fn` that accepts arguments of any rank."""
    with tf.name_scope(name or 'make_rank_polymorphic'):
      assertions = []

      # If we got a single value for core_ndims, tile it across all args.
      core_ndims_structure = (
          core_ndims
          if nest.is_nested(core_ndims)
          else nest.map_structure(lambda _: core_ndims, args))

      # Build flat lists of all argument parts and their corresponding core
      # ndims.
      flat_core_ndims = nest.flatten(core_ndims_structure)
      parts = tf.nest.flatten(nest.map_structure_up_to(
          core_ndims_structure, tf.convert_to_tensor, args, check_types=False))
      if len(parts) != len(flat_core_ndims):
        raise ValueError('Number of args does not match `core_ndims` '
                         '({} vs {}). Saw argument parts {}; core '
                         'ndims {}.'.format(len(parts), len(flat_core_ndims),
                                            parts, flat_core_ndims))

      # `vectorized_map` requires all inputs to have a single, common batch
      # dimension `[n]`. So we broadcast all input parts to a common
      # batch shape, then flatten it down to a single dimension.

      # First, compute how many 'extra' (batch) ndims each part has. This must
      # be nonnegative.
      part_shapes = [tf.shape(part) for part in parts]
      batch_ndims = [
          prefer_static.rank_from_shape(part_shape) - nd
          for (part_shape, nd) in zip(part_shapes, flat_core_ndims)]
      static_ndims = [tf.get_static_value(nd) for nd in batch_ndims]
      if any([nd and nd < 0 for nd in static_ndims]):
        raise ValueError('Cannot broadcast a Tensor having lower rank than the '
                         'specified `core_ndims`! (saw input ranks {}, '
                         '`core_ndims` {}).'.format(
                             tf.nest.map_structure(
                                 prefer_static.rank_from_shape, part_shapes),
                             flat_core_ndims))
      if validate_args:
        for nd, part, core_nd in zip(batch_ndims, parts, flat_core_ndims):
          assertions.append(tf.debugging.assert_non_negative(
              nd, message='Cannot broadcast a Tensor having lower rank than '
              'the specified `core_ndims`! (saw {} vs minimum rank {}).'.format(
                  part, core_nd)))

      # Next, split each part's shape into batch and core shapes, and
      # broadcast the batch shapes.
      with tf.control_dependencies(assertions):
        batch_shapes, core_shapes = zip(*[
            (part_shape[:nd], part_shape[nd:])
            for (part_shape, nd) in zip(part_shapes, batch_ndims)])
        broadcast_batch_shape = functools.reduce(
            prefer_static.broadcast_shape, batch_shapes, [])

      # Flatten all of the batch dimensions into one.
      n = tf.cast(prefer_static.reduce_prod(broadcast_batch_shape), tf.int32)
      static_n = tf.get_static_value(n)
      if static_n == 1:
        result = fn(*args)
      else:
        # Pad all input parts to the common shape, then flatten
        # into the single leading dimension `[n]`.
        # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap
        # when batch rank is static so that we can exploit broadcasting.
        broadcast_parts = [
            tf.broadcast_to(part, prefer_static.concat([broadcast_batch_shape,
                                                        core_shape], axis=0))
            for (part, core_shape) in zip(parts, core_shapes)]
        parts_with_flattened_batch_dim = [
            tf.reshape(part, prefer_static.concat([[n], core_shape], axis=0))
            for (part, core_shape) in zip(broadcast_parts, core_shapes)]

        # Run the vectorized computation
        batched_result = tf.vectorized_map(lambda args: fn(*args),
                                           nest.pack_sequence_as(
                                               args,
                                               parts_with_flattened_batch_dim))

        # Unflatten the result
        result = nest.map_structure(
            lambda x: tf.reshape(x, prefer_static.concat([  # pylint: disable=g-long-lambda
                broadcast_batch_shape, prefer_static.shape(x)[1:]], axis=0)),
            batched_result)
    return result
Esempio n. 27
0
def expected_calibration_error_quantiles(hit,
                                         pred_log_prob,
                                         num_buckets=20,
                                         axis=0,
                                         log_space_buckets=False,
                                         name=None):
    """Expected calibration error via `quantiles(exp(pred_log_prob),num_buckets)`.

  Calibration is a measure of how well a model reports its own uncertainty. A
  model is said to be "calibrated" if buckets of predicted probabilities have
  the same within bucket average accurcy. The exected calibration error is the
  average absolute difference between predicted probability and (bucket) average
  accuracy. That is:

  ```python
  bucket weight = bucket_count / tf.reduce_sum(bucket_count, axis=0)
  bucket_error = abs(bucket_accuracy - bucket_confidence)
  ece = tf.reduce_sum(bucket_weight * bucket_error, axis=0)
  ```

  where `bucket_accuracy, bucket_confidence, bucket_count` are statistics
  aggregated by `num_buckets`-quantiles of `tf.math.exp(pred_log_prob)`. Note:
  `bucket_*` always have `num_buckets` size for the zero-th dimension.

  Args:
    hit: `bool` `Tensor` where `True` means the model prediction was correct
      and `False` means the model prediction was incorrect. Shape must
      broadcast with pred_log_prob.
    pred_log_prob: `Tensor` representing the model's predicted log probability
      for the given `hit`. Shape must broadcast with `hit`.
    num_buckets: `int` representing the number of buckets over which to
      aggregate hits. Buckets are quantiles of `exp(pred_log_prob)`.
      Default value: `20`.
    axis: Dimension over which to compute buckets and aggregate stats.
      Default value: `0`.
    log_space_buckets: When `False` bucket edges are computed from
      `tf.math.exp(pred_log_prob)`; when `True` bucket edges are computed from
      `pred_log_prob`.
      Default value: `False`.
    name: Prefer `str` name used for ops created by this function.
      Default value: `None` (i.e.,
      `"expected_calibration_error_quantiles"`).

  Returns:
    ece: Expected calibration error; `tf.reduce_sum(abs(bucket_accuracy -
      bucket_confidence) * bucket_count, axis=0) / tf.reduce_sum(bucket_count,
      axis=0)`.
    bucket_accuracy: `Tensor` representing the within bucket average hits, i.e.,
      total bucket hits divided by bucket count. Has shape
      `tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob,
      axis=axis))], axis=0)`.
    bucket_confidence: `Tensor` representing the within bucket average
      probability, i.e., total bucket predicted probability divided by bucket
      count. Has shape `tf.concat([[num_buckets],
      tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0)`.
    bucket_count: `Tensor` representing the total number of obervations in each
      bucket. Has shape `tf.concat([[num_buckets],
      tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0)`.
    bucket_pred_log_prob: `Tensor` representing `pred_log_prob` bucket edges.
      Always in log space, regardless of the value of `log_space_buckets`.
    bucket: `int` `Tensor` representing the bucket within which `pred_log_prob`
      lies.

  #### Examples

  ```python
  # Example 1: Generic use.
  label = tf.cast([0, 0, 1, 0, 1, 1], dtype=tf.bool)
  log_pred = tf.math.log([0.1, 0.05, 0.5, 0.2, 0.99, 0.99])
  (
    ece,
    acc,
    conf,
    cnt,
    edges,
    bucket,
  ) = tfp.stats.expected_calibration_error_quantiles(
      label, log_pred, num_buckets=3)
  # ece  ==> tf.Tensor(0.145, shape=(), dtype=float32)
  # acc  ==> tf.Tensor([0. 0. 1.], shape=(3,), dtype=float32)
  # conf ==> tf.Tensor([0.075, 0.2, 0.826665], shape=(3,), dtype=float32)
  # cnt  ==> tf.Tensor([2. 1. 3.], shape=(3,), dtype=float32)
  ```

  ```python
  # Example 2: Categorgical classification.
  # Assume we have evidence `x`, targets `y`, and model function `dnn`.
  d = tfd.Categorical(logits=dnn(x))
  def all_categories(d):
    num_classes = tf.shape(d.logits_parameter())[-1]
    batch_ndims = tf.size(d.batch_shape_tensor())
    expand_shape = tf.pad(
        [num_classes], paddings=[[0, batch_ndims]], constant_values=1)
    return tf.reshape(tf.range(num_classes, dtype=d.dtype), expand_shape)
  all_pred_log_prob = d.log_prob(all_categories(d))
  yhat = tf.argmax(all_pred_log_prob, axis=0)
  def rollaxis(x, shift):
    return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0))
  pred_log_prob = tf.gather(rollaxis(all_pred_log_prob, shift=-1),
                            yhat,
                            batch_dims=len(d.batch_shape))
  hit = tf.equal(y, yhat)
  (
    ece,
    acc,
    conf,
    cnt,
    edges,
    bucket,
  ) = tfp.stats.expected_calibration_error_quantiles(
      hit, pred_log_prob, num_buckets=10)
  ```

  """
    with tf.name_scope(name or 'expected_calibration_error_quantiles'):
        pred_log_prob = tf.convert_to_tensor(pred_log_prob,
                                             dtype_hint=tf.float32,
                                             name='pred_log_prob')
        dtype = pred_log_prob.dtype
        hit = tf.cast(hit, dtype, name='hit')
        # Make sure to compute quantiles in "prob" space not "log(prob)".
        if log_space_buckets:
            bucket_pred_log_prob = quantiles_lib.quantiles(
                pred_log_prob, num_quantiles=num_buckets, axis=axis)
        else:
            bucket_pred_log_prob = tf.math.log(
                quantiles_lib.quantiles(tf.math.exp(pred_log_prob),
                                        num_quantiles=num_buckets,
                                        axis=axis))
        bucket = _find_bins(pred_log_prob, bucket_pred_log_prob, axis)

        def _fn(i):
            """`map_fn` body."""
            keep = tf.equal(i, bucket)
            total_hit = tf.math.reduce_sum(tf.where(keep, hit,
                                                    tf.constant(0., dtype)),
                                           axis=axis)
            total_count = tf.math.reduce_sum(tf.cast(keep, dtype), axis=axis)
            log_total_pred_prob = tf.math.reduce_logsumexp(tf.where(
                keep, pred_log_prob, tf.constant(-np.inf, dtype)),
                                                           axis=axis)
            return total_hit, log_total_pred_prob, total_count

        # On the following line, we use vectorized_map instead of map_fn not for
        # efficiency reasons but because at the time of writing, map_fn doesn't
        # work correctly on the JAX substrate.  Specifically, it does not like that
        # _fn returns a tuple.
        bucket_total_hit, bucket_log_total_pred_prob, bucket_count = (
            tf.vectorized_map(fn=_fn,
                              elems=tf.range(num_buckets, dtype=bucket.dtype)))
        n = tf.maximum(bucket_count, 1.)
        bucket_accuracy = bucket_total_hit / n
        bucket_confidence = tf.math.exp(bucket_log_total_pred_prob -
                                        tf.math.log(n))
        bucket_error = abs(bucket_accuracy - bucket_confidence)
        n = ps.cast(ps.shape(pred_log_prob)[axis], dtype)
        ece = tf.math.reduce_sum(bucket_count * bucket_error, axis=0) / n
        return (
            ece,
            bucket_accuracy,
            bucket_confidence,
            bucket_count,
            bucket_pred_log_prob,
            bucket,
        )
Esempio n. 28
0
def _update_confusion_matrix_variables_optimized(
    variables_to_update,
    y_true,
    y_pred,
    thresholds,
    multi_label=False,
    sample_weights=None,
    label_weights=None,
    thresholds_with_epsilon=False,
):
    """Update confusion matrix variables with memory efficient alternative.

    Note that the thresholds need to be evenly distributed within the list, eg,
    the diff between consecutive elements are the same.

    To compute TP/FP/TN/FN, we are measuring a binary classifier
      C(t) = (predictions >= t)
    at each threshold 't'. So we have
      TP(t) = sum( C(t) * true_labels )
      FP(t) = sum( C(t) * false_labels )

    But, computing C(t) requires computation for each t. To make it fast,
    observe that C(t) is a cumulative integral, and so if we have
      thresholds = [t_0, ..., t_{n-1}];  t_0 < ... < t_{n-1}
    where n = num_thresholds, and if we can compute the bucket function
      B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
    then we get
      C(t_i) = sum( B(j), j >= i )
    which is the reversed cumulative sum in tf.cumsum().

    We can compute B(i) efficiently by taking advantage of the fact that
    our thresholds are evenly distributed, in that
      width = 1.0 / (num_thresholds - 1)
      thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
    Given a prediction value p, we can map it to its bucket by
      bucket_index(p) = floor( p * (num_thresholds - 1) )
    so we can use tf.math.unsorted_segment_sum() to update the buckets in one
    pass.

    Consider following example:
    y_true = [0, 0, 1, 1]
    y_pred = [0.1, 0.5, 0.3, 0.9]
    thresholds = [0.0, 0.5, 1.0]
    num_buckets = 2   # [0.0, 1.0], (1.0, 2.0]
    bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets)
                         = tf.math.floor([0.2, 1.0, 0.6, 1.8])
                         = [0, 0, 0, 1]
    # The meaning of this bucket is that if any of the label is true,
    # then 1 will be added to the corresponding bucket with the index.
    # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
    # label for 1.8 is true, then 1 will be added to bucket 1.
    #
    # Note the second item "1.0" is floored to 0, since the value need to be
    # strictly larger than the bucket lower bound.
    # In the implementation, we use tf.math.ceil() - 1 to achieve this.
    tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices,
                                                   num_segments=num_thresholds)
                    = [1, 1, 0]
    # For [1, 1, 0] here, it means there is 1 true value contributed by bucket 0,
    # and 1 value contributed by bucket 1. When we aggregate them to together,
    # the result become [a + b + c, b + c, c], since large thresholds will always
    # contribute to the value for smaller thresholds.
    true_positive = tf.math.cumsum(tp_bucket_value, reverse=True)
                  = [2, 1, 0]

    This implementation exhibits a run time and space complexity of O(T + N),
    where T is the number of thresholds and N is the size of predictions.
    Metrics that rely on standard implementation instead exhibit a complexity of
    O(T * N).

    Args:
      variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
        and corresponding variables to update as values.
      y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be cast
        to `bool`.
      y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
        the range `[0, 1]`.
      thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
        It need to be evenly distributed (the diff between each element need to be
        the same).
      multi_label: Optional boolean indicating whether multidimensional
        prediction/labels should be treated as multilabel responses, or flattened
        into a single label. When True, the valus of `variables_to_update` must
        have a second dimension equal to the number of labels in y_true and
        y_pred, and those tensors must not be RaggedTensors.
      sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
        as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
        must be either `1`, or the same as the corresponding `y_true` dimension).
      label_weights: Optional tensor of non-negative weights for multilabel
        data. The weights are applied when calculating TP, FP, FN, and TN without
        explicit multilabel handling (i.e. when the data is to be flattened).
      thresholds_with_epsilon: Optional boolean indicating whether the leading and
        tailing thresholds has any epsilon added for floating point imprecisions.
        It will change how we handle the leading and tailing bucket.

    Returns:
      Update op.
    """
    num_thresholds = thresholds.shape.as_list()[0]

    if sample_weights is None:
        sample_weights = 1.0
    else:
        sample_weights = tf.__internal__.ops.broadcast_weights(
            tf.cast(sample_weights, dtype=y_pred.dtype), y_pred)
        if not multi_label:
            sample_weights = tf.reshape(sample_weights, [-1])
    if label_weights is None:
        label_weights = 1.0
    else:
        label_weights = tf.expand_dims(label_weights, 0)
        label_weights = tf.__internal__.ops.broadcast_weights(
            label_weights, y_pred)
        if not multi_label:
            label_weights = tf.reshape(label_weights, [-1])
    weights = tf.multiply(sample_weights, label_weights)

    # We shouldn't need this, but in case there are predict value that is out of
    # the range of [0.0, 1.0]
    y_pred = tf.clip_by_value(y_pred, clip_value_min=0.0, clip_value_max=1.0)

    y_true = tf.cast(tf.cast(y_true, tf.bool), y_true.dtype)
    if not multi_label:
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(y_pred, [-1])

    true_labels = tf.multiply(y_true, weights)
    false_labels = tf.multiply((1.0 - y_true), weights)

    # Compute the bucket indices for each prediction value.
    # Since the predict value has to be strictly greater than the thresholds,
    # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
    # We have to use math.ceil(val) - 1 for the bucket.
    bucket_indices = tf.math.ceil(y_pred * (num_thresholds - 1)) - 1

    if thresholds_with_epsilon:
        # In this case, the first bucket should actually take into account since
        # the any prediction between [0.0, 1.0] should be larger than the first
        # threshold. We change the bucket value from -1 to 0.
        bucket_indices = tf.nn.relu(bucket_indices)

    bucket_indices = tf.cast(bucket_indices, tf.int32)

    if multi_label:
        # We need to run bucket segment sum for each of the label class. In the
        # multi_label case, the rank of the label is 2. We first transpose it so
        # that the label dim becomes the first and we can parallel run though them.
        true_labels = tf.transpose(true_labels)
        false_labels = tf.transpose(false_labels)
        bucket_indices = tf.transpose(bucket_indices)

        def gather_bucket(label_and_bucket_index):
            label, bucket_index = (
                label_and_bucket_index[0],
                label_and_bucket_index[1],
            )
            return tf.math.unsorted_segment_sum(
                data=label,
                segment_ids=bucket_index,
                num_segments=num_thresholds,
            )

        tp_bucket_v = tf.vectorized_map(gather_bucket,
                                        (true_labels, bucket_indices))
        fp_bucket_v = tf.vectorized_map(gather_bucket,
                                        (false_labels, bucket_indices))
        tp = tf.transpose(tf.cumsum(tp_bucket_v, reverse=True, axis=1))
        fp = tf.transpose(tf.cumsum(fp_bucket_v, reverse=True, axis=1))
    else:
        tp_bucket_v = tf.math.unsorted_segment_sum(
            data=true_labels,
            segment_ids=bucket_indices,
            num_segments=num_thresholds,
        )
        fp_bucket_v = tf.math.unsorted_segment_sum(
            data=false_labels,
            segment_ids=bucket_indices,
            num_segments=num_thresholds,
        )
        tp = tf.cumsum(tp_bucket_v, reverse=True)
        fp = tf.cumsum(fp_bucket_v, reverse=True)

    # fn = sum(true_labels) - tp
    # tn = sum(false_labels) - fp
    if (ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
            or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update):
        if multi_label:
            total_true_labels = tf.reduce_sum(true_labels, axis=1)
            total_false_labels = tf.reduce_sum(false_labels, axis=1)
        else:
            total_true_labels = tf.reduce_sum(true_labels)
            total_false_labels = tf.reduce_sum(false_labels)

    update_ops = []
    if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
        variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
        update_ops.append(variable.assign_add(tp))
    if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
        variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
        update_ops.append(variable.assign_add(fp))
    if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
        variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
        tn = total_false_labels - fp
        update_ops.append(variable.assign_add(tn))
    if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
        variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
        fn = total_true_labels - tp
        update_ops.append(variable.assign_add(fn))
    return tf.group(update_ops)