def _test_vectorization(self, dist_name, dist): seed = test_util.test_seed() # TODO(b/171752261): New stateless samplers don't work with pfor. enable_auto_vectorized_sampling = False num_samples = 3 if (not enable_auto_vectorized_sampling or dist_name in SAMPLE_AUTOVECTORIZATION_IS_BROKEN): sample = self.evaluate(dist.sample(num_samples, seed=seed)) else: sample = self.evaluate( tf.vectorized_map(lambda i: dist.sample(seed=seed), tf.range(num_samples), fallback_to_while_loop=False)) hp.note('Drew samples {}'.format(sample)) if dist_name not in LOGPROB_AUTOVECTORIZATION_IS_BROKEN: pfor_lp = tf.vectorized_map(dist.log_prob, tf.convert_to_tensor(sample), fallback_to_while_loop=False) batch_lp = dist.log_prob(sample) pfor_lp_, batch_lp_ = self.evaluate((pfor_lp, batch_lp)) self.assertAllClose(pfor_lp_, batch_lp_, atol=VECTORIZED_LOGPROB_ATOL[dist_name], rtol=VECTORIZED_LOGPROB_RTOL[dist_name])
def testAutoVectorization(self, bijector_name, data): # TODO(b/150161911): reconcile numeric behavior of eager and graph mode. if tf.executing_eagerly(): return bijector, event_dim = self._draw_bijector( bijector_name, data, batch_shape=[], # Avoid conflict with vmap sample dimension. validate_args=False, # Work around lack of `If` support in vmap. allowed_bijectors=(set(bhps.INSTANTIABLE_BIJECTORS) - set(AUTOVECTORIZATION_IS_BROKEN))) atol = AUTOVECTORIZATION_ATOL[bijector_name] rtol = AUTOVECTORIZATION_RTOL[bijector_name] # Forward n = 3 xs = self._draw_domain_tensor(bijector, data, event_dim, sample_shape=[n]) ys = bijector.forward(xs) vectorized_ys = tf.vectorized_map(bijector.forward, xs, fallback_to_while_loop=False) self.assertAllClose(*self.evaluate((ys, vectorized_ys)), atol=atol, rtol=rtol) # FLDJ event_ndims = data.draw( hps.integers( min_value=bijector.forward_min_event_ndims, max_value=ps.rank_from_shape(xs.shape) - 1)) fldj_fn = functools.partial(bijector.forward_log_det_jacobian, event_ndims=event_ndims) vectorized_fldj = tf.vectorized_map(fldj_fn, xs, fallback_to_while_loop=False) fldj = tf.broadcast_to(fldj_fn(xs), tf.shape(vectorized_fldj)) self.assertAllClose(*self.evaluate((fldj, vectorized_fldj)), atol=atol, rtol=rtol) # Inverse ys = self._draw_codomain_tensor(bijector, data, event_dim, sample_shape=[n]) xs = bijector.inverse(ys) vectorized_xs = tf.vectorized_map(bijector.inverse, ys, fallback_to_while_loop=False) self.assertAllClose(*self.evaluate((xs, vectorized_xs)), atol=atol, rtol=rtol) # ILDJ event_ndims = data.draw( hps.integers( min_value=bijector.inverse_min_event_ndims, max_value=ps.rank_from_shape(ys.shape) - 1)) ildj_fn = functools.partial(bijector.inverse_log_det_jacobian, event_ndims=event_ndims) vectorized_ildj = tf.vectorized_map(ildj_fn, ys, fallback_to_while_loop=False) ildj = tf.broadcast_to(ildj_fn(ys), tf.shape(vectorized_ildj)) self.assertAllClose(*self.evaluate((ildj, vectorized_ildj)), atol=atol, rtol=rtol)
def testPForInterop(self): def outer_product(a): return np.tensordot(a, a, 0) batch_size = 100 a = np.ones((batch_size, 32, 32)) c = tf.vectorized_map(outer_product, a) self.assertIsInstance(c, np.ndarray) self.assertEqual(c.shape, (batch_size, 32, 32, 32, 32)) c = tf.vectorized_map(lambda x: x.T, a) self.assertIsInstance(c, np.ndarray) self.assertEqual(c.shape, (batch_size, 32, 32))
def call(self, inputs): bins = [tf.cast(tf.compat.v1.squeeze(self.bins), tf.float32)] def _bucketize_fn(inputs): return tf.raw_ops.BoostedTreesBucketize( float_values=[tf.cast(inputs, tf.float32)], bucket_boundaries=bins)[0] if tf_utils.is_ragged(inputs): integer_buckets = tf.ragged.map_flat_values( _bucketize_fn, inputs) # Ragged map_flat_values doesn't touch the non-values tensors in the # ragged composite tensor. If this op is the only op a Keras model, # this can cause errors in Graph mode, so wrap the tensor in an identity. return tf.identity(integer_buckets) elif isinstance(inputs, tf.SparseTensor): return tf.SparseTensor( indices=tf.identity(inputs.indices), values=_bucketize_fn(inputs.values), dense_shape=tf.identity(inputs.dense_shape)) else: static_shape = inputs.get_shape() if any(dim is None for dim in static_shape.as_list()[1:]): raise NotImplementedError( "Discretization Layer requires known non-batch shape," "found {}".format(static_shape)) dynamic_shape = tf.shape(inputs) # BoostedTreesBucketize only handles rank 1 inputs. We need to flatten our # inputs after batch size and vectorized_map over each sample. reshaped = tf.reshape(inputs, [dynamic_shape[0], -1]) return tf.reshape( tf.vectorized_map(_bucketize_fn, reshaped), dynamic_shape)
def test_sampled_scale_follows_correct_distribution(self): strm = test_util.test_seed_stream() prior = tfd.InverseGamma(concentration=0.1, scale=0.1) num_timesteps = 100 observed_samples = tf.random.normal([2, num_timesteps], seed=strm()) * 3. is_missing = tf.random.uniform([2, num_timesteps], seed=strm()) > 0.9 # Check that posterior variance samples have the moments of the correct # InverseGamma distribution. posterior_scale_samples = tf.vectorized_map( lambda seed: gibbs_sampler._resample_scale( # pylint: disable=g-long-lambda prior=prior, observed_residuals=observed_samples, is_missing=is_missing, seed=seed), tfp.random.split_seed(strm(), tf.constant(10000))) concentration = prior.concentration + tf.reduce_sum( 1 - tf.cast(is_missing, tf.float32), axis=-1)/2. scale = prior.scale + tf.reduce_sum( (observed_samples * tf.cast(~is_missing, tf.float32))**2, axis=-1)/2. posterior_scale_samples_, concentration_, scale_ = self.evaluate( (posterior_scale_samples, concentration, scale)) self.assertAllClose(np.mean(posterior_scale_samples_**2, axis=0), scale_ / (concentration_ - 1), atol=0.05) self.assertAllClose( np.std(posterior_scale_samples_**2, axis=0), scale_ / ((concentration_ - 1) * np.sqrt(concentration_ - 2)), atol=0.05)
def test_batching(self, input_batch_shape, kernel_batch_shape): input_shape = (12, 12, 2) filter_shape = (3, 3) channels_out = 4 strides = 2 dilations = (1, 1) padding = 'SAME' x, k = _make_input_and_kernel( self.make_input, input_batch_shape=input_batch_shape, input_shape=input_shape, kernel_batch_shape=kernel_batch_shape, filter_shape=filter_shape, channels_out=channels_out, dtype=self.dtype) conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations) y_batched = conv_fn(x, k) broadcast_batch_shape = ps.broadcast_shape( input_batch_shape, kernel_batch_shape) broadcasted_input = tf.broadcast_to( x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0)) broadcasted_kernel = tf.broadcast_to( k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0)) flat_y = tf.reshape( y_batched, shape=ps.pad( ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1)) flat_x = tf.reshape( broadcasted_input, shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1)) flat_tf_kernel = tf.einsum( '...ij->...ji', tf.reshape( broadcasted_kernel, shape=ps.concat( [(-1,), filter_shape, (input_shape[-1], channels_out)], axis=0))) rank = 2 output_shape, strides_ = convolution_util._get_output_shape( rank=rank, strides=(strides,) * rank, padding=padding, dilations=dilations, input_shape=input_shape, output_size=channels_out, filter_shape=filter_shape) y_expected = tf.vectorized_map( lambda args: tf.nn.conv2d_transpose( # pylint: disable=g-long-lambda args[0][tf.newaxis], args[1], output_shape=ps.concat([[1], output_shape], axis=0), strides=strides_, padding=padding), elems=(flat_x, flat_tf_kernel)) [y_actual_, y_expected_] = self.evaluate( [flat_y, tf.squeeze(y_expected, axis=1)]) self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0)
def _test_vectorization(self, dist_name, dist): seed = test_util.test_seed() num_samples = 3 if dist_name in SAMPLE_AUTOVECTORIZATION_IS_BROKEN: sample = self.evaluate(dist.sample(num_samples, seed=seed)) else: sample = self.evaluate(tf.vectorized_map( lambda i: dist.sample(seed=seed), tf.range(num_samples))) hp.note('Drew samples {}'.format(sample)) if dist_name not in LOGPROB_AUTOVECTORIZATION_IS_BROKEN: pfor_lp = tf.vectorized_map(dist.log_prob, tf.convert_to_tensor(sample)) batch_lp = dist.log_prob(sample) pfor_lp_, batch_lp_ = self.evaluate((pfor_lp, batch_lp)) self.assertAllClose(pfor_lp_, batch_lp_, atol=VECTORIZED_LOGPROB_ATOL[dist_name])
def test_can_return_distribution_from_vectorized_map(self): def fn(x): dist = AutoNormal(loc=x, scale=[3., 5.]) return dist._broadcast_parameters_with_batch_shape( tf.ones_like(dist.batch_shape_tensor())) batch_dist = tf.vectorized_map(fn, tf.convert_to_tensor([1., 2., 3.])) self.assertAllEqual(batch_dist.batch_shape, [3, 2])
def _f(*args): tf_args = tf.nest.map_structure(lambda x: tf_np.asarray(x).data, args) def tf_f(x): return f(*x) outputs = tf.vectorized_map(tf_f, tf_args) return tf.nest.map_structure(tf_np.asarray, outputs)
def test_vectorized_map(self): initial_value = tf.ones([5, 3]) x = tfp.util.TransformedVariable(initial_value, tfb.Sigmoid()) # TODO(emilyaf): Remove `convert_to_tensor` after tf.Variables are # CompositeTensor. y = tf.vectorized_map(lambda v: v + 2., tf.convert_to_tensor(x)) self.evaluate([v.initializer for v in x.trainable_variables]) self.assertAllClose(self.evaluate(y), initial_value + 2.)
def test_sampled_weights_follow_correct_distribution(self): seed = test_util.test_seed(sampler_type='stateless') design_seed, true_weights_seed, sampled_weights_seed = samplers.split_seed( seed, 3, 'test_sampled_weights_follow_correct_distribution') num_timesteps = 10 num_features = 2 batch_shape = [3, 1] design_matrix = self.evaluate(samplers.normal( batch_shape + [num_timesteps, num_features], seed=design_seed)) true_weights = self.evaluate(samplers.normal( batch_shape + [num_features, 1], seed=true_weights_seed) * 10.0) targets = np.matmul(design_matrix, true_weights) is_missing = np.array([False, False, False, True, True, False, False, True, False, False]) prior_scale = tf.convert_to_tensor(5.) likelihood_scale = tf.convert_to_tensor(0.1) # Analytically compute the true posterior distribution on weights. valid_design_matrix = design_matrix[..., ~is_missing, :] valid_targets = targets[..., ~is_missing, :] num_valid_observations = tf.shape(valid_design_matrix)[-2] weights_posterior_mean, weights_posterior_cov, _ = linear_gaussian_update( prior_mean=tf.zeros([num_features, 1]), prior_cov=tf.eye(num_features) * prior_scale**2, observation_matrix=tfl.LinearOperatorFullMatrix(valid_design_matrix), observation_noise=tfd.MultivariateNormalDiag( loc=tf.zeros([num_valid_observations]), scale_diag=likelihood_scale * tf.ones([num_valid_observations])), x_observed=valid_targets) # Check that the empirical moments of sampled weights match the true values. sampled_weights = tf.vectorized_map( lambda seed: gibbs_sampler._resample_weights( # pylint: disable=g-long-lambda design_matrix=tf.where(is_missing[..., tf.newaxis], tf.zeros_like(design_matrix), design_matrix), target_residuals=targets[..., 0], observation_noise_scale=likelihood_scale, weights_prior_scale=tf.linalg.LinearOperatorScaledIdentity( num_features, prior_scale), seed=seed), tfp.random.split_seed(sampled_weights_seed, tf.constant(10000))) sampled_weights_mean = tf.reduce_mean(sampled_weights, axis=0) centered_weights = sampled_weights - weights_posterior_mean[..., 0] sampled_weights_cov = tf.reduce_mean(centered_weights[..., :, tf.newaxis] * centered_weights[..., tf.newaxis, :], axis=0) (sampled_weights_mean_, weights_posterior_mean_, sampled_weights_cov_, weights_posterior_cov_) = self.evaluate(( sampled_weights_mean, weights_posterior_mean[..., 0], sampled_weights_cov, weights_posterior_cov)) self.assertAllClose(sampled_weights_mean_, weights_posterior_mean_, atol=0.01, rtol=0.05) self.assertAllClose(sampled_weights_cov_, weights_posterior_cov_, atol=0.01, rtol=0.05)
def testPForInterop(self): def outer_product(a): return np.tensordot(a, a, 0) batch_size = 100 a = np.ones((batch_size, 32, 32)) c = tf.vectorized_map(outer_product, a) # # TODO(nareshmodi): vectorized_map doesn't rewrap tensors in ndarray. # self.assertIsInstance(c, np.ndarray) self.assertEqual(c.shape, (batch_size, 32, 32, 32, 32))
def testReproduceVmap1(self, dtype): # Regression test for b/145554459 loc = tf.constant(-200., dtype=dtype) scale = tf.constant(2.188274e+01, dtype=dtype) high = tf.constant(113.33857, dtype=dtype) low = tf.constant(102.94414, dtype=dtype) # Not validating args b/c the assertions confuse pfor. dist = tfd.TruncatedNormal(loc, scale, low, high, validate_args=False) sample = tf.constant([102.950745, 103.87256, 107.78299], dtype=dtype) batch_lp = dist.log_prob(sample) pfor_lp = tf.vectorized_map(dist.log_prob, sample) batch_lp_, pfor_lp_ = self.evaluate((batch_lp, pfor_lp)) self.assertAllClose(batch_lp_, pfor_lp_, atol=1e-6)
def test_automatic_conversion_to_tensor(self): v = tf.Variable(tf.ones([5])) d = tfd.Normal(tf.zeros([5]), v) x = tf.convert_to_tensor([3.]) vectorized_log_prob = tf.vectorized_map(lambda z: z.log_prob(x), d) log_prob = d.log_prob(x) self.evaluate(v.initializer) self.assertAllClose(vectorized_log_prob[:, 0], log_prob) loc = tf.Variable(0.) self.evaluate(loc.initializer) cond_dist = tf.cond(tf.convert_to_tensor(True), lambda: tfd.Normal(loc, 1.), lambda: tfd.Normal(0., 1.)) self.assertIsInstance(cond_dist, tfd.Normal)
def _get_mode(samples): _, idx, count = tf.raw_ops.UniqueWithCountsV2(x=samples, axis=[0]) # TODO(b/161402486): Remove this hack for fixing the wrong static shape # of `idx` in graph mode. idx = tf.vectorized_map(lambda x: tf.reshape(x, [-1])[0], idx) # NOTE: # - `count` has shape `[K]`, where `K` is the number of unique elements, # and `count[j]` is the number of times the j-th unique element occurs # in `samples`. # - `idx` has shape `[samples.shape[0]]`, and `idx[i] == j` means that # `samples[i]` is equal to the `j`-th unique element. max_count_idx = tf.argmax(count, output_type=tf.int32) # Return an index `i` for which `idx[i] == max_count_idx`. return tf.argmax(tf.cast(tf.math.equal(idx, max_count_idx), dtype=tf.int32), output_type=tf.int32)
def per_example_test_nll(params): """Computes per-example test NLL.""" test_y_idx = np.stack([ dataset.test_student_ids - 1, dataset.test_question_ids - 1 ], axis=-1) dense_nll = (-test_joint_dist.sample_distributions( value=params)[0][-1].distribution.log_prob(test_dense_y)) vectorized_dense_nll = tf.reshape( dense_nll, [-1, num_students, num_questions]) # TODO(siege): Avoid using vmap here. log_prob_y = tf.vectorized_map( lambda nll: tf.gather_nd(nll, test_y_idx), vectorized_dense_nll) return tf.reshape( log_prob_y, list(params[0].shape) + [test_y_idx.shape[0]])
def _vectorize_parameters(f, params, use_pfor, dtype): """Loop over `params`, providing a one-hot mask to `f` for each.""" parameter_sizes = [tf.size(param) for param in params] total_size = tf.math.add_n(parameter_sizes) def _wrapper(index): full_onehot = tf.one_hot(index, total_size) split_onehot = tf.split(full_onehot, parameter_sizes) tangents = [ tf.reshape(v, tf.shape(param)) for param, v in zip(params, split_onehot) ] return f(tangents) if use_pfor: return tf.vectorized_map(_wrapper, tf.range(total_size)) else: return tf.map_fn(_wrapper, tf.range(total_size), dtype)
def testReproduceVmap2(self, dtype): # Regression test for b/150811273 if dtype == np.float32: raise unittest.SkipTest('b/150811273') seed = test_util.test_seed() loc = tf.constant(-12.500191, dtype=dtype) scale = tf.constant(1e-06, dtype=dtype) high = tf.constant(-12.502851, dtype=dtype) low = tf.constant(-187.50009, dtype=dtype) # Not validating args b/c the assertions confuse pfor. dist = tfd.TruncatedNormal(loc, scale, low, high, validate_args=False) # At the default seed, the sample comes out as [-12.502851 -12.502851 # -12.502851], but that's also weird. At a scale of 1e-6, the samples # should cluster more tightly around the location, which is -12.500191. sample = self.evaluate(dist.sample(3, seed=seed)) batch_lp = dist.log_prob(sample) pfor_lp = tf.vectorized_map(dist.log_prob, tf.convert_to_tensor(sample)) batch_lp_, pfor_lp_ = self.evaluate((batch_lp, pfor_lp)) self.assertAllClose(batch_lp_, pfor_lp_, atol=1e-6)
def test_vectorized_map(self): batch_size = 10 num_features = 32 layer = tf.keras.layers.Dense(1) def model_fn(arg): with tf.GradientTape() as g: inp, label = arg inp = tf.expand_dims(inp, 0) label = tf.expand_dims(label, 0) prediction = layer(inp) loss = tf.nn.l2_loss(label - prediction) return g.gradient(loss, (layer.kernel, layer.bias)) inputs = tf.random.uniform([batch_size, num_features]) labels = tf.random.uniform([batch_size, 1]) per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels)) self.assertEqual(per_example_gradients[0].shape, (batch_size, num_features, 1)) self.assertEqual(per_example_gradients[1].shape, (batch_size, 1))
def call(self, inputs): def _bucketize_op(bins): bins = [tf.cast(bins, tf.float32)] return lambda inputs: tf.raw_ops.BoostedTreesBucketize( # pylint: disable=g-long-lambda float_values=[tf.cast(inputs, tf.float32)], bucket_boundaries=bins)[0] if tf_utils.is_ragged(inputs): integer_buckets = tf.ragged.map_flat_values( _bucketize_op(tf.compat.v1.squeeze(self.bins)), inputs) # Ragged map_flat_values doesn't touch the non-values tensors in the # ragged composite tensor. If this op is the only op a Keras model, # this can cause errors in Graph mode, so wrap the tensor in an identity. return tf.identity(integer_buckets) elif isinstance(inputs, tf.SparseTensor): integer_buckets = tf.raw_ops.BoostedTreesBucketize( float_values=[tf.cast(inputs.values, tf.float32)], bucket_boundaries=[ tf.cast(tf.compat.v1.squeeze(self.bins), tf.float32) ])[0] return tf.SparseTensor(indices=tf.identity(inputs.indices), values=integer_buckets, dense_shape=tf.identity(inputs.dense_shape)) else: input_shape = inputs.get_shape() if any(dim is None for dim in input_shape.as_list()[1:]): raise NotImplementedError( "Discretization Layer requires known non-batch shape," "found {}".format(input_shape)) reshaped = tf.reshape( inputs, [-1, tf.raw_ops.Prod(input=input_shape.as_list()[1:], axis=0)]) return tf.reshape( tf.vectorized_map( _bucketize_op(tf.compat.v1.squeeze(self.bins)), reshaped), tf.constant([-1] + input_shape.as_list()[1:]))
def vectorized_fn(*args): """Vectorized version of `fn` that accepts arguments of any rank.""" with tf.name_scope(name or 'make_rank_polymorphic'): # If we got a single value for core_ndims, tile it across all args. core_ndims_structure = (core_ndims if tf.nest.is_nested(core_ndims) else tf.nest.map_structure( lambda _: core_ndims, args)) # Build flat lists of all argument parts and their corresponding core # ndims. flat_core_ndims = tf.nest.flatten(core_ndims_structure) flat_args = nest.flatten_up_to(core_ndims_structure, args, check_types=False) # Filter to only the `Tensor`-valued args (taken to be those with `None` # values for `core_ndims`). Other args will be passed through to `fn` # unmodified. (vectorized_arg_core_ndims, vectorized_args, fn_of_vectorized_args) = _lock_in_non_vectorized_args( fn, arg_structure=core_ndims_structure, flat_core_ndims=flat_core_ndims, flat_args=flat_args) # `vectorized_map` requires all inputs to have a single, common batch # dimension `[n]`. So we broadcast all input parts to a common # batch shape, then flatten it down to a single dimension. vectorized_arg_shapes = [ps.shape(arg) for arg in vectorized_args] vectorized_arg_actual_core_ndims = [] batch_shapes, core_shapes = [], [] for (arg_shape, core_nd) in zip(vectorized_arg_shapes, vectorized_arg_core_ndims): arg_nd = ps.rank_from_shape(arg_shape) # Shrink 'core' ndims of rank-deficient args. This guarantees that # `batch_ndims` is always nonnegative. actual_core_nd = ps.minimum(arg_nd, core_nd) vectorized_arg_actual_core_ndims.append(actual_core_nd) batch_ndims = arg_nd - actual_core_nd batch_shapes.append(arg_shape[:batch_ndims]) core_shapes.append(arg_shape[batch_ndims:]) # Flatten all of the batch dimensions into one. broadcast_batch_shape = (functools.reduce(ps.broadcast_shape, batch_shapes, [])) n = ps.cast(ps.reduce_prod(broadcast_batch_shape), tf.int32) static_n = tf.get_static_value(n) if static_n == 1: # We can bypass `vectorized_map` if the batch shape is `[]`, `[1]`, # `[1, 1]`, etc., just by flattening to batch shape `[]`. result_batch_dims = 0 batched_result = fn_of_vectorized_args( tf.nest.map_structure(lambda x, nd: tf.reshape( x, ps.shape(x)[ps.rank(x) - nd:]), vectorized_args, vectorized_arg_actual_core_ndims, check_types=False)) else: # Pad all input parts to the common shape, then flatten # into the single leading dimension `[n]`. # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap # when batch rank is static so that we can exploit broadcasting. broadcast_vectorized_args = [ tf.broadcast_to( part, ps.concat([broadcast_batch_shape, core_shape], axis=0)) for (part, core_shape) in zip(vectorized_args, core_shapes) ] vectorized_args_with_flattened_batch_dim = [ tf.reshape(part, ps.concat([[n], core_shape], axis=0)) for (part, core_shape ) in zip(broadcast_vectorized_args, core_shapes) ] result_batch_dims = 1 batched_result = tf.vectorized_map( fn_of_vectorized_args, vectorized_args_with_flattened_batch_dim) # Unflatten any `Tensor`s in the result. unflatten = lambda x: tf.reshape( x, ps.concat( [ # pylint: disable=g-long-lambda broadcast_batch_shape, ps.shape(x)[result_batch_dims:] ], axis=0)) result = tf.nest.map_structure(lambda x: unflatten(x) if tf.is_tensor(x) else x, batched_result, expand_composites=True) return result
def test_vectorized_map(self): initial_value = tf.ones([5, 3]) x = tfp.util.TransformedVariable(initial_value, tfb.Sigmoid()) y = tf.vectorized_map(lambda v: v + 2., x) self.evaluate([v.initializer for v in x.trainable_variables]) self.assertAllClose(self.evaluate(y), initial_value + 2.)
def test_vectorized_map(self): pretransformed_input = tf.Variable(tf.ones([5, 3])) x = tfp.util.DeferredTensor(pretransformed_input, tfb.Scale([5])) y = tf.vectorized_map(lambda v: v + 2., x) self.evaluate([v.initializer for v in x.trainable_variables]) self.assertAllClose(self.evaluate(y), 5. * pretransformed_input + 2.)
def _parse_fn(record): """Parses a record into a feature_dict.""" feature_values = tf.io.parse_single_example( serialized=record, features={ 'i/o': tf.io.FixedLenFeature([], tf.string, default_value=''), 'program_encoding': tf.io.FixedLenFeature([], tf.string, default_value=''), }) ios = tf.strings.split(tf.strings.split(feature_values['i/o'], sep='>'), sep='<') inputs, outputs = ios.merge_dims(0, 1)[::2], ios.merge_dims(0, 1)[1::2] # Parse inputs into tokens. inputs = tf.strings.unicode_split(inputs, 'UTF-8').to_tensor() inputs = spec_vocab_table.lookup(inputs) # Map characters to tokens. # Parse outputs into tokens. outputs_with_separators = (tf.strings.unicode_split( outputs, 'UTF-8').to_tensor()) outputs_with_separators = spec_vocab_table.lookup( outputs_with_separators) split_outputs = tf.strings.unicode_split( tf.strings.split(outputs, sep='|'), 'UTF-8') outputs = split_outputs.merge_dims(1, 2).to_tensor() outputs = spec_vocab_table.lookup(outputs) # Compute indices for the start of each part of the spec, w.r.t. the # original spec. separator_indices = tf.where( tf.equal(outputs_with_separators, separator_id))[:, 1] separator_indices = tf.reshape( separator_indices, (tf.shape(outputs_with_separators)[0], -1)) start_indices = separator_indices - tf.expand_dims( tf.range(tf.shape(separator_indices)[1], dtype=tf.int64), 0) start_indices = tf.concat((tf.zeros( (tf.shape(start_indices)[0], 1), dtype=tf.int64), start_indices), axis=1) num_examples = tf.shape(start_indices)[0] num_parts = tf.shape(start_indices)[1] # Construct the shifted spec suffixes. flat_start_indices = tf.reshape(start_indices, (-1, )) prefix_mask = (1 - tf.sequence_mask( flat_start_indices, maxlen=tf.shape(outputs)[-1], dtype=tf.int64)) masked_outputs = tf.repeat(outputs, num_parts, axis=0) * prefix_mask output_suffixes = tf.vectorized_map( fn=lambda x: tf.roll(x[0], x[1], axis=0), elems=(masked_outputs, -flat_start_indices)) # Compute indices for the start/end of spec parts, w.r.t. the shifted spec # suffixes. ground_truth_start_indices = tf.zeros((num_examples * num_parts, ), dtype=tf.int64) cumulative_end_indices = tf.concat( (start_indices, tf.math.count_nonzero(outputs, axis=-1, keepdims=True)), axis=1) ground_truth_end_indices = tf.reshape( cumulative_end_indices[:, 1:] - cumulative_end_indices[:, :-1], (-1, )) # Construct the actual spec parts to predict. range_indices = tf.expand_dims(tf.range(tf.shape(output_suffixes)[-1], dtype=tf.int64), axis=0) part_mask = tf.where( tf.logical_and( range_indices >= tf.expand_dims(ground_truth_start_indices, axis=1), range_indices < tf.expand_dims(ground_truth_end_indices, axis=1)), 1, 0) output_parts = output_suffixes * tf.cast(part_mask, tf.int64) output_parts = tf.pad(output_parts, [[0, 0], [0, 1]]) # Make room for sep. # TODO(kshi): roll output_parts leftward by start_indices for SCAN. first_zero_index = tf.math.count_nonzero(output_parts, axis=-1) output_parts += tf.one_hot(first_zero_index, depth=tf.shape(output_parts)[-1], dtype=tf.int64) * separator_id # Reshape everything so that different spec suffixes become different # dataset elements. output_suffixes_reshaped = tf.transpose( tf.reshape(output_suffixes, (num_examples, num_parts, -1)), (1, 0, 2)) output_parts_reshaped = tf.transpose( tf.reshape(output_parts, (num_examples, num_parts, -1)), (1, 0, 2)) inputs_reshaped = tf.reshape(tf.tile(inputs, (num_parts, 1)), (num_parts, num_examples, -1)) ground_truth_start_indices_reshaped = tf.transpose( tf.reshape(ground_truth_start_indices, (num_examples, num_parts))) ground_truth_end_indices_reshaped = tf.transpose( tf.reshape(ground_truth_end_indices, (num_examples, num_parts))) # Combine spec parts from all examples into one sequence with separator # tokens between examples and ending in EOS. shifts = tf.cumsum(tf.concat((tf.zeros( (num_parts, 1), dtype=tf.int64), ground_truth_end_indices_reshaped[:, :-1] + 1), 1), axis=-1) flat_shifts = tf.reshape(shifts, (-1, )) output_len = tf.shape(output_parts_reshaped)[-1] flat_spec_parts = tf.reshape(output_parts_reshaped, (-1, output_len)) flat_spec_parts = tf.pad(flat_spec_parts, [[0, 0], [0, max_target_length - output_len]]) combined_spec_parts = tf.vectorized_map( fn=lambda x: tf.roll(x[0], x[1], axis=0), elems=(flat_spec_parts, flat_shifts)) combined_spec_parts = tf.reshape(combined_spec_parts, (num_parts, num_examples, -1)) combined_spec_parts = tf.reduce_sum(combined_spec_parts, axis=1) first_zero_index = tf.math.count_nonzero(combined_spec_parts, axis=-1) combined_spec_parts += tf.one_hot( first_zero_index, depth=tf.shape(combined_spec_parts)[-1], dtype=tf.int64) * eos_id # Create a dataset containing data for all spec suffixes. dataset = tf.data.Dataset.from_tensor_slices({ 'inputs': inputs_reshaped, 'outputs': output_suffixes_reshaped, 'spec_parts': combined_spec_parts, 'start_index': ground_truth_start_indices_reshaped, 'end_index': ground_truth_end_indices_reshaped }) return dataset
def vectorized_fn(*args): """Vectorized version of `fn` that accepts arguments of any rank.""" with tf.name_scope(name or 'make_rank_polymorphic'): assertions = [] # If we got a single value for core_ndims, tile it across all args. core_ndims_structure = ( core_ndims if tf.nest.is_nested(core_ndims) else tf.nest.map_structure(lambda _: core_ndims, args)) # Build flat lists of all argument parts and their corresponding core # ndims. flat_core_ndims = tf.nest.flatten(core_ndims_structure) flat_args = nest.flatten_up_to( core_ndims_structure, args, check_types=False) # Filter to only the `Tensor`-valued args (taken to be those with `None` # values for `core_ndims`). Other args will be passed through to `fn` # unmodified. (vectorized_arg_core_ndims, vectorized_args, fn_of_vectorized_args) = _lock_in_non_vectorized_args( fn, arg_structure=core_ndims_structure, flat_core_ndims=flat_core_ndims, flat_args=flat_args) # `vectorized_map` requires all inputs to have a single, common batch # dimension `[n]`. So we broadcast all input parts to a common # batch shape, then flatten it down to a single dimension. # First, compute how many 'extra' (batch) ndims each part has. This must # be nonnegative. vectorized_arg_shapes = [tf.shape(arg) for arg in vectorized_args] batch_ndims = [ ps.rank_from_shape(arg_shape) - nd for (arg_shape, nd) in zip( vectorized_arg_shapes, vectorized_arg_core_ndims)] static_ndims = [tf.get_static_value(nd) for nd in batch_ndims] if any([nd and nd < 0 for nd in static_ndims]): raise ValueError('Cannot broadcast a Tensor having lower rank than the ' 'specified `core_ndims`! (saw input ranks {}, ' '`core_ndims` {}).'.format( tf.nest.map_structure( ps.rank_from_shape, vectorized_arg_shapes), vectorized_arg_core_ndims)) if validate_args: for nd, part, core_nd in zip( batch_ndims, vectorized_args, vectorized_arg_core_ndims): assertions.append(tf.debugging.assert_non_negative( nd, message='Cannot broadcast a Tensor having lower rank than ' 'the specified `core_ndims`! (saw {} vs minimum rank {}).'.format( part, core_nd))) # Next, split each part's shape into batch and core shapes, and # broadcast the batch shapes. with tf.control_dependencies(assertions): empty_shape = np.zeros([0], dtype=np.int32) batch_shapes, core_shapes = empty_shape, empty_shape if vectorized_arg_shapes: batch_shapes, core_shapes = zip(*[ (arg_shape[:nd], arg_shape[nd:]) for (arg_shape, nd) in zip(vectorized_arg_shapes, batch_ndims)]) broadcast_batch_shape = ( functools.reduce(ps.broadcast_shape, batch_shapes, [])) # Flatten all of the batch dimensions into one. n = tf.cast(ps.reduce_prod(broadcast_batch_shape), tf.int32) static_n = tf.get_static_value(n) if static_n == 1: result = fn(*args) else: # Pad all input parts to the common shape, then flatten # into the single leading dimension `[n]`. # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap # when batch rank is static so that we can exploit broadcasting. broadcast_vectorized_args = [ tf.broadcast_to(part, ps.concat( [broadcast_batch_shape, core_shape], axis=0)) for (part, core_shape) in zip(vectorized_args, core_shapes)] vectorized_args_with_flattened_batch_dim = [ tf.reshape(part, ps.concat([[n], core_shape], axis=0)) for (part, core_shape) in zip( broadcast_vectorized_args, core_shapes)] batched_result = tf.vectorized_map( fn_of_vectorized_args, vectorized_args_with_flattened_batch_dim) # Unflatten any `Tensor`s in the result. unflatten = lambda x: tf.reshape(x, ps.concat([ # pylint: disable=g-long-lambda broadcast_batch_shape, ps.shape(x)[1:]], axis=0)) result = tf.nest.map_structure( lambda x: unflatten(x) if tf.is_tensor(x) else x, batched_result, expand_composites=True) return result
def vectorized_fn(*args): """Vectorized version of `fn` that accepts arguments of any rank.""" with tf.name_scope(name or 'make_rank_polymorphic'): assertions = [] # If we got a single value for core_ndims, tile it across all args. core_ndims_structure = ( core_ndims if nest.is_nested(core_ndims) else nest.map_structure(lambda _: core_ndims, args)) # Build flat lists of all argument parts and their corresponding core # ndims. flat_core_ndims = nest.flatten(core_ndims_structure) parts = tf.nest.flatten(nest.map_structure_up_to( core_ndims_structure, tf.convert_to_tensor, args, check_types=False)) if len(parts) != len(flat_core_ndims): raise ValueError('Number of args does not match `core_ndims` ' '({} vs {}). Saw argument parts {}; core ' 'ndims {}.'.format(len(parts), len(flat_core_ndims), parts, flat_core_ndims)) # `vectorized_map` requires all inputs to have a single, common batch # dimension `[n]`. So we broadcast all input parts to a common # batch shape, then flatten it down to a single dimension. # First, compute how many 'extra' (batch) ndims each part has. This must # be nonnegative. part_shapes = [tf.shape(part) for part in parts] batch_ndims = [ prefer_static.rank_from_shape(part_shape) - nd for (part_shape, nd) in zip(part_shapes, flat_core_ndims)] static_ndims = [tf.get_static_value(nd) for nd in batch_ndims] if any([nd and nd < 0 for nd in static_ndims]): raise ValueError('Cannot broadcast a Tensor having lower rank than the ' 'specified `core_ndims`! (saw input ranks {}, ' '`core_ndims` {}).'.format( tf.nest.map_structure( prefer_static.rank_from_shape, part_shapes), flat_core_ndims)) if validate_args: for nd, part, core_nd in zip(batch_ndims, parts, flat_core_ndims): assertions.append(tf.debugging.assert_non_negative( nd, message='Cannot broadcast a Tensor having lower rank than ' 'the specified `core_ndims`! (saw {} vs minimum rank {}).'.format( part, core_nd))) # Next, split each part's shape into batch and core shapes, and # broadcast the batch shapes. with tf.control_dependencies(assertions): batch_shapes, core_shapes = zip(*[ (part_shape[:nd], part_shape[nd:]) for (part_shape, nd) in zip(part_shapes, batch_ndims)]) broadcast_batch_shape = functools.reduce( prefer_static.broadcast_shape, batch_shapes, []) # Flatten all of the batch dimensions into one. n = tf.cast(prefer_static.reduce_prod(broadcast_batch_shape), tf.int32) static_n = tf.get_static_value(n) if static_n == 1: result = fn(*args) else: # Pad all input parts to the common shape, then flatten # into the single leading dimension `[n]`. # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap # when batch rank is static so that we can exploit broadcasting. broadcast_parts = [ tf.broadcast_to(part, prefer_static.concat([broadcast_batch_shape, core_shape], axis=0)) for (part, core_shape) in zip(parts, core_shapes)] parts_with_flattened_batch_dim = [ tf.reshape(part, prefer_static.concat([[n], core_shape], axis=0)) for (part, core_shape) in zip(broadcast_parts, core_shapes)] # Run the vectorized computation batched_result = tf.vectorized_map(lambda args: fn(*args), nest.pack_sequence_as( args, parts_with_flattened_batch_dim)) # Unflatten the result result = nest.map_structure( lambda x: tf.reshape(x, prefer_static.concat([ # pylint: disable=g-long-lambda broadcast_batch_shape, prefer_static.shape(x)[1:]], axis=0)), batched_result) return result
def expected_calibration_error_quantiles(hit, pred_log_prob, num_buckets=20, axis=0, log_space_buckets=False, name=None): """Expected calibration error via `quantiles(exp(pred_log_prob),num_buckets)`. Calibration is a measure of how well a model reports its own uncertainty. A model is said to be "calibrated" if buckets of predicted probabilities have the same within bucket average accurcy. The exected calibration error is the average absolute difference between predicted probability and (bucket) average accuracy. That is: ```python bucket weight = bucket_count / tf.reduce_sum(bucket_count, axis=0) bucket_error = abs(bucket_accuracy - bucket_confidence) ece = tf.reduce_sum(bucket_weight * bucket_error, axis=0) ``` where `bucket_accuracy, bucket_confidence, bucket_count` are statistics aggregated by `num_buckets`-quantiles of `tf.math.exp(pred_log_prob)`. Note: `bucket_*` always have `num_buckets` size for the zero-th dimension. Args: hit: `bool` `Tensor` where `True` means the model prediction was correct and `False` means the model prediction was incorrect. Shape must broadcast with pred_log_prob. pred_log_prob: `Tensor` representing the model's predicted log probability for the given `hit`. Shape must broadcast with `hit`. num_buckets: `int` representing the number of buckets over which to aggregate hits. Buckets are quantiles of `exp(pred_log_prob)`. Default value: `20`. axis: Dimension over which to compute buckets and aggregate stats. Default value: `0`. log_space_buckets: When `False` bucket edges are computed from `tf.math.exp(pred_log_prob)`; when `True` bucket edges are computed from `pred_log_prob`. Default value: `False`. name: Prefer `str` name used for ops created by this function. Default value: `None` (i.e., `"expected_calibration_error_quantiles"`). Returns: ece: Expected calibration error; `tf.reduce_sum(abs(bucket_accuracy - bucket_confidence) * bucket_count, axis=0) / tf.reduce_sum(bucket_count, axis=0)`. bucket_accuracy: `Tensor` representing the within bucket average hits, i.e., total bucket hits divided by bucket count. Has shape `tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0)`. bucket_confidence: `Tensor` representing the within bucket average probability, i.e., total bucket predicted probability divided by bucket count. Has shape `tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0)`. bucket_count: `Tensor` representing the total number of obervations in each bucket. Has shape `tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0)`. bucket_pred_log_prob: `Tensor` representing `pred_log_prob` bucket edges. Always in log space, regardless of the value of `log_space_buckets`. bucket: `int` `Tensor` representing the bucket within which `pred_log_prob` lies. #### Examples ```python # Example 1: Generic use. label = tf.cast([0, 0, 1, 0, 1, 1], dtype=tf.bool) log_pred = tf.math.log([0.1, 0.05, 0.5, 0.2, 0.99, 0.99]) ( ece, acc, conf, cnt, edges, bucket, ) = tfp.stats.expected_calibration_error_quantiles( label, log_pred, num_buckets=3) # ece ==> tf.Tensor(0.145, shape=(), dtype=float32) # acc ==> tf.Tensor([0. 0. 1.], shape=(3,), dtype=float32) # conf ==> tf.Tensor([0.075, 0.2, 0.826665], shape=(3,), dtype=float32) # cnt ==> tf.Tensor([2. 1. 3.], shape=(3,), dtype=float32) ``` ```python # Example 2: Categorgical classification. # Assume we have evidence `x`, targets `y`, and model function `dnn`. d = tfd.Categorical(logits=dnn(x)) def all_categories(d): num_classes = tf.shape(d.logits_parameter())[-1] batch_ndims = tf.size(d.batch_shape_tensor()) expand_shape = tf.pad( [num_classes], paddings=[[0, batch_ndims]], constant_values=1) return tf.reshape(tf.range(num_classes, dtype=d.dtype), expand_shape) all_pred_log_prob = d.log_prob(all_categories(d)) yhat = tf.argmax(all_pred_log_prob, axis=0) def rollaxis(x, shift): return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0)) pred_log_prob = tf.gather(rollaxis(all_pred_log_prob, shift=-1), yhat, batch_dims=len(d.batch_shape)) hit = tf.equal(y, yhat) ( ece, acc, conf, cnt, edges, bucket, ) = tfp.stats.expected_calibration_error_quantiles( hit, pred_log_prob, num_buckets=10) ``` """ with tf.name_scope(name or 'expected_calibration_error_quantiles'): pred_log_prob = tf.convert_to_tensor(pred_log_prob, dtype_hint=tf.float32, name='pred_log_prob') dtype = pred_log_prob.dtype hit = tf.cast(hit, dtype, name='hit') # Make sure to compute quantiles in "prob" space not "log(prob)". if log_space_buckets: bucket_pred_log_prob = quantiles_lib.quantiles( pred_log_prob, num_quantiles=num_buckets, axis=axis) else: bucket_pred_log_prob = tf.math.log( quantiles_lib.quantiles(tf.math.exp(pred_log_prob), num_quantiles=num_buckets, axis=axis)) bucket = _find_bins(pred_log_prob, bucket_pred_log_prob, axis) def _fn(i): """`map_fn` body.""" keep = tf.equal(i, bucket) total_hit = tf.math.reduce_sum(tf.where(keep, hit, tf.constant(0., dtype)), axis=axis) total_count = tf.math.reduce_sum(tf.cast(keep, dtype), axis=axis) log_total_pred_prob = tf.math.reduce_logsumexp(tf.where( keep, pred_log_prob, tf.constant(-np.inf, dtype)), axis=axis) return total_hit, log_total_pred_prob, total_count # On the following line, we use vectorized_map instead of map_fn not for # efficiency reasons but because at the time of writing, map_fn doesn't # work correctly on the JAX substrate. Specifically, it does not like that # _fn returns a tuple. bucket_total_hit, bucket_log_total_pred_prob, bucket_count = ( tf.vectorized_map(fn=_fn, elems=tf.range(num_buckets, dtype=bucket.dtype))) n = tf.maximum(bucket_count, 1.) bucket_accuracy = bucket_total_hit / n bucket_confidence = tf.math.exp(bucket_log_total_pred_prob - tf.math.log(n)) bucket_error = abs(bucket_accuracy - bucket_confidence) n = ps.cast(ps.shape(pred_log_prob)[axis], dtype) ece = tf.math.reduce_sum(bucket_count * bucket_error, axis=0) / n return ( ece, bucket_accuracy, bucket_confidence, bucket_count, bucket_pred_log_prob, bucket, )
def _update_confusion_matrix_variables_optimized( variables_to_update, y_true, y_pred, thresholds, multi_label=False, sample_weights=None, label_weights=None, thresholds_with_epsilon=False, ): """Update confusion matrix variables with memory efficient alternative. Note that the thresholds need to be evenly distributed within the list, eg, the diff between consecutive elements are the same. To compute TP/FP/TN/FN, we are measuring a binary classifier C(t) = (predictions >= t) at each threshold 't'. So we have TP(t) = sum( C(t) * true_labels ) FP(t) = sum( C(t) * false_labels ) But, computing C(t) requires computation for each t. To make it fast, observe that C(t) is a cumulative integral, and so if we have thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} where n = num_thresholds, and if we can compute the bucket function B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) then we get C(t_i) = sum( B(j), j >= i ) which is the reversed cumulative sum in tf.cumsum(). We can compute B(i) efficiently by taking advantage of the fact that our thresholds are evenly distributed, in that width = 1.0 / (num_thresholds - 1) thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] Given a prediction value p, we can map it to its bucket by bucket_index(p) = floor( p * (num_thresholds - 1) ) so we can use tf.math.unsorted_segment_sum() to update the buckets in one pass. Consider following example: y_true = [0, 0, 1, 1] y_pred = [0.1, 0.5, 0.3, 0.9] thresholds = [0.0, 0.5, 1.0] num_buckets = 2 # [0.0, 1.0], (1.0, 2.0] bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets) = tf.math.floor([0.2, 1.0, 0.6, 1.8]) = [0, 0, 0, 1] # The meaning of this bucket is that if any of the label is true, # then 1 will be added to the corresponding bucket with the index. # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the # label for 1.8 is true, then 1 will be added to bucket 1. # # Note the second item "1.0" is floored to 0, since the value need to be # strictly larger than the bucket lower bound. # In the implementation, we use tf.math.ceil() - 1 to achieve this. tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices, num_segments=num_thresholds) = [1, 1, 0] # For [1, 1, 0] here, it means there is 1 true value contributed by bucket 0, # and 1 value contributed by bucket 1. When we aggregate them to together, # the result become [a + b + c, b + c, c], since large thresholds will always # contribute to the value for smaller thresholds. true_positive = tf.math.cumsum(tp_bucket_value, reverse=True) = [2, 1, 0] This implementation exhibits a run time and space complexity of O(T + N), where T is the number of thresholds and N is the size of predictions. Metrics that rely on standard implementation instead exhibit a complexity of O(T * N). Args: variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys and corresponding variables to update as values. y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. y_pred: A floating point `Tensor` of arbitrary shape and whose values are in the range `[0, 1]`. thresholds: A sorted floating point `Tensor` with value in `[0, 1]`. It need to be evenly distributed (the diff between each element need to be the same). multi_label: Optional boolean indicating whether multidimensional prediction/labels should be treated as multilabel responses, or flattened into a single label. When True, the valus of `variables_to_update` must have a second dimension equal to the number of labels in y_true and y_pred, and those tensors must not be RaggedTensors. sample_weights: Optional `Tensor` whose rank is either 0, or the same rank as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must be either `1`, or the same as the corresponding `y_true` dimension). label_weights: Optional tensor of non-negative weights for multilabel data. The weights are applied when calculating TP, FP, FN, and TN without explicit multilabel handling (i.e. when the data is to be flattened). thresholds_with_epsilon: Optional boolean indicating whether the leading and tailing thresholds has any epsilon added for floating point imprecisions. It will change how we handle the leading and tailing bucket. Returns: Update op. """ num_thresholds = thresholds.shape.as_list()[0] if sample_weights is None: sample_weights = 1.0 else: sample_weights = tf.__internal__.ops.broadcast_weights( tf.cast(sample_weights, dtype=y_pred.dtype), y_pred) if not multi_label: sample_weights = tf.reshape(sample_weights, [-1]) if label_weights is None: label_weights = 1.0 else: label_weights = tf.expand_dims(label_weights, 0) label_weights = tf.__internal__.ops.broadcast_weights( label_weights, y_pred) if not multi_label: label_weights = tf.reshape(label_weights, [-1]) weights = tf.multiply(sample_weights, label_weights) # We shouldn't need this, but in case there are predict value that is out of # the range of [0.0, 1.0] y_pred = tf.clip_by_value(y_pred, clip_value_min=0.0, clip_value_max=1.0) y_true = tf.cast(tf.cast(y_true, tf.bool), y_true.dtype) if not multi_label: y_true = tf.reshape(y_true, [-1]) y_pred = tf.reshape(y_pred, [-1]) true_labels = tf.multiply(y_true, weights) false_labels = tf.multiply((1.0 - y_true), weights) # Compute the bucket indices for each prediction value. # Since the predict value has to be strictly greater than the thresholds, # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket. # We have to use math.ceil(val) - 1 for the bucket. bucket_indices = tf.math.ceil(y_pred * (num_thresholds - 1)) - 1 if thresholds_with_epsilon: # In this case, the first bucket should actually take into account since # the any prediction between [0.0, 1.0] should be larger than the first # threshold. We change the bucket value from -1 to 0. bucket_indices = tf.nn.relu(bucket_indices) bucket_indices = tf.cast(bucket_indices, tf.int32) if multi_label: # We need to run bucket segment sum for each of the label class. In the # multi_label case, the rank of the label is 2. We first transpose it so # that the label dim becomes the first and we can parallel run though them. true_labels = tf.transpose(true_labels) false_labels = tf.transpose(false_labels) bucket_indices = tf.transpose(bucket_indices) def gather_bucket(label_and_bucket_index): label, bucket_index = ( label_and_bucket_index[0], label_and_bucket_index[1], ) return tf.math.unsorted_segment_sum( data=label, segment_ids=bucket_index, num_segments=num_thresholds, ) tp_bucket_v = tf.vectorized_map(gather_bucket, (true_labels, bucket_indices)) fp_bucket_v = tf.vectorized_map(gather_bucket, (false_labels, bucket_indices)) tp = tf.transpose(tf.cumsum(tp_bucket_v, reverse=True, axis=1)) fp = tf.transpose(tf.cumsum(fp_bucket_v, reverse=True, axis=1)) else: tp_bucket_v = tf.math.unsorted_segment_sum( data=true_labels, segment_ids=bucket_indices, num_segments=num_thresholds, ) fp_bucket_v = tf.math.unsorted_segment_sum( data=false_labels, segment_ids=bucket_indices, num_segments=num_thresholds, ) tp = tf.cumsum(tp_bucket_v, reverse=True) fp = tf.cumsum(fp_bucket_v, reverse=True) # fn = sum(true_labels) - tp # tn = sum(false_labels) - fp if (ConfusionMatrix.TRUE_NEGATIVES in variables_to_update or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update): if multi_label: total_true_labels = tf.reduce_sum(true_labels, axis=1) total_false_labels = tf.reduce_sum(false_labels, axis=1) else: total_true_labels = tf.reduce_sum(true_labels) total_false_labels = tf.reduce_sum(false_labels) update_ops = [] if ConfusionMatrix.TRUE_POSITIVES in variables_to_update: variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES] update_ops.append(variable.assign_add(tp)) if ConfusionMatrix.FALSE_POSITIVES in variables_to_update: variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES] update_ops.append(variable.assign_add(fp)) if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update: variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES] tn = total_false_labels - fp update_ops.append(variable.assign_add(tn)) if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update: variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES] fn = total_true_labels - tp update_ops.append(variable.assign_add(fn)) return tf.group(update_ops)