def _decode(self, image_buffer, image_format): """Decodes the image buffer. Args: image_buffer: The tensor representing the encoded image tensor. image_format: The image format for the image in `image_buffer`. If image format is `raw`, all images are expected to be in this format, otherwise this op can decode a mix of `jpg` and `png` formats. Returns: A tensor that represents decoded image of self._shape, or (?, ?, self._channels) if self._shape is not specified. """ def decode_image(): """Decodes a png or jpg based on the headers.""" return image_ops.decode_image(image_buffer, self._channels) def decode_raw(): """Decodes a raw image.""" return parsing_ops.decode_raw(image_buffer, out_type=self._dtype) pred_fn_pairs = { math_ops.logical_or( math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, } image = control_flow_ops.case( pred_fn_pairs, default=decode_image, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: image = array_ops.reshape(image, self._shape) return image
def rot90(image, k=1, name=None): """Rotate an image counter-clockwise by 90 degrees. Args: image: A 3-D tensor of shape `[height, width, channels]`. k: A scalar integer. The number of times the image is rotated by 90 degrees. name: A name for this operation (optional). Returns: A rotated 3-D tensor of the same type and shape as `image`. """ with ops.name_scope(name, 'rot90', [image, k]) as scope: image = ops.convert_to_tensor(image, name='image') _Check3DImage(image, require_static=False) k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k') k.get_shape().assert_has_rank(0) k = math_ops.mod(k, 4) def _rot90(): return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2]) def _rot180(): return array_ops.reverse_v2(image, [0, 1]) def _rot270(): return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1]) cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180), (math_ops.equal(k, 3), _rot270)] ret = control_flow_ops.case(cases, default=lambda: image, exclusive=True, name=scope) ret.set_shape([None, None, image.get_shape()[2]]) return ret
def _decode(self, image_buffer, image_format): """Decodes the image buffer. Args: image_buffer: T tensor representing the encoded image tensor. image_format: The image format for the image in `image_buffer`. Returns: A decoder image. """ def decode_png(): return image_ops.decode_png(image_buffer, self._channels) def decode_raw(): return parsing_ops.decode_raw(image_buffer, dtypes.uint8) def decode_jpg(): return image_ops.decode_jpeg(image_buffer, self._channels) image = control_flow_ops.case({ math_ops.logical_or(math_ops.equal(image_format, 'png'), math_ops.equal(image_format, 'PNG')): decode_png, math_ops.logical_or(math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, }, default=decode_jpg, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: image = array_ops.reshape(image, self._shape) return image
def _testReturnValues(self, fn_true, fn_false, expected_value_true, expected_value_false, strict=False, check_cond=True, feed_dict=None): if feed_dict is None: feed_dict = {} condition = array_ops.placeholder(dtypes.bool) output_cond = control_flow_ops.cond(condition, fn_true, fn_false, strict=strict) output_case = control_flow_ops.case([(condition, fn_true)], fn_false, strict=strict) with self.test_session() as sess: variables.global_variables_initializer().run() true_feed_dict = {condition: True} true_feed_dict.update(feed_dict) result_cond, result_case = sess.run([output_cond, output_case], feed_dict=true_feed_dict) self.assertAllEqualNested(result_cond, expected_value_true) if check_cond: self.assertAllEqualNested(result_case, expected_value_true) false_feed_dict = {condition: False} false_feed_dict.update(feed_dict) result_cond, result_case = sess.run([output_cond, output_case], feed_dict=false_feed_dict) self.assertAllEqualNested(result_cond, expected_value_false) if check_cond: self.assertAllEqualNested(result_case, expected_value_false)
def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimator( damping_fn=lambda: 0.2, variables=[self.weights], layer_collection=self.layer_collection, cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = training_util.get_or_create_global_step() inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() for matrix in fisher_factor._inverses_by_damping.values() ] inv_update_op_thunks = fisher_estimator.inv_update_thunks inv_update_op = control_flow_ops.case( [(math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(inv_update_op_thunks)]) increment_global_step = global_step.assign_add(1) sess.run(variables.global_variables_initializer()) initial_inv_values = sess.run(inv_matrices) # Ensure there's one update per inverse matrix. This is true as long as # there's no fan-in/fan-out or parameter re-use. self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) # Test is no-op if only 1 invariance matrix. assert len(inv_matrices) > 1 # Assign each covariance matrix a value other than the identity. This # ensures that the inverse matrices are updated to something different as # well. cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] sess.run([ cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) for cov_matrix in cov_matrices ]) for i in range(len(inv_matrices)): # Compare new and old inverse values new_inv_values = sess.run(inv_matrices) is_inv_equal = [ np.allclose(initial_inv_value, new_inv_value) for (initial_inv_value, new_inv_value) in zip(initial_inv_values, new_inv_values) ] num_inv_equal = sum(is_inv_equal) # Ensure exactly one inverse matrix changes per step. self.assertEqual(num_inv_equal, len(inv_matrices) - i) # Run all inverse update ops. sess.run(inv_update_op) sess.run(increment_global_step)
def testCase_dict(self): x = constant_op.constant(2) conditions = { math_ops.equal(x, 1): lambda: constant_op.constant(2), math_ops.equal(x, 2): lambda: constant_op.constant(4) } output = control_flow_ops.case(conditions, exclusive=True) self.assertEqual(4, self.evaluate(output))
def testCase_withoutDefault_oneCondition(self): x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))] output = control_flow_ops.case(conditions, exclusive=True) with self.test_session() as sess: self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): sess.run(output, feed_dict={x: 4})
def piecewise_constant(x, boundaries, values, name=None): """ Piecewise constant from boundaries and interval values. Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5 for steps 100001 to 110000, and 0.1 for any additional steps. ```python global_step = tf.Variable(0, trainable=False) boundaries = [100000, 110000] values = [1.0, 0.5, 0.1] learning_rate = tf.train.piecewise_constant(global_step, boundaries, values) # Later, whenever we perform an optimization step, we increment global_step. ``` Args: x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. boundaries: A list of `Tensor`s or `int`s or `float`s with strictly increasing entries, and with all elements having the same type as `x`. values: A list of `Tensor`s or float`s or `int`s that specifies the values for the intervals defined by `boundaries`. It should have one more element than `boundaries`, and all elements should have the same type. name: A string. Optional name of the operation. Defaults to 'PiecewiseConstant'. Returns: A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`, `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., and values[-1] when `x > boundaries[-1]`. """ with ops.name_scope(name, 'PiecewiseConstant', [x, boundaries, values, name]) as name: x = ops.convert_to_tensor(x) # Avoid explicit conversion to x's dtype. This could result in faulty # comparisons, for example if floats are converted to integers. boundaries = ops.convert_n_to_tensor(boundaries) if not all(b.dtype == x.dtype for b in boundaries): raise ValueError('boundaries must have the same dtype as x.') # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing. values = ops.convert_n_to_tensor(values) if not all(v.dtype == values[0].dtype for v in values): raise ValueError('values must have elements all with the same dtype.') pred_fn_pairs = {} pred_fn_pairs[x <= boundaries[0]] = lambda: values[0] pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1] for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): # Need to bind v here; can do this with lambda v=v: ... pred = (x > low) & (x <= high) pred_fn_pairs[pred] = lambda v=v: v # The default isn't needed here because our conditions are mutually # exclusive and exhaustive, but tf.case requires it. default = lambda: values[0] return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
def testCase_withDefault(self): x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), (math_ops.equal(x, 2), lambda: constant_op.constant(4))] default = lambda: constant_op.constant(6) output = control_flow_ops.case(conditions, default, exclusive=True) with self.test_session() as sess: self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
def _decode(self, image_buffer, image_format): """Decodes the image buffer. Args: image_buffer: The tensor representing the encoded image tensor. image_format: The image format for the image in `image_buffer`. If image format is `raw`, all images are expected to be in this format, otherwise this op can decode a mix of `jpg` and `png` formats. Returns: A tensor that represents decoded image of self._shape, or (?, ?, self._channels) if self._shape is not specified. """ def decode_image(): """Decodes a image based on the headers.""" return math_ops.cast( image_ops.decode_image(image_buffer, channels=self._channels), self._dtype) def decode_jpeg(): """Decodes a jpeg image with specified '_dct_method'.""" return math_ops.cast( image_ops.decode_jpeg( image_buffer, channels=self._channels, dct_method=self._dct_method), self._dtype) def check_jpeg(): """Checks if an image is jpeg.""" # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image # in order to feed the jpeg specify parameter 'dct_method'. return control_flow_ops.cond( image_ops.is_jpeg(image_buffer), decode_jpeg, decode_image, name='cond_jpeg') def decode_raw(): """Decodes a raw image.""" return parsing_ops.decode_raw(image_buffer, out_type=self._dtype) pred_fn_pairs = { math_ops.logical_or( math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, } image = control_flow_ops.case( pred_fn_pairs, default=check_jpeg, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: image = array_ops.reshape(image, self._shape) return image
def test_singleton_strict(self): fn_tensor = lambda: constant_op.constant(1) fn_list = lambda: [constant_op.constant(2)] fn_tuple = lambda: (constant_op.constant(3),) with self.assertRaises(ValueError): control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list, strict=True) with self.assertRaises(TypeError): control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple, strict=True) with self.assertRaises(ValueError): control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list, strict=True) with self.assertRaises(TypeError): control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple, strict=True)
def _testShape(self, fn_true, fn_false, expected_shape, strict=False): condition = array_ops.placeholder(dtypes.bool) output_cond = control_flow_ops.cond(condition, fn_true, fn_false, strict=strict) self.assertEqual(_RawNestedShape(_GetNestedShape(output_cond)), _RawNestedShape(expected_shape)) output_case = control_flow_ops.case([(condition, fn_true)], fn_false, strict=strict) self.assertEqual(_RawNestedShape(_GetNestedShape(output_case)), _RawNestedShape(expected_shape))
def testCase_multiple_matches_exclusive(self): x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), (math_ops.equal(x, 2), lambda: constant_op.constant(4)), (math_ops.equal(x, 2), lambda: constant_op.constant(6))] default = lambda: constant_op.constant(8) output = control_flow_ops.case(conditions, default, exclusive=True) with self.test_session() as sess: self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): sess.run(output, feed_dict={x: 2})
def _decode(self, image_buffer, image_format): """Decodes the image buffer. Args: image_buffer: The tensor representing the encoded image tensor. image_format: The image format for the image in `image_buffer`. Returns: A tensor that represents decoded image of self._shape, or (?, ?, self._channels) if self._shape is not specified. """ def decode_png(): return image_ops.decode_png( image_buffer, self._channels, dtype=self._dtype) def decode_raw(): return parsing_ops.decode_raw(image_buffer, out_type=self._dtype) def decode_jpg(): if self._dtype != dtypes.uint8: raise ValueError( 'jpeg decoder can only be used to decode to tf.uint8 but %s was ' 'requested for a jpeg image.' % self._dtype) return image_ops.decode_jpeg(image_buffer, self._channels) # For RGBA images JPEG is not a valid decoder option. if self._channels > 3: pred_fn_pairs = { math_ops.logical_or( math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, } default_decoder = decode_png else: pred_fn_pairs = { math_ops.logical_or( math_ops.equal(image_format, 'png'), math_ops.equal(image_format, 'PNG')): decode_png, math_ops.logical_or( math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, } default_decoder = decode_jpg image = control_flow_ops.case( pred_fn_pairs, default=default_decoder, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: image = array_ops.reshape(image, self._shape) return image
def testCase_multiple_matches_exclusive(self): x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), (math_ops.equal(x, 2), lambda: constant_op.constant(4)), (math_ops.equal(x, 2), lambda: constant_op.constant(6))] default = lambda: constant_op.constant(8) output = control_flow_ops.case(conditions, default, exclusive=True) with self.test_session() as sess: self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) with self.assertRaisesRegexp( errors.InvalidArgumentError, "More than one condition evaluated as True"): sess.run(output, feed_dict={x: 2})
def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = training_util.get_or_create_global_step() (cov_variable_thunks, cov_update_op_thunks, _, _) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] cov_update_op = control_flow_ops.case([ (math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(cov_update_op_thunks) ]) increment_global_step = global_step.assign_add(1) sess.run(variables.global_variables_initializer()) initial_cov_values = sess.run(cov_matrices) # Ensure there's one update per covariance matrix. self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) # Test is no-op if only 1 covariance matrix. assert len(cov_matrices) > 1 for i in range(len(cov_matrices)): # Compare new and old covariance values new_cov_values = sess.run(cov_matrices) is_cov_equal = [ np.allclose(initial_cov_value, new_cov_value) for (initial_cov_value, new_cov_value ) in zip(initial_cov_values, new_cov_values) ] num_cov_equal = sum(is_cov_equal) # Ensure exactly one covariance matrix changes per step. self.assertEqual(num_cov_equal, len(cov_matrices) - i) # Run all covariance update ops. sess.run(cov_update_op) sess.run(increment_global_step)
def _decode(self, image_buffer, image_format): """Decodes the image buffer. Args: image_buffer: The tensor representing the encoded image tensor. image_format: The image format for the image in `image_buffer`. If image format is `raw`, all images are expected to be in this format, otherwise this op can decode a mix of `jpg` and `png` formats. Returns: A tensor that represents decoded image of self._shape, or (?, ?, self._channels) if self._shape is not specified. """ def decode_image(): """Decodes a image based on the headers.""" return image_ops.decode_image(image_buffer, channels=self._channels) def decode_jpeg(): """Decodes a jpeg image with specified '_dct_method'.""" return image_ops.decode_jpeg( image_buffer, channels=self._channels, dct_method=self._dct_method) def check_jpeg(): """Checks if an image is jpeg.""" # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image # in order to feed the jpeg specify parameter 'dct_method'. return control_flow_ops.cond( image_ops.is_jpeg(image_buffer), decode_jpeg, decode_image, name='cond_jpeg') def decode_raw(): """Decodes a raw image.""" return parsing_ops.decode_raw(image_buffer, out_type=self._dtype) pred_fn_pairs = { math_ops.logical_or( math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, } image = control_flow_ops.case( pred_fn_pairs, default=check_jpeg, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: image = array_ops.reshape(image, self._shape) return image
def testConvertCase(self): """Tests that a v1 case() construction converts properly.""" with ops.Graph().as_default(): with variable_scope.variable_scope("", use_resource=False): control_flow_v2_toggles.disable_control_flow_v2() x = variable_scope.get_variable("x", initializer=1.0) y = variable_scope.get_variable("y", initializer=2.0) _ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)], default=lambda: y) with session_lib.Session() as sess: sess.run(variables.global_variables_initializer()) variable_graph_def = sess.graph.as_graph_def() constant_graph_def = ( convert_to_constants .convert_variables_to_constants_from_session_graph( sess, variable_graph_def, ["case/cond/Merge"])) self._assertGraphContains( constant_graph_def, """ node { name: "x" op: "Const" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}} } node { name: "y" op: "Const" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}} } node {name: "x/read" op: "Identity" input: "x"} node {name: "y/read" op: "Identity" input: "y"} node {name: "Less" op: "Less" input: "x/read" input: "y/read"} node {name: "case/cond/pred_id" op: "Identity" input: "Less"} node { name: "case/cond/Switch_1" op: "Switch" input: "case/cond/pred_id" input: "x/read" } node { name: "case/cond/Switch_2" op: "Switch" input: "case/cond/pred_id" input: "y/read" } node { name: "case/cond/Merge" op: "Merge" input: "case/cond/Switch_2" input: "case/cond/Switch_1:1" attr {key: "T" value {type: DT_FLOAT}} }""")
def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = training_util.get_or_create_global_step() (cov_variable_thunks, cov_update_op_thunks, _, _) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] cov_update_op = control_flow_ops.case( [(math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(cov_update_op_thunks)]) increment_global_step = global_step.assign_add(1) sess.run(variables.global_variables_initializer()) initial_cov_values = sess.run(cov_matrices) # Ensure there's one update per covariance matrix. self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) # Test is no-op if only 1 covariance matrix. assert len(cov_matrices) > 1 for i in range(len(cov_matrices)): # Compare new and old covariance values new_cov_values = sess.run(cov_matrices) is_cov_equal = [ np.allclose(initial_cov_value, new_cov_value) for (initial_cov_value, new_cov_value) in zip(initial_cov_values, new_cov_values) ] num_cov_equal = sum(is_cov_equal) # Ensure exactly one covariance matrix changes per step. self.assertEqual(num_cov_equal, len(cov_matrices) - i) # Run all covariance update ops. sess.run(cov_update_op) sess.run(increment_global_step)
def control_map_fn(x, y): def multiply(): return x * 2 def divide(): return x // 2 pred_fn_pairs = { math_ops.logical_or(math_ops.equal(y, 2), math_ops.equal(y, 3)): divide, } return control_flow_ops.case( pred_fn_pairs, default=multiply, exclusive=True)
def _decode(self, image_buffer, image_format): """Decodes the image buffer. Args: image_buffer: The tensor representing the encoded image tensor. image_format: The image format for the image in `image_buffer`. Returns: A tensor that represents decoded image of self._shape, or (?, ?, self._channels) if self._shape is not specified. """ def decode_png(): return image_ops.decode_png(image_buffer, self._channels) def decode_raw(): return parsing_ops.decode_raw(image_buffer, dtypes.uint8) def decode_jpg(): return image_ops.decode_jpeg(image_buffer, self._channels) # For RGBA images JPEG is not a valid decoder option. if self._channels > 3: pred_fn_pairs = { math_ops.logical_or(math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, } default_decoder = decode_png else: pred_fn_pairs = { math_ops.logical_or(math_ops.equal(image_format, 'png'), math_ops.equal(image_format, 'PNG')): decode_png, math_ops.logical_or(math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, } default_decoder = decode_jpg image = control_flow_ops.case(pred_fn_pairs, default=default_decoder, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: image = array_ops.reshape(image, self._shape) return image
def testCase_withoutDefault(self): x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), (math_ops.equal(x, 2), lambda: constant_op.constant(4)), (math_ops.equal(x, 3), lambda: constant_op.constant(6))] output = control_flow_ops.case(conditions, exclusive=True) with self.test_session() as sess: self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) with self.assertRaisesRegexp( errors.InvalidArgumentError, r"\[None of the conditions evaluated as True. " r"Conditions: \(Equal:0, Equal_1:0, Equal_2:0\), Values:\] " r"\[0 0 0\]"): sess.run(output, feed_dict={x: 4})
def conditional_decoding(keys_to_tensors): """See base class.""" image_buffer = keys_to_tensors['image/encoded'] image_format = keys_to_tensors['image/format'] def decode_png(): return image_ops.decode_png(image_buffer, 3) def decode_jpg(): return image_ops.decode_jpeg(image_buffer, 3) image = control_flow_ops.case( {math_ops.equal(image_format, 'png'): decode_png}, default=decode_jpg, exclusive=True) image = array_ops.reshape(image, image_shape) return image
def decayed_lr(x, boundaries, values, name): """Helper to recompute learning rate; most helpful in eager-mode.""" with ops.name_scope(name, "PiecewiseConstant", [x, boundaries, values, name]) as name: boundaries = ops.convert_n_to_tensor(boundaries) values = ops.convert_n_to_tensor(values) x_recomp = ops.convert_to_tensor(x) # Avoid explicit conversion to x's dtype. This could result in faulty # comparisons, for example if floats are converted to integers. for i, b in enumerate(boundaries): if b.dtype.base_dtype != x_recomp.dtype.base_dtype: # We can promote int32 boundaries to int64 without loss of precision. # This covers the most common case where the user passes in boundaries # as an array of Python integers. if (b.dtype.base_dtype == dtypes.int32 and x_recomp.dtype.base_dtype == dtypes.int64): b = math_ops.cast(b, x_recomp.dtype.base_dtype) boundaries[i] = b else: raise ValueError( "Boundaries (%s) must have the same dtype as x (%s)." % (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) # TODO(rdipietro): Ensure that boundaries' elements strictly increases. for v in values[1:]: if v.dtype.base_dtype != values[0].dtype.base_dtype: raise ValueError( "Values must have elements all with the same dtype (%s vs %s)." % (values[0].dtype.base_dtype, v.dtype.base_dtype)) pred_fn_pairs = [] pred_fn_pairs.append( (x_recomp <= boundaries[0], lambda: values[0])) pred_fn_pairs.append( (x_recomp > boundaries[-1], lambda: values[-1])) for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): # Need to bind v here; can do this with lambda v=v: ... pred = (x_recomp > low) & (x_recomp <= high) pred_fn_pairs.append((pred, lambda v=v: v)) # The default isn't needed here because our conditions are mutually # exclusive and exhaustive, but tf.case requires it. default = lambda: values[0] return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
def testConvertV2ResourceCase(self): """Tests that a v2 case() with resource variables converts properly.""" with ops.Graph().as_default(): with variable_scope.variable_scope("", use_resource=True): control_flow_v2_toggles.enable_control_flow_v2() x = variable_scope.get_variable("x", initializer=1.0) y = variable_scope.get_variable("y", initializer=2.0) _ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)], default=lambda: y) control_flow_v2_toggles.disable_control_flow_v2() with session_lib.Session() as sess: sess.run(variables.global_variables_initializer()) variable_graph_def = sess.graph.as_graph_def() constant_graph_def = ( convert_to_constants .convert_variables_to_constants_from_session_graph( sess, variable_graph_def, ["case/cond"])) self._assertGraphContains( constant_graph_def, """ node {name: "x" op: "Const"} node {name: "y" op: "Const"} node { name: "case/cond" op: "If" input: "Less" input: "x" input: "y" attr {key: "Tcond" value {type: DT_BOOL}} attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}} attr {key: "Tout" value {list {type: DT_FLOAT}}} } library { function { signature { name: "case_cond_false_frozen_0" input_arg {name: "placeholder" type: DT_FLOAT} input_arg {name: "readvariableop_y" type: DT_FLOAT} output_arg {name: "readvariableop" type: DT_FLOAT} } } function { signature { name: "case_cond_true_frozen_0" input_arg {name: "placeholder" type: DT_FLOAT} input_arg {name: "readvariableop_x" type: DT_FLOAT} output_arg {name: "readvariableop" type: DT_FLOAT} } } }""")
def compute_piecewise(): pred_fn_pairs = [] pred_fn_pairs.append( (x_recomp <= boundaries[0], lambda: values[0])) pred_fn_pairs.append( (x_recomp > boundaries[-1], lambda: values[-1])) for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): # Need to bind v here; can do this with lambda v=v: ... pred = (x_recomp > low) & (x_recomp <= high) pred_fn_pairs.append((pred, lambda v=v: v)) # The default isn't needed here because our conditions are mutually # exclusive and exhaustive, but tf.case requires it. default = lambda: values[0] return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
def testGraphWithSwitch(self): """Freezes a graph which contains a Switch with type RESOURCE_DT.""" with ops.Graph().as_default(): with variable_scope.variable_scope("", use_resource=True): x = variable_scope.get_variable("var_x", initializer=1.0) y = variable_scope.get_variable("var_y", initializer=2.0) f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0) f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0) cond_node = control_flow_ops.case([(gen_math_ops.less(x, y), f1)], default=f2) _ = math_ops_lib.multiply(cond_node, 2.0, name="output_node") with session.Session() as sess: sess.run(variables.global_variables_initializer()) variable_graph_def = sess.graph.as_graph_def() constant_graph_def = graph_util.convert_variables_to_constants( sess, variable_graph_def, ["output_node"]) self._ensure_no_variables_in_graph(constant_graph_def)
def _testReturnValues(self, fn_true, fn_false, expected_value_true, expected_value_false, strict=False, check_cond=True): condition = array_ops.placeholder(dtypes.bool) output_cond = control_flow_ops.cond(condition, fn_true, fn_false, strict=strict) output_case = control_flow_ops.case([(condition, fn_true)], fn_false, strict=strict) with self.test_session() as sess: variables.global_variables_initializer().run() result_cond, result_case = sess.run([output_cond, output_case], feed_dict={condition: True}) self.assertAllEqualNested(result_cond, expected_value_true) if check_cond: self.assertAllEqualNested(result_case, expected_value_true) result_cond, result_case = sess.run([output_cond, output_case], feed_dict={condition: False}) self.assertAllEqualNested(result_cond, expected_value_false) if check_cond: self.assertAllEqualNested(result_case, expected_value_false)
def decayed_lr(x, boundaries, values, name): """Helper to recompute learning rate; most helpful in eager-mode.""" with ops.name_scope(name, "PiecewiseConstant", [x, boundaries, values, name]) as name: boundaries = ops.convert_n_to_tensor(boundaries) values = ops.convert_n_to_tensor(values) x_recomp = ops.convert_to_tensor(x) # Avoid explicit conversion to x's dtype. This could result in faulty # comparisons, for example if floats are converted to integers. for i, b in enumerate(boundaries): if b.dtype.base_dtype != x_recomp.dtype.base_dtype: # We can promote int32 boundaries to int64 without loss of precision. # This covers the most common case where the user passes in boundaries # as an array of Python integers. if (b.dtype.base_dtype == dtypes.int32 and x_recomp.dtype.base_dtype == dtypes.int64): b = math_ops.cast(b, x_recomp.dtype.base_dtype) boundaries[i] = b else: raise ValueError( "Boundaries (%s) must have the same dtype as x (%s)." % (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) # TODO(rdipietro): Ensure that boundaries' elements strictly increases. for v in values[1:]: if v.dtype.base_dtype != values[0].dtype.base_dtype: raise ValueError( "Values must have elements all with the same dtype (%s vs %s)." % (values[0].dtype.base_dtype, v.dtype.base_dtype)) pred_fn_pairs = [] pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0])) pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1])) for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): # Need to bind v here; can do this with lambda v=v: ... pred = (x_recomp > low) & (x_recomp <= high) pred_fn_pairs.append((pred, lambda v=v: v)) # The default isn't needed here because our conditions are mutually # exclusive and exhaustive, but tf.case requires it. default = lambda: values[0] return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
def __call__(self, step): with ops.name_scope_v2(self.name or "PiecewiseConstant"): boundaries = ops.convert_n_to_tensor(self.boundaries) values = ops.convert_n_to_tensor(self.values) x_recomp = ops.convert_to_tensor_v2(step) for i, b in enumerate(boundaries): if b.dtype.base_dtype != x_recomp.dtype.base_dtype: # We cast the boundaries to have the same type as the step b = math_ops.cast(b, x_recomp.dtype.base_dtype) boundaries[i] = b pred_fn_pairs = [] pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0])) pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1])) for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): # Need to bind v here; can do this with lambda v=v: ... pred = (x_recomp > low) & (x_recomp <= high) pred_fn_pairs.append((pred, lambda v=v: v)) # The default isn't needed here because our conditions are mutually # exclusive and exhaustive, but tf.case requires it. default = lambda: values[0] return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
def piecewise_linear_schedule(global_step, boundaries, values, name=None): if global_step is None: raise ValueError("global_step is required for piecewise_linear_schedule.") assert len(boundaries) == len(values), "boundaries length ({}) should equal values length ({})".format(len(boundaries), len(values)) with ops.name_scope(name, "piecewise_linear_schedule", [global_step, boundaries, values]) as name: x = math_ops.cast(global_step, tf.float32) pred_fn_pairs = [] pred_fn_pairs.append((x <= boundaries[0], lambda: values[0])) pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1])) for low, high, low_v, high_v in zip(boundaries[:-1], boundaries[1:], values[:-1], values[1:]): # Need to bind v here; can do this with lambda v=v: ... pred = (x > low) & (x <= high) r = (x - low) / (high - low) v = r * high_v + (1-r) * low_v pred_fn_pairs.append((pred, lambda v=v: v)) # The default isn't needed here because our conditions are mutually # exclusive and exhaustive, but tf.case requires it. default = lambda: values[0] return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
def _decode(self, image_buffer, image_format): """Decodes the image buffer. Args: image_buffer: T tensor representing the encoded image tensor. image_format: The image format for the image in `image_buffer`. Returns: A decoder image. """ def decode_png(): return image_ops.decode_png(image_buffer, self._channels) def decode_jpg(): return image_ops.decode_jpeg(image_buffer, self._channels) image = control_flow_ops.case({ math_ops.equal(image_format, 'png'): decode_png, }, default=decode_jpg, exclusive=True) if self._shape is not None: image = array_ops.reshape(image, self._shape) return image
def rot90(image, k=1, name=None): """Rotate an image counter-clockwise by 90 degrees. Args: image: A 3-D tensor of shape `[height, width, channels]`. k: A scalar integer. The number of times the image is rotated by 90 degrees. name: A name for this operation (optional). Returns: A rotated 3-D tensor of the same type and shape as `image`. """ with ops.name_scope(name, 'rot90', [image, k]) as scope: image = ops.convert_to_tensor(image, name='image') _Check3DImage(image, require_static=False) k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k') k.get_shape().assert_has_rank(0) k = math_ops.mod(k, 4) def _rot90(): return array_ops.transpose( array_ops.reverse(image, [False, True, False]), [1, 0, 2]) def _rot180(): return array_ops.reverse(image, [True, True, False]) def _rot270(): return array_ops.reverse(array_ops.transpose(image, [1, 0, 2]), [False, True, False]) cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180), (math_ops.equal(k, 3), _rot270)] ret = control_flow_ops.case(cases, default=lambda: image, exclusive=True, name=scope) ret.set_shape([None, None, image.get_shape()[2]]) return ret
def control_map_fn(x, y): def multiply(): return x * 2 def divide(): return x // 2 def defaults_two(): return control_flow_ops.cond( math_ops.equal(math_ops.mod(x, 2), 0), multiply, divide, name="cond_mult") pred_fn_pairs = { math_ops.logical_or(math_ops.equal(y, 2), math_ops.equal(y, 3)): defaults_two, } return control_flow_ops.case( pred_fn_pairs, default=multiply, exclusive=True)
def compute_damping(): """"Adapts damping parameter based on "reduction ratio". Reduction ratio captures how closely the quadratic approximation to the loss function approximates the actual loss within a trust region. The damping update tries to make the damping as small as possible while maintaining the property that the quadratic model remains a good local approximation to the loss function. Returns: An Op to assign newly computed damping value to `self._damping`. """ prev_batch_loss = self._loss_fn(prev_batch) with ops.control_dependencies([prev_batch_loss]): rho_assign = self._rho.assign( (prev_batch_loss - self._prev_loss) / self._q_model_change) with ops.control_dependencies([rho_assign]): new_damping = control_flow_ops.case( [(self._rho < 0.25, lambda: self.damping / self._omega), (self._rho > 0.75, lambda: self.damping * self._omega)], lambda: self.damping) with ops.control_dependencies([new_damping]): new_damping_min = math_ops.maximum(new_damping, self._min_damping) return control_flow_ops.group(self._damping.assign(new_damping_min))
def _decode(self, image_buffer, image_format): """Decodes the image buffer. Args: image_buffer: T tensor representing the encoded image tensor. image_format: The image format for the image in `image_buffer`. Returns: A decoder image. """ def decode_png(): return image_ops.decode_png(image_buffer, self._channels) def decode_raw(): return parsing_ops.decode_raw(image_buffer, dtypes.uint8) def decode_jpg(): return image_ops.decode_jpeg(image_buffer, self._channels) image = control_flow_ops.case( { math_ops.logical_or(math_ops.equal(image_format, 'png'), math_ops.equal(image_format, 'PNG')): decode_png, math_ops.logical_or(math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, }, default=decode_jpg, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: image = array_ops.reshape(image, self._shape) return image
def _decode(self, image_buffer, image_format, image_height, image_width): """Decodes the image buffer. Args: image_buffer: The tensor representing the encoded image tensor. image_format: The image format for the image in `image_buffer`. If image format is `raw`, all images are expected to be in this format, otherwise this op can decode a mix of `jpg` and `png` formats. Returns: A tensor that represents decoded image of self._shape, or (?, ?, self._channels) if self._shape is not specified. """ def decode_image(): """Decodes a png or jpg based on the headers.""" return image_ops.decode_image(image_buffer, self._channels) def decode_raw(): """Decodes a raw image.""" return parsing_ops.decode_raw(image_buffer, out_type=self._dtype) pred_fn_pairs = { math_ops.logical_or( math_ops.equal(image_format, 'raw'), math_ops.equal(image_format, 'RAW')): decode_raw, } if self._dtype == dtypes.uint8: image = control_flow_ops.case(pred_fn_pairs, default=decode_image, exclusive=True) else: image = decode_raw() image = array_ops.reshape(image, tf.stack([image_height, image_width, 3])) return image
def piecewise_constant(x, boundaries, values, name=None): """Piecewise constant from boundaries and interval values. Example: use a exploration rate that's 1.0 for the first 100000 steps, 0.5 for steps 100001 to 110000, and 0.1 for any additional steps. ```python timestep = tf.Variable(0, trainable=False) boundaries = [100000, 110000] values = [1.0, 0.5, 0.1] exploration_rate = tf.train.piecewise_constant(timestep, boundaries, values) # Later, whenever we perform an optimization step, we increment timestep. ``` Args: x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. boundaries: A list of `Tensor`s or `int`s or `float`s with strictly increasing entries, and with all elements having the same type as `x`. values: A list of `Tensor`s or float`s or `int`s that specifies the values for the intervals defined by `boundaries`. It should have one more element than `boundaries`, and all elements should have the same type. name: A string. Optional name of the operation. Defaults to 'PiecewiseConstant'. Returns: A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`, `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., and values[-1] when `x > boundaries[-1]`. Raises: ValueError: if types of `x` and `buondaries` do not match, or types of all `values` do not match. """ with get_name_scope(name=name, scope="PiecewiseConstant", values=[x, boundaries, values, name]) as name: x = ops.convert_to_tensor(x) # Avoid explicit conversion to x's dtype. This could result in faulty # comparisons, for example if floats are converted to integers. boundaries = ops.convert_n_to_tensor(boundaries) for b in boundaries: if b.dtype != x.dtype: raise ValueError( "Boundaries (%s) must have the same dtype as x (%s)." % (b.dtype, x.dtype)) values = ops.convert_n_to_tensor(values) for v in values[1:]: if v.dtype != values[0].dtype: raise ValueError( "Values must have elements all with the same dtype (%s vs %s)." % (values[0].dtype, v.dtype)) pred_fn_pairs = {} pred_fn_pairs[x <= boundaries[0]] = lambda: values[0] pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1] for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): # Need to bind v here; can do this with lambda v=v: ... pred = (x > low) & (x <= high) pred_fn_pairs[pred] = lambda v=v: v # The default isn't needed here because our conditions are mutually # exclusive and exhaustive, but tf.case requires it. default = lambda: values[0] return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
def testConvertV2UnconvertedResourceNestedCase(self): """Tests unconverted variable propagation through nested functions.""" with ops.Graph().as_default(): with variable_scope.variable_scope("", use_resource=True): control_flow_v2_toggles.enable_control_flow_v2() x = variable_scope.get_variable("x", initializer=1.0) y = variable_scope.get_variable("y", initializer=2.0) z = variable_scope.get_variable("z", initializer=3.0) # pylint: disable=g-long-lambda _ = control_flow_ops.case( [(gen_math_ops.less(x, y), lambda: x)], default=lambda: control_flow_ops.case( [(gen_math_ops.less(z, y), lambda: z)], default=lambda: y)) # pylint: enable=g-long-lambda control_flow_v2_toggles.disable_control_flow_v2() with session_lib.Session() as sess: sess.run(variables.global_variables_initializer()) variable_graph_def = sess.graph.as_graph_def() constant_graph_def = ( convert_to_constants .convert_variables_to_constants_from_session_graph( sess, variable_graph_def, ["case/cond"], variable_names_denylist=["y"])) self._assertGraphContains( constant_graph_def, """ node {name: "x" op: "Const"} node {name: "y" op: "VarHandleOp"} node {name: "z" op: "Const"} node {name: "Less/ReadVariableOp" op: "Identity" input: "x"} node {name: "Less/ReadVariableOp_1" op: "ReadVariableOp" input: "y"} node { name: "case/cond" op: "If" input: "x" input: "z" input: "y" attr { key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT type: DT_RESOURCE}}} attr { key: "_read_only_resource_inputs" value {list {i: 1 i: 2 i: 3}}} attr {key: "then_branch" value {func {name: "case_cond_true_frozen_0"}}} attr {key: "else_branch" value {func {name: "case_cond_false_frozen_0"}}} attr {key: "output_shapes" value {list {shape {}}}} } library { function { signature { name: "case_cond_true_frozen_0" input_arg {name: "placeholder" type: DT_FLOAT} input_arg {name: "placeholder_1" type: DT_RESOURCE} input_arg {name: "readvariableop_x" type: DT_FLOAT} output_arg {name: "readvariableop" type: DT_FLOAT} is_stateful: true } node_def {name: "ReadVariableOp" op: "Identity" input: "readvariableop_x"}} function { signature { name: "case_cond_false_frozen_0" input_arg {name: "placeholder" type: DT_FLOAT} input_arg {name: "less_readvariableop_1_y" type: DT_RESOURCE} input_arg {name: "less_readvariableop_z" type: DT_FLOAT} output_arg {name: "case_cond_identity" type: DT_FLOAT} is_stateful: true } node_def {name: "Less/ReadVariableOp_1" op: "ReadVariableOp" input: "less_readvariableop_1_y"} node_def {name: "Less/ReadVariableOp" op: "Identity" input: "less_readvariableop_z"} node_def {name: "case/cond" op: "If" input: "less_readvariableop_z" input: "less_readvariableop_1_y" attr { key: "Tin" value {list {type: DT_FLOAT type: DT_RESOURCE}}} attr {key: "then_branch" value {func {name: "case_cond_true_frozen_1"}}} attr {key: "else_branch" value {func {name: "case_cond_false_frozen_1"}}} attr { key: "_read_only_resource_inputs" value {list {i: 1 i: 2}}}}} function { signature { name: "case_cond_false_frozen_1" input_arg {name: "placeholder" type: DT_FLOAT} input_arg {name: "readvariableop_y" type: DT_RESOURCE} output_arg {name: "readvariableop" type: DT_FLOAT} is_stateful: true } node_def {name: "ReadVariableOp" op: "ReadVariableOp" input: "readvariableop_y"}} function { signature { name: "case_cond_true_frozen_1" input_arg {name: "placeholder" type: DT_RESOURCE} input_arg {name: "readvariableop_z" type: DT_FLOAT} output_arg {name: "readvariableop" type: DT_FLOAT} is_stateful: true } node_def {name: "ReadVariableOp" op: "Identity" input: "readvariableop_z"}}}""")
def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) mask = array_ops.one_hot(end_token, array_ops.shape(cell_outputs)[-1], dtype=dtypes.float32) # reduce_ratio = [0, 1e10, 6.1, 5.5, 3, 2, 1, 0.5] reduce_ratio = [0, 1e10] pred_fn_pairs = [] def foo(i): return lambda: cell_outputs - mask * reduce_ratio[i] for i in range(1, len(reduce_ratio)): pred_fn_pairs.append((math_ops.equal(time, i), foo(i))) cell_outputs = control_flow_ops.case(pred_fn_pairs=pred_fn_pairs, default=lambda: cell_outputs, exclusive=True) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) return (beam_search_output, beam_search_state, next_inputs, finished)
def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimator( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = training_util.get_or_create_global_step() (cov_variable_thunks, _, inv_variable_thunks, inv_update_op_thunks ) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() for thunk in inv_variable_thunks: thunk() inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() for matrix in fisher_factor._matpower_by_exp_and_damping.values() ] inv_update_op = control_flow_ops.case([ (math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(inv_update_op_thunks) ]) increment_global_step = global_step.assign_add(1) sess.run(variables.global_variables_initializer()) initial_inv_values = sess.run(inv_matrices) # Ensure there's one update per inverse matrix. This is true as long as # there's no fan-in/fan-out or parameter re-use. self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) # Test is no-op if only 1 invariance matrix. assert len(inv_matrices) > 1 # Assign each covariance matrix a value other than the identity. This # ensures that the inverse matrices are updated to something different as # well. cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] sess.run([ cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) for cov_matrix in cov_matrices ]) for i in range(len(inv_matrices)): # Compare new and old inverse values new_inv_values = sess.run(inv_matrices) is_inv_equal = [ np.allclose(initial_inv_value, new_inv_value) for (initial_inv_value, new_inv_value ) in zip(initial_inv_values, new_inv_values) ] num_inv_equal = sum(is_inv_equal) # Ensure exactly one inverse matrix changes per step. self.assertEqual(num_inv_equal, len(inv_matrices) - i) # Run all inverse update ops. sess.run(inv_update_op) sess.run(increment_global_step)
def piecewise_constant(x, boundaries, values, name=None): """Piecewise constant from boundaries and interval values. Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5 for steps 100001 to 110000, and 0.1 for any additional steps. ```python global_step = tf.Variable(0, trainable=False) boundaries = [100000, 110000] values = [1.0, 0.5, 0.1] learning_rate = tf.train.piecewise_constant(global_step, boundaries, values) # Later, whenever we perform an optimization step, we increment global_step. ``` Args: x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. boundaries: A list of `Tensor`s or `int`s or `float`s with strictly increasing entries, and with all elements having the same type as `x`. values: A list of `Tensor`s or float`s or `int`s that specifies the values for the intervals defined by `boundaries`. It should have one more element than `boundaries`, and all elements should have the same type. name: A string. Optional name of the operation. Defaults to 'PiecewiseConstant'. Returns: A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`, `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., and values[-1] when `x > boundaries[-1]`. Raises: ValueError: if types of `x` and `boundaries` do not match, or types of all `values` do not match or the number of elements in the lists does not match. """ if len(boundaries) != len(values) - 1: raise ValueError( "The length of boundaries should be 1 less than the length of values" ) with ops.name_scope(name, "PiecewiseConstant", [x, boundaries, values, name]) as name: x = ops.convert_to_tensor(x) # Avoid explicit conversion to x's dtype. This could result in faulty # comparisons, for example if floats are converted to integers. boundaries = ops.convert_n_to_tensor(boundaries) for i, b in enumerate(boundaries): if b.dtype.base_dtype != x.dtype.base_dtype: # We can promote int32 boundaries to int64 without loss of precision. # This covers the most common case where the user passes in boundaries # as an array of Python integers. if (b.dtype.base_dtype == dtypes.int32 and x.dtype.base_dtype == dtypes.int64): b = math_ops.cast(b, x.dtype.base_dtype) boundaries[i] = b else: raise ValueError( "Boundaries (%s) must have the same dtype as x (%s)." % (b.dtype.base_dtype, x.dtype.base_dtype)) # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing. values = ops.convert_n_to_tensor(values) for v in values[1:]: if v.dtype.base_dtype != values[0].dtype.base_dtype: raise ValueError( "Values must have elements all with the same dtype (%s vs %s)." % (values[0].dtype.base_dtype, v.dtype.base_dtype)) pred_fn_pairs = {} pred_fn_pairs[x <= boundaries[0]] = lambda: values[0] pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1] for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): # Need to bind v here; can do this with lambda v=v: ... pred = (x > low) & (x <= high) pred_fn_pairs[pred] = lambda v=v: v # The default isn't needed here because our conditions are mutually # exclusive and exhaustive, but tf.case requires it. default = lambda: values[0] return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
def piecewise_constant(x, boundaries, values, name=None): """Piecewise constant from boundaries and interval values. Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5 for steps 100001 to 110000, and 0.1 for any additional steps. ```python global_step = tf.Variable(0, trainable=False) boundaries = [100000, 110000] values = [1.0, 0.5, 0.1] learning_rate = tf.train.piecewise_constant(global_step, boundaries, values) # Later, whenever we perform an optimization step, we increment global_step. ``` Args: x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. boundaries: A list of `Tensor`s or `int`s or `float`s with strictly increasing entries, and with all elements having the same type as `x`. values: A list of `Tensor`s or float`s or `int`s that specifies the values for the intervals defined by `boundaries`. It should have one more element than `boundaries`, and all elements should have the same type. name: A string. Optional name of the operation. Defaults to 'PiecewiseConstant'. Returns: A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`, `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., and values[-1] when `x > boundaries[-1]`. Raises: ValueError: if types of `x` and `boundaries` do not match, or types of all `values` do not match or the number of elements in the lists does not match. """ if len(boundaries) != len(values) - 1: raise ValueError( "The length of boundaries should be 1 less than the length of values") with ops.name_scope(name, "PiecewiseConstant", [x, boundaries, values, name]) as name: x = ops.convert_to_tensor(x) # Avoid explicit conversion to x's dtype. This could result in faulty # comparisons, for example if floats are converted to integers. boundaries = ops.convert_n_to_tensor(boundaries) for i, b in enumerate(boundaries): if b.dtype.base_dtype != x.dtype.base_dtype: # We can promote int32 boundaries to int64 without loss of precision. # This covers the most common case where the user passes in boundaries # as an array of Python integers. if (b.dtype.base_dtype == dtypes.int32 and x.dtype.base_dtype == dtypes.int64): b = math_ops.cast(b, x.dtype.base_dtype) boundaries[i] = b else: raise ValueError( "Boundaries (%s) must have the same dtype as x (%s)." % ( b.dtype.base_dtype, x.dtype.base_dtype)) # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing. values = ops.convert_n_to_tensor(values) for v in values[1:]: if v.dtype.base_dtype != values[0].dtype.base_dtype: raise ValueError( "Values must have elements all with the same dtype (%s vs %s)." % ( values[0].dtype.base_dtype, v.dtype.base_dtype)) pred_fn_pairs = [] pred_fn_pairs.append((x <= boundaries[0], lambda: values[0])) pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1])) for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): # Need to bind v here; can do this with lambda v=v: ... pred = (x > low) & (x <= high) pred_fn_pairs.append((pred, lambda v=v: v)) # The default isn't needed here because our conditions are mutually # exclusive and exhaustive, but tf.case requires it. default = lambda: values[0] return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)