def test_is_defun(self): self.assertTrue(function_utils.is_defun(function.Defun()(lambda x: None))) self.assertTrue( function_utils.is_defun(function.Defun(tf.int32)(lambda x: None))) self.assertFalse(function_utils.is_defun(function.Defun)) self.assertFalse(function_utils.is_defun(lambda x: None)) self.assertFalse(function_utils.is_defun(None))
def _scale_gradient_op(dtype): """Create an op that scales gradients using a Defun. The tensorflow Defun decorator creates an op and tensorflow caches these ops automatically according to `func_name`. Using a Defun decorator twice with the same `func_name` does not create a new op, instead the cached op is used. This method produces a new op the first time it is called with a given `dtype` argument, and then uses the cached op each time it is called after that with the same `dtype`. The scale value is given as an argument for the forward pass method so that it can be used in the backwards pass. Args: dtype: the dtype of the net whose gradient is being scaled. Returns: The op that scales gradients. """ def scale_gradient_backward(op, grad): scale = op.inputs[1] scaled_grad = grad * scale return scaled_grad, None # Note that if the forward pass implementation involved the creation of ops, # _scale_gradient_op would require some memoization mechanism. def scale_gradient_forward(x, scale): del scale # Unused. return x func_name = "ScaleGradient_{}".format(dtype.name) return function.Defun( dtype, dtype, python_grad_func=scale_gradient_backward, func_name=func_name)(scale_gradient_forward)
def test_get_defun_argspec_with_untyped_non_eager_defun(self): # In a non-eager defun with no input signature, the same restrictions as in # a typed defun apply. self.assertEqual( function_utils.get_argspec(function.Defun()(lambda x, y, *z: None)), inspect.ArgSpec( args=['x', 'y'], varargs='z', keywords=None, defaults=None))
def testFunctionWithResourcesOnDifferentDevices(self): if not test_util.is_gpu_available(): self.skipTest("No GPUs available.") with ops.device("/cpu:0"): v_cpu_zero = resource_variable_ops.ResourceVariable( [0.0, 1.0, 2.0], name="v_cpu_zero") with ops.device("/cpu:1"): v_cpu_one = resource_variable_ops.ResourceVariable( [0.0, 1.0, 2.0], name="v_cpu_one") with ops.device("/gpu:0"): v_gpu = resource_variable_ops.ResourceVariable( [0.0, 1.0, 2.0], name="v_gpu") def sum_gather(): cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2])) also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2])) gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) return cpu_result, also_cpu_result, gpu_result defined = function.Defun()(sum_gather) with self.test_session( config=config_pb2.ConfigProto( allow_soft_placement=False, log_device_placement=True, device_count={"CPU": 2})) as sess: sess.run(variables.global_variables_initializer()) expected = sess.run(sum_gather()) result = sess.run( functional_ops.partitioned_call( args=defined.captured_inputs, f=defined)) self.assertAllEqual(expected, result)
def _instantiate_op(dtype): """Instantiate constrain to range op constructor for given dtype.""" def constrain_to_range_forward(x, clip_value_min, clip_value_max): return tf.clip_by_value(x, clip_value_min, clip_value_max) def constrain_to_range_backward(op, grad): """Forwards the gradients moving the inputs within the valid range.""" x = op.inputs[0] clip_value_min = op.inputs[1] clip_value_max = op.inputs[2] zeros = tf.zeros_like(grad) condition = tf.logical_and(x < clip_value_min, grad < 0) grad = tf.where(condition, zeros, grad) condition = tf.logical_and(x > clip_value_max, grad > 0) grad = tf.where(condition, zeros, grad) return grad, None, None func_name = 'ConstrainToRange_{}'.format(dtype.name) return function.Defun(dtype, dtype, dtype, python_grad_func=constrain_to_range_backward, func_name=func_name)(constrain_to_range_forward)
def _clip_gradient_op(dtype): """Create an op that clips gradients using a Defun. The tensorflow Defun decorator creates an op and tensorflow caches these op automatically according to `func_name`. Using a Defun decorator twice with the same `func_name` does not create a new op, instead the cached op is used. This method produces a new op the first time it is called with a given `dtype` argument, and then uses the cached op each time it is called after that with the same `dtype`. The min and max clip values are given as arguments for the forward pass method so that they can be used in the backwards pass. Args: dtype: the dtype of the net whose gradient is being clipped. Returns: The op that clips gradients. """ def clip_gradient_backward(op, grad): clip_value_min = op.inputs[1] clip_value_max = op.inputs[2] clipped_grad = tf.clip_by_value(grad, clip_value_min, clip_value_max) return clipped_grad, None, None def clip_gradient_forward(x, clip_value_min, clip_value_max): del clip_value_min # Unused. del clip_value_max # Unused. return x func_name = "ClipGradient_{}".format(dtype.name) return function.Defun(dtype, dtype, dtype, python_grad_func=clip_gradient_backward, func_name=func_name)(clip_gradient_forward)
def testFunctionGradientWithGradFuncAndRegistration(self): g = ops.Graph() with g.as_default(): grad_func = function.Defun(x=tf.float32, b=tf.float32, g=tf.float32)( self.XSquarePlusBGradient) with self.assertRaisesRegexp(ValueError, "Gradient defined twice"): f = self._GetFunc(grad_func=grad_func, python_grad_func=self._PythonGradient) f.add_to_graph(tf.Graph())
def testWhileCapturedInputs(self): for use_gpu in (True, False): with ops.Graph().as_default() as g: v = variables.Variable(1.0) def TestCond(n, *args): del args return n < 10 @function.Defun(*[dtypes.float32] * 2) def TestUnary(n, x): return math_ops.add(n, 1), x + n + v @function.Defun(*[dtypes.float32] * 3) def TestBinary(n, x, x2): return math_ops.add(n, 1), x + n + v, x2 + v with self.session(graph=g, use_gpu=use_gpu) as sess: result_unary = functional_ops.While( [1.0, 0.], function.Defun(*[dtypes.float32] * 2)(TestCond), TestUnary) result_binary = functional_ops.While( [1.0, 0., 0.], function.Defun(*[dtypes.float32] * 3)(TestCond), TestBinary) self.evaluate(variables.global_variables_initializer()) assert len(result_unary) == 2 self.assertEqual([10.0, 54.0], self.evaluate(result_unary)) assert len(result_binary) == 3 self.assertEqual([10.0, 54.0, 9.0], self.evaluate(result_binary)) def TestCondCapture(n, *args): del args return math_ops.to_float(n) + v < 10 with self.assertRaises(ValueError): _ = functional_ops.While( [1], function.Defun(dtypes.int32)(TestCondCapture), function.Defun(dtypes.int32, dtypes.float32)(TestUnary))
def test_get_defun_argspec_with_typed_non_eager_defun(self): # In a non-eager defun with a defined input signature, **kwargs or default # values are not allowed, but *args are, and the input signature may # overlap with *args. self.assertEqual( function_utils.get_argspec( function.Defun(tf.int32, tf.bool, tf.float32, tf.float32)(lambda x, y, *z: None)), inspect.ArgSpec( args=['x', 'y'], varargs='z', keywords=None, defaults=None))
def testFunctionGradientsComposition(self): with ops.Graph().as_default(): f = function.Defun(x=tf.float32)(self.XSquarePlusOne) two = tf.constant([2.0], name="two") y = f(f(two)) # Build gradient graph (should add SymbolicGradient node for function). grads = gradients.gradients(y, two) with self.test_session() as sess: self.assertAllEqual([40.0], sess.run(grads)[0])
def _BuildForward(self, weights, inp, mode="cell"): def Loop(cell, w, i): x = tf.unpack(i, self.NUM_UNROLL) m = tf.zeros_like(x[0]) c = tf.zeros_like(x[0]) for i in range(self.NUM_UNROLL): m, c = cell(x[i], m, c, w) return m cell = UnrollLSTMTest.LSTMCell if mode == "complete": # Constructs the complete graph in python. return Loop(cell, weights, inp) cell = function.Defun(x=tf.float32, mprev=tf.float32, cprev=tf.float32, weights=tf.float32)(cell) if mode == "cell": # Just represent the LSTM as a function. return Loop(cell, weights, inp) if mode == "loop": # Wraps the whole loop as a function. @function.Defun(tf.float32, tf.float32) def LSTMLoop(w, i): return Loop(cell, w, i) return LSTMLoop(weights, inp) if mode == "loop10": # Wraps 10 lstm steps into one function, and the whole loop # into another calling the formers. # Groups 10 steps at a time. @function.Defun(tf.float32, tf.float32, tf.float32, *([tf.float32] * 10)) def Loop10(w, m, c, *args): for x in args: m, c = cell(x, m, c, w) return m, c @function.Defun(tf.float32, tf.float32) def LSTMLoop10(weights, inp): x = tf.unpack(inp, self.NUM_UNROLL) m = tf.zeros_like(x[0]) c = tf.zeros_like(x[0]) assert self.NUM_UNROLL % 10 == 0 for i in range(0, self.NUM_UNROLL, 10): m, c = Loop10(weights, m, c, *x[i:i + 10]) return m return LSTMLoop10(weights, inp)
def generic_input(processor, *args, **kwargs): # pylint: disable=protected-access if not isinstance(processor, function._DefinedFunction): # Helper if processor is a python callable. processor = function.Defun(tf.string)(processor) out_types = [ tf.DType(a.type) for a in processor.definition.signature.output_arg ] assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1]) return gen_x_ops.generic_input( processor=processor, out_types=out_types[:-1], *args, **kwargs)
def testFunctionGradientsWithGradFunc(self): g = ops.Graph() with g.as_default(): grad_func = function.Defun(x=tf.float32, b=tf.float32, g=tf.float32)( self.XSquarePlusBGradient) f = self._GetFunc(grad_func=grad_func) # Get gradients (should add SymbolicGradient node for function, which # uses the grad_func above, which multiplies all gradients by 2). grads = self._GetFuncGradients(f, [2.0], [1.0]) self.assertAllEqual([4.0 * 2], grads[0]) self.assertAllEqual([1.0 * 2], grads[1])
def testFunctionGradientsBasic(self): g = ops.Graph() with g.as_default(): f = function.Defun(x=tf.float32, b=tf.float32)(self.XSquarePlusB) x = tf.constant([2.0], name="x") b = tf.constant([1.0], name="b") y = f(x, b) # Build gradient graph (should add SymbolicGradient node for function). grads = gradients.gradients(y, [x, b]) with self.test_session() as sess: self.assertAllEqual([4.0], sess.run(grads)[0]) self.assertAllEqual([1.0], sess.run(grads)[1])
def testFunctionGradientsComposition(self): with ops.Graph().as_default(): f = function.Defun(x=tf.float32, b=tf.float32)(self.XSquarePlusB) x = tf.constant([2.0], name="x") b1 = tf.constant([1.0], name="b1") b2 = tf.constant([1.0], name="b2") y = f(f(x, b1), b2) # Build gradient graph (should add SymbolicGradient node for function). grads = gradients.gradients(y, [x, b1]) with self.test_session() as sess: self.assertAllEqual([40.0], sess.run(grads)[0]) self.assertAllEqual([10.0], sess.run(grads)[1])
def _instantiate_op(dtype): """Instantiate pass through gradients op constructor for given dtype.""" def _forward(x, moving_avg): del x return tf.identity(moving_avg) def _backward(op, grad): """Forwards the gradients to the op's inputs.""" del op return grad, None func_name = "PassThroughGradients_{}".format(dtype.name) return function.Defun( dtype, dtype, python_grad_func=_backward, func_name=func_name)(_forward)
def register(self, input_types, op, grad_op, name=None): """ :param list[tf.DType] input_types: :param (tf.Tensor) -> tf.Tensor op: :param (tf.Operation, tf.Tensor) -> tf.Tensor grad_op: :param str name: optional func_name :return: op :rtype: (tf.Tensor) -> tf.Tensor """ if op in self.registered_ops: return self.registered_ops[op] from tensorflow.python.framework import function op_with_new_grad = function.Defun(*input_types, python_grad_func=grad_op, func_name=name)(op) self.registered_ops[op] = op_with_new_grad # We need to add one instance of the new op to the graph now because of: # https://github.com/tensorflow/tensorflow/issues/6804 op_with_new_grad(*[tf.placeholder(dtype) for dtype in input_types]) return op_with_new_grad
def _BuildForward(self, use_func=True, num_unroll=100): batch_size = 16 lstm_dims = 32 cell = FunctionTest.LSTMCell if use_func: cell = function.Defun(x=tf.float32, mprev=tf.float32, cprev=tf.float32, weights=tf.float32)(cell) m = tf.zeros(shape=[batch_size, lstm_dims]) c = tf.zeros(shape=[batch_size, lstm_dims]) weights = tf.random_uniform([2 * lstm_dims, 4 * lstm_dims], -1, 1, seed=123456) inputs = tf.random_uniform([num_unroll, batch_size, lstm_dims], seed=654321) x = tf.unpack(inputs) for i in range(num_unroll): m, c = cell(x[i], m, c, weights) return weights, m, c
def BatchMatMul(a, b): use_fp32_batch_matmul = (os.environ.get("use_fp32_batch_matmul") == "true") xla_compile = (os.environ.get("xla_compile") == "true") if use_fp32_batch_matmul: def DoFn(a, b): dtype = a.dtype a = tf.to_float(a) b = tf.to_float(b) return tf.cast(tf.matmul(a, b), dtype) # If using xla_compile, the fwd and bak per tower are wrapped in xla_compile if not xla_compile: DoFn = function.Defun(noinline=True)(DoFn) res = DoFn(a, b) res.set_shape((None, None, b.shape[-1].value)) else: # If xla_compile, leave to xla to handle the casts. res = DoFn(a, b) else: res = tf.matmul(a, b) return res
def testFunctionWithResourcesOnDifferentDevices(self): # TODO(akshayka): Remove the `skipTest` once we can whitelist ops as # safe to be invoked with resources on different devices. self.skipTest("The Placer disallows ops with resource inputs " "on different devices.") with ops.device("/cpu:0"): v_cpu_zero = resource_variable_ops.ResourceVariable( [0.0, 1.0, 2.0], name="v_cpu_zero") with ops.device("/cpu:1"): v_cpu_one = resource_variable_ops.ResourceVariable( [0.0, 1.0, 2.0], name="v_cpu_one") with ops.device("/gpu:0"): v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0], name="v_gpu") def sum_gather(): cpu_result = math_ops.reduce_sum( array_ops.gather(v_cpu_zero, [1, 2])) also_cpu_result = math_ops.reduce_sum( array_ops.gather(v_cpu_one, [1, 2])) gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) return cpu_result, also_cpu_result, gpu_result defined = function.Defun()(sum_gather) with self.test_session(config=config_pb2.ConfigProto( allow_soft_placement=False, log_device_placement=True, device_count={"CPU": 2})) as sess: sess.run(variables.global_variables_initializer()) expected = sess.run(sum_gather()) result = sess.run( functional_ops.partitioned_call(args=defined.captured_inputs, f=defined)) self.assertAllEqual(expected, result)
def _BuildForward(self, weights, inp, mode="cell"): def Loop(cell, w, i): x = tf.unpack(i, self.NUM_UNROLL) m = tf.zeros_like(x[0]) c = tf.zeros_like(x[0]) for i in range(self.NUM_UNROLL): m, c = cell(x[i], m, c, w) return m cell = UnrollLSTMTest.LSTMCell if mode == "complete": # Constructs the complete graph in python. return Loop(cell, weights, inp) cell = function.Defun(x=tf.float32, mprev=tf.float32, cprev=tf.float32, weights=tf.float32)(cell) if mode == "cell": # Just represent the LSTM as a function. return Loop(cell, weights, inp) if mode == "loop": # Wraps the whole loop as a function. @function.Defun(w=tf.float32, i=tf.float32) def LSTMLoop(w, i): return Loop(cell, w, i) return LSTMLoop(weights, inp) if mode == "loop10": # Wraps 10 lstm steps into one function, and the whole loop # into another calling the formers. # Groups 10 steps at a time): # TODO(zhifengc): Any way to make the syntax less hideous? @function.Defun(m=tf.float32, c=tf.float32, w=tf.float32, x0=tf.float32, x1=tf.float32, x2=tf.float32, x3=tf.float32, x4=tf.float32, x5=tf.float32, x6=tf.float32, x7=tf.float32, x8=tf.float32, x9=tf.float32) def Loop10(w, m, c, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): for x in [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9]: m, c = cell(x, m, c, w) return m, c @function.Defun(weights=tf.float32, inp=tf.float32) def LSTMLoop10(weights, inp): x = tf.unpack(inp, self.NUM_UNROLL) m = tf.zeros_like(x[0]) c = tf.zeros_like(x[0]) assert self.NUM_UNROLL % 10 == 0 for i in range(0, self.NUM_UNROLL, 10): m, c = Loop10(weights, m, c, *x[i:i + 10]) return m return LSTMLoop10(weights, inp)
def GenericInput(processor, *args, **kwargs): """Builds a generic input pipeline. Example usage:: def ParseRecord(record): # Given a tf.string record, return a (NestedMap, bucketing key) pair. feature_map = ... features = tf.parse_single_example(record, feature_map) # Each example is represented by a NestedMap of tensors (without a # batch dimension). example = py_utils.NestedMap(field1=..., field2=...) # bucketing_key is a scalar convertible to tf.int32. # Use 1 if all examples are of the same size. bucketing_key = 1 return example, bucketing_key input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...) # input_batch is a NestedMap of tensors, where dim 0 of each tensor # represents the batch dimension. input_batch.field1 = ... Args: processor: a function that takes a string record as input and returns a tuple (output, bucketing_key). `output` must be a NestedMap or a list of tensors representing one example. The `bucketing_key` must be a scalar convertible to a tf.int32 tensor that represents the bucketing key (e.g., sequence length for sequence inputs). *args: additional args for x_ops.generic_input. **kwargs: additional keyword args for x_ops.generic_input. Returns: A tuple of (outputs, bucket_keys): - outputs: a NestedMap or a list of tensors, similar to `processor`'s return, except every tensor will have an additional dimension 0 that represents the batch dimension. - bucket_keys: a tf.int32 vector. """ output_tmpl = py_utils.NestedMap() def _FlatOutputProcessor(inputs): """Returns a flattened list of 'processor(inputs)'.""" output, bucketing_key = processor(inputs) if isinstance(output, list): assert output assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output) else: assert isinstance(output, py_utils.NestedMap), '{}'.format(output) assert output assert all(isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format( output.DebugString()) bucketing_key = tf.to_int32(bucketing_key) tf.logging.debug('Processor outputs=%s bucketing_key=%s', output, bucketing_key) output_tmpl.values = output flat_output_tmpl = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_output_tmpl + [bucketing_key] proc_fn = function.Defun(tf.string)(_FlatOutputProcessor) out_types = [ tf.DType(a.type) for a in proc_fn.definition.signature.output_arg ] assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1]) flat_outputs, bucket_keys = ops.gen_x_ops.generic_input( processor=proc_fn, out_types=out_types[:-1], *args, **kwargs) tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs) # Pack flat_outputs to outputs. outputs = output_tmpl.Pack(flat_outputs).values tf.logging.debug('x_ops.generic_input outputs=%s', outputs) return outputs, bucket_keys
def GenericInput(processor, *args, **kwargs): """Builds a generic input pipeline. Example usage:: def ParseRecord(record): # Given a tf.string record, return a (NestedMap, bucketing key) pair. feature_map = ... features = tf.parse_single_example(record, feature_map) # Each example is represented by a NestedMap of tensors (without a # batch dimension). example = py_utils.NestedMap(field1=..., field2=...) # bucketing_key is an int scalar tensor. # Use 1 if all examples are of the same size. bucketing_key = tf.to_int32(1) return example, bucketing_key input_batch = GenericInput(ParseRecord, file_pattern=..., ...) # input_batch is a NestedMap of tensors, where dim 0 of each tensor # represents the batch dimension. input_batch.field1 = ... Args: processor: a function that takes a string record as input and returns a list of tensors or NestedMaps representing one example. The last return value of processor must be an int32 scalar tensor that represents the bucketing key (e.g., sequence length for sequence inputs). *args: additional args for x_ops.generic_input. **kwargs: additional keyword args for x_ops.generic_input. Returns: A list of tensors or NestedMaps, similar `processor`'s return, except: * The bucket key is not included in the output. * Every tensor will have an additional dimension 0 that represents the batch dimension. """ output_tmpl = py_utils.NestedMap() def _FlatOutputProcessor(inputs): """Returns a flattened list of 'processor(inputs)'.""" outputs = processor(inputs) tf.logging.debug('Processor outputs=%s', outputs) assert len(outputs) > 1, outputs # Add 'outputs' as a list so that each element will be flattened. output_tmpl.values = list(outputs) flat_outputs = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_outputs) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_outputs proc_fn = function.Defun(tf.string)(_FlatOutputProcessor) out_types = [ tf.DType(a.type) for a in proc_fn.definition.signature.output_arg ] assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1]) flat_outputs = ops.gen_x_ops.generic_input(processor=proc_fn, out_types=out_types[:-1], *args, **kwargs) tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs) if not output_tmpl: return flat_outputs # Pack flat_outputs to outputs. output_tmpl.values.pop(-1) outputs = output_tmpl.Pack(flat_outputs).values tf.logging.debug('x_ops.generic_input outputs=%s', outputs) return outputs
def _GetFunc(cls, **kwargs): return framework_function.Defun(dtypes.float32, dtypes.float32, **kwargs)(cls.XSquarePlusB)
def _bahdanau_score(processed_query, keys, normalize, v, g, b): """Implements Bahdanau-style (additive) scoring function. This attention has two forms. The first is Bhandanau attention, as described in: Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. "Neural Machine Translation by Jointly Learning to Align and Translate." ICLR 2015. https://arxiv.org/abs/1409.0473 The second is the normalized form. This form is inspired by the weight normalization article: Tim Salimans, Diederik P. Kingma. "Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks." https://arxiv.org/abs/1602.07868 To enable the second form, set `normalize=True`. Args: processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys. keys: Processed memory, shape `[batch_size, max_time, num_units]`. normalize: Whether to normalize the score function. Returns: A `[batch_size, max_time]` tensor of unnormalized score values. """ # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. processed_query = tf.expand_dims(processed_query, 1) if normalize: # normed_v = g * v / ||v|| def NormalizedAttenFwd(keys, processed_query, g, v, b): """normalized atten.""" normed_v = g * v * tf.rsqrt(tf.reduce_sum(tf.square(v))) batch = tf.shape(keys)[0] max_time = tf.shape(keys)[1] units = tf.shape(keys)[-1] # [batch, time, dim] activation = tf.tanh(keys + processed_query + b) # [batch * time, dim] activation = tf.reshape(activation, [batch * max_time, units]) # [dim, 1] v = tf.expand_dims(normed_v, -1) # [batch * time, 1] -> [batch * time] y = tf.squeeze(tf.matmul(activation, v), axis=1) y = tf.reshape(y, [batch, max_time]) return y use_xla = os.environ.get("use_xla") == "true" def NormalizedAtten(keys, processed_query, g, v, b): return NormalizedAttenFwd(keys, processed_query, g, v, b) fn = NormalizedAtten if os.environ.get("use_defun") == "true": fn = function.Defun(compiled=use_xla)(fn) res = fn(keys, processed_query, g, v, b) res.set_shape((None, keys.shape[1])) return res else: def _Atten(keys, processed_query, v): """atten.""" batch = tf.shape(keys)[0] max_time = tf.shape(keys)[1] units = tf.shape(keys)[-1] activation = tf.tanh(keys + processed_query) activation = tf.reshape(activation, [batch * max_time, units]) v = tf.expand_dims(v, -1) y = tf.squeeze(tf.matmul(activation, v), axis=1) y = tf.reshape(y, [batch, max_time]) return y fn = _Atten if os.environ.get("use_defun") == "true": fn = function.Defun()(fn) return fn(keys, processed_query, v)
def _softmax_cross_entropy_loss(self, logits, labels, label_smoothing): """Compute softmax loss or sampled softmax loss.""" use_defun = os.environ["use_defun"] == "true" use_xla = os.environ["use_xla"] == "true" # @function.Defun(noinline=True, compiled=use_xla) def ComputePositiveCrossent(labels, logits): crossent = math_utils.sparse_softmax_crossent_with_logits( labels=labels, logits=logits) return crossent crossent = ComputePositiveCrossent(labels, logits) assert crossent.dtype == tf.float32 def _safe_shape_div(x, y): """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`.""" return x // tf.maximum(y, 1) @function.Defun(tf.float32, tf.float32, compiled=use_xla) def ReduceSumGrad(x, grad): """docstring.""" input_shape = tf.shape(x) # TODO(apassos) remove this once device placement for eager ops makes more # sense. with tf.colocate_with(input_shape): output_shape_kept_dims = math_ops.reduced_shape( input_shape, -1) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = tf.reshape(grad, output_shape_kept_dims) return tf.tile(grad, tile_scaling) def ReduceSum(x): """docstring.""" return tf.reduce_sum(x, axis=-1) if use_defun: ReduceSum = function.Defun(tf.float32, compiled=use_xla, noinline=True, grad_func=ReduceSumGrad)(ReduceSum) if abs(label_smoothing) > 1e-3: # pylint:disable=invalid-name def ComputeNegativeCrossentFwd(logits): """docstring.""" # [time, batch, dim] # [time, batch] max_logits = tf.reduce_max(logits, axis=-1) # [time, batch, dim] shifted_logits = logits - tf.expand_dims(max_logits, axis=-1) # Always compute loss in fp32 shifted_logits = tf.to_float(shifted_logits) # [time, batch] log_sum_exp = tf.log(ReduceSum(tf.exp(shifted_logits))) # [time, batch, dim] - [time, batch, 1] --> reduce_sum(-1) --> # [time, batch] neg_crossent = ReduceSum(shifted_logits - tf.expand_dims(log_sum_exp, axis=-1)) return neg_crossent def ComputeNegativeCrossent(logits): return ComputeNegativeCrossentFwd(logits) if use_defun: ComputeNegativeCrossent = function.Defun( compiled=use_xla)(ComputeNegativeCrossent) neg_crossent = ComputeNegativeCrossent(logits) neg_crossent = tf.to_float(neg_crossent) num_labels = logits.shape[-1].value crossent = (1.0 - label_smoothing) * crossent - ( label_smoothing / tf.to_float(num_labels) * neg_crossent) # pylint:enable=invalid-name return crossent
def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False): """Repeatedly applies callable `fn` to a sequence of elements. Implemented by functional_ops.While, tpu friendly, no gradient. This is similar to functional_ops.scan but significantly faster on tpu/gpu for the forward backward use case. Examples: scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0] Multiple accumulators: scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0)) Multiple inputs: scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0) Args: fn: callable, fn(accumulators, element) return new accumulator values. The (possibly nested) sequence of accumulators is the same as `initial` and the return value must have the same structure. elems: A (possibly nested) tensor which will be unpacked along the first dimension. The resulting slices will be the second argument to fn. The first dimension of all nested input tensors must be the same. initial: A tensor or (possibly nested) sequence of tensors with initial values for the accumulators. reverse: (optional) True enables scan and output elems in reverse order. inclusive: (optional) True includes the initial accumulator values in the output. Length of output will be len(elem sequence) + 1. Not meaningful if final_only is True. final_only: (optional) When True, return only the final accumulated values, not the concatenation of accumulated values for each input. Returns: A (possibly nested) sequence of tensors with the results of applying fn to tensors unpacked from elems and previous accumulator values. """ flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)] num_elems = array_ops.shape(flat_elems[0])[0] pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x) flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)] pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x) accum_dtypes = [x.dtype for x in flat_initial] num_accums = len(flat_initial) # Types for counter, [outputs], [accumulators] loop arguments. if final_only: loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes else: loop_dtypes = [dtypes.int32, dtypes.int32 ] + accum_dtypes + accum_dtypes # TODO(tombagby): Update to tfe.defun def cond(i, num_elems, *args): del args return i >= 0 if reverse else i < num_elems # The loop *args are [output tensors] + [accumulator tensors] which must # be paired. Each output corresponds to one accumulator. def body(i, num_elems, *args): """Loop body.""" i.set_shape([]) if final_only: accum = args else: out, accum = args[:num_accums], args[num_accums:] slices = [array_ops.gather(e, i) for e in flat_elems] accum = fn(pack(accum), pack_elems(slices)) flat_accum = nest.flatten(accum) if final_only: new_out = [] else: update_i = i + 1 if inclusive and not reverse else i new_out = [ inplace_ops.alias_inplace_update(x, update_i, y) for x, y in zip(out, flat_accum) ] i = i - 1 if reverse else i + 1 return [i, num_elems] + new_out + flat_accum init_i = (array_ops.shape(flat_elems[0])[0] - 1 if reverse else constant_op.constant(0, dtype=dtypes.int32)) outputs = [] if not final_only: num_outputs = array_ops.shape( flat_elems[0])[0] + (1 if inclusive else 0) for initial_accum in flat_initial: out_shape = array_ops.concat( [[num_outputs], array_ops.shape(initial_accum)], 0) out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True) if inclusive: out = inplace_ops.alias_inplace_add( out, init_i + (1 if reverse else 0), initial_accum) outputs.append(out) loop_in = [init_i, num_elems] + outputs + flat_initial hostmem = [ i for i, x in enumerate(loop_in) if x.dtype.base_dtype in (dtypes.int32, dtypes.int64) ] if context.executing_eagerly(): loop_results = loop_in while cond(*loop_results): loop_results = body(*loop_results) else: # TODO(tombagby): Update to while_v2. cond = function.Defun(*loop_dtypes)(cond) body = function.Defun(*loop_dtypes)(body) loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem) out = loop_results[2:num_accums + 2] return pack(out)
def _GetFunc(cls, **kwargs): return function.Defun(x=tf.float32, b=tf.float32, **kwargs)( cls.XSquarePlusB)