예제 #1
0
 def test_is_defun(self):
   self.assertTrue(function_utils.is_defun(function.Defun()(lambda x: None)))
   self.assertTrue(
       function_utils.is_defun(function.Defun(tf.int32)(lambda x: None)))
   self.assertFalse(function_utils.is_defun(function.Defun))
   self.assertFalse(function_utils.is_defun(lambda x: None))
   self.assertFalse(function_utils.is_defun(None))
예제 #2
0
def _scale_gradient_op(dtype):
  """Create an op that scales gradients using a Defun.

  The tensorflow Defun decorator creates an op and tensorflow caches these ops
  automatically according to `func_name`. Using a Defun decorator twice with the
  same `func_name` does not create a new op, instead the cached op is used.

  This method produces a new op the first time it is called with a given `dtype`
  argument, and then uses the cached op each time it is called after that with
  the same `dtype`. The scale value is given as an argument for the forward pass
  method so that it can be used in the backwards pass.

  Args:
    dtype: the dtype of the net whose gradient is being scaled.

  Returns:
    The op that scales gradients.
  """

  def scale_gradient_backward(op, grad):
    scale = op.inputs[1]
    scaled_grad = grad * scale
    return scaled_grad, None

  # Note that if the forward pass implementation involved the creation of ops,
  # _scale_gradient_op would require some memoization mechanism.
  def scale_gradient_forward(x, scale):
    del scale  # Unused.
    return x

  func_name = "ScaleGradient_{}".format(dtype.name)
  return function.Defun(
      dtype, dtype,
      python_grad_func=scale_gradient_backward,
      func_name=func_name)(scale_gradient_forward)
예제 #3
0
 def test_get_defun_argspec_with_untyped_non_eager_defun(self):
   # In a non-eager defun with no input signature, the same restrictions as in
   # a typed defun apply.
   self.assertEqual(
       function_utils.get_argspec(function.Defun()(lambda x, y, *z: None)),
       inspect.ArgSpec(
           args=['x', 'y'], varargs='z', keywords=None, defaults=None))
예제 #4
0
  def testFunctionWithResourcesOnDifferentDevices(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPUs available.")

    with ops.device("/cpu:0"):
      v_cpu_zero = resource_variable_ops.ResourceVariable(
          [0.0, 1.0, 2.0], name="v_cpu_zero")

    with ops.device("/cpu:1"):
      v_cpu_one = resource_variable_ops.ResourceVariable(
          [0.0, 1.0, 2.0], name="v_cpu_one")

    with ops.device("/gpu:0"):
      v_gpu = resource_variable_ops.ResourceVariable(
          [0.0, 1.0, 2.0], name="v_gpu")

    def sum_gather():
      cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2]))
      also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2]))
      gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
      return cpu_result, also_cpu_result, gpu_result

    defined = function.Defun()(sum_gather)
    with self.test_session(
        config=config_pb2.ConfigProto(
            allow_soft_placement=False,
            log_device_placement=True,
            device_count={"CPU": 2})) as sess:
      sess.run(variables.global_variables_initializer())
      expected = sess.run(sum_gather())
      result = sess.run(
          functional_ops.partitioned_call(
              args=defined.captured_inputs, f=defined))
      self.assertAllEqual(expected, result)
예제 #5
0
    def _instantiate_op(dtype):
        """Instantiate constrain to range op constructor for given dtype."""
        def constrain_to_range_forward(x, clip_value_min, clip_value_max):
            return tf.clip_by_value(x, clip_value_min, clip_value_max)

        def constrain_to_range_backward(op, grad):
            """Forwards the gradients moving the inputs within the valid range."""
            x = op.inputs[0]
            clip_value_min = op.inputs[1]
            clip_value_max = op.inputs[2]
            zeros = tf.zeros_like(grad)

            condition = tf.logical_and(x < clip_value_min, grad < 0)
            grad = tf.where(condition, zeros, grad)

            condition = tf.logical_and(x > clip_value_max, grad > 0)
            grad = tf.where(condition, zeros, grad)
            return grad, None, None

        func_name = 'ConstrainToRange_{}'.format(dtype.name)
        return function.Defun(dtype,
                              dtype,
                              dtype,
                              python_grad_func=constrain_to_range_backward,
                              func_name=func_name)(constrain_to_range_forward)
예제 #6
0
def _clip_gradient_op(dtype):
    """Create an op that clips gradients using a Defun.

  The tensorflow Defun decorator creates an op and tensorflow caches these op
  automatically according to `func_name`. Using a Defun decorator twice with the
  same `func_name` does not create a new op, instead the cached op is used.

  This method produces a new op the first time it is called with a given `dtype`
  argument, and then uses the cached op each time it is called after that with
  the same `dtype`. The min and max clip values are given as arguments for the
  forward pass method so that they can be used in the backwards pass.

  Args:
    dtype: the dtype of the net whose gradient is being clipped.

  Returns:
    The op that clips gradients.
  """
    def clip_gradient_backward(op, grad):
        clip_value_min = op.inputs[1]
        clip_value_max = op.inputs[2]
        clipped_grad = tf.clip_by_value(grad, clip_value_min, clip_value_max)
        return clipped_grad, None, None

    def clip_gradient_forward(x, clip_value_min, clip_value_max):
        del clip_value_min  # Unused.
        del clip_value_max  # Unused.
        return x

    func_name = "ClipGradient_{}".format(dtype.name)
    return function.Defun(dtype,
                          dtype,
                          dtype,
                          python_grad_func=clip_gradient_backward,
                          func_name=func_name)(clip_gradient_forward)
예제 #7
0
 def testFunctionGradientWithGradFuncAndRegistration(self):
   g = ops.Graph()
   with g.as_default():
     grad_func = function.Defun(x=tf.float32, b=tf.float32, g=tf.float32)(
         self.XSquarePlusBGradient)
     with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
       f = self._GetFunc(grad_func=grad_func,
                         python_grad_func=self._PythonGradient)
       f.add_to_graph(tf.Graph())
예제 #8
0
    def testWhileCapturedInputs(self):
        for use_gpu in (True, False):
            with ops.Graph().as_default() as g:
                v = variables.Variable(1.0)

                def TestCond(n, *args):
                    del args
                    return n < 10

                @function.Defun(*[dtypes.float32] * 2)
                def TestUnary(n, x):
                    return math_ops.add(n, 1), x + n + v

                @function.Defun(*[dtypes.float32] * 3)
                def TestBinary(n, x, x2):
                    return math_ops.add(n, 1), x + n + v, x2 + v

                with self.session(graph=g, use_gpu=use_gpu) as sess:
                    result_unary = functional_ops.While(
                        [1.0, 0.],
                        function.Defun(*[dtypes.float32] * 2)(TestCond),
                        TestUnary)
                    result_binary = functional_ops.While(
                        [1.0, 0., 0.],
                        function.Defun(*[dtypes.float32] * 3)(TestCond),
                        TestBinary)
                    self.evaluate(variables.global_variables_initializer())
                    assert len(result_unary) == 2
                    self.assertEqual([10.0, 54.0], self.evaluate(result_unary))
                    assert len(result_binary) == 3
                    self.assertEqual([10.0, 54.0, 9.0],
                                     self.evaluate(result_binary))

                    def TestCondCapture(n, *args):
                        del args
                        return math_ops.to_float(n) + v < 10

                    with self.assertRaises(ValueError):
                        _ = functional_ops.While(
                            [1],
                            function.Defun(dtypes.int32)(TestCondCapture),
                            function.Defun(dtypes.int32,
                                           dtypes.float32)(TestUnary))
예제 #9
0
 def test_get_defun_argspec_with_typed_non_eager_defun(self):
   # In a non-eager defun with a defined input signature, **kwargs or default
   # values are not allowed, but *args are, and the input signature may
   # overlap with *args.
   self.assertEqual(
       function_utils.get_argspec(
           function.Defun(tf.int32, tf.bool, tf.float32,
                          tf.float32)(lambda x, y, *z: None)),
       inspect.ArgSpec(
           args=['x', 'y'], varargs='z', keywords=None, defaults=None))
예제 #10
0
    def testFunctionGradientsComposition(self):
        with ops.Graph().as_default():
            f = function.Defun(x=tf.float32)(self.XSquarePlusOne)
            two = tf.constant([2.0], name="two")
            y = f(f(two))
            # Build gradient graph (should add SymbolicGradient node for function).
            grads = gradients.gradients(y, two)

            with self.test_session() as sess:
                self.assertAllEqual([40.0], sess.run(grads)[0])
예제 #11
0
  def _BuildForward(self, weights, inp, mode="cell"):

    def Loop(cell, w, i):
      x = tf.unpack(i, self.NUM_UNROLL)
      m = tf.zeros_like(x[0])
      c = tf.zeros_like(x[0])
      for i in range(self.NUM_UNROLL):
        m, c = cell(x[i], m, c, w)
      return m

    cell = UnrollLSTMTest.LSTMCell
    if mode == "complete":
      # Constructs the complete graph in python.
      return Loop(cell, weights, inp)

    cell = function.Defun(x=tf.float32,
                          mprev=tf.float32,
                          cprev=tf.float32,
                          weights=tf.float32)(cell)
    if mode == "cell":
      # Just represent the LSTM as a function.
      return Loop(cell, weights, inp)

    if mode == "loop":
      # Wraps the whole loop as a function.
      @function.Defun(tf.float32, tf.float32)
      def LSTMLoop(w, i):
        return Loop(cell, w, i)

      return LSTMLoop(weights, inp)

    if mode == "loop10":
      # Wraps 10 lstm steps into one function, and the whole loop
      # into another calling the formers.

      # Groups 10 steps at a time.
      @function.Defun(tf.float32, tf.float32, tf.float32,
                      *([tf.float32] * 10))
      def Loop10(w, m, c, *args):
        for x in args:
          m, c = cell(x, m, c, w)
        return m, c

      @function.Defun(tf.float32, tf.float32)
      def LSTMLoop10(weights, inp):
        x = tf.unpack(inp, self.NUM_UNROLL)
        m = tf.zeros_like(x[0])
        c = tf.zeros_like(x[0])
        assert self.NUM_UNROLL % 10 == 0
        for i in range(0, self.NUM_UNROLL, 10):
          m, c = Loop10(weights, m, c, *x[i:i + 10])
        return m

      return LSTMLoop10(weights, inp)
예제 #12
0
def generic_input(processor, *args, **kwargs):
  # pylint: disable=protected-access
  if not isinstance(processor, function._DefinedFunction):
    # Helper if processor is a python callable.
    processor = function.Defun(tf.string)(processor)
  out_types = [
      tf.DType(a.type) for a in processor.definition.signature.output_arg
  ]
  assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
  return gen_x_ops.generic_input(
      processor=processor, out_types=out_types[:-1], *args, **kwargs)
예제 #13
0
 def testFunctionGradientsWithGradFunc(self):
   g = ops.Graph()
   with g.as_default():
     grad_func = function.Defun(x=tf.float32, b=tf.float32, g=tf.float32)(
         self.XSquarePlusBGradient)
     f = self._GetFunc(grad_func=grad_func)
     # Get gradients (should add SymbolicGradient node for function, which
     # uses the grad_func above, which multiplies all gradients by 2).
     grads = self._GetFuncGradients(f, [2.0], [1.0])
     self.assertAllEqual([4.0 * 2], grads[0])
     self.assertAllEqual([1.0 * 2], grads[1])
예제 #14
0
  def testFunctionGradientsBasic(self):
    g = ops.Graph()
    with g.as_default():
      f = function.Defun(x=tf.float32, b=tf.float32)(self.XSquarePlusB)
      x = tf.constant([2.0], name="x")
      b = tf.constant([1.0], name="b")

      y = f(x, b)
      # Build gradient graph (should add SymbolicGradient node for function).
      grads = gradients.gradients(y, [x, b])
      with self.test_session() as sess:
        self.assertAllEqual([4.0], sess.run(grads)[0])
        self.assertAllEqual([1.0], sess.run(grads)[1])
예제 #15
0
  def testFunctionGradientsComposition(self):
    with ops.Graph().as_default():
      f = function.Defun(x=tf.float32, b=tf.float32)(self.XSquarePlusB)
      x = tf.constant([2.0], name="x")
      b1 = tf.constant([1.0], name="b1")
      b2 = tf.constant([1.0], name="b2")

      y = f(f(x, b1), b2)
      # Build gradient graph (should add SymbolicGradient node for function).
      grads = gradients.gradients(y, [x, b1])

      with self.test_session() as sess:
        self.assertAllEqual([40.0], sess.run(grads)[0])
        self.assertAllEqual([10.0], sess.run(grads)[1])
예제 #16
0
  def _instantiate_op(dtype):
    """Instantiate pass through gradients op constructor for given dtype."""
    def _forward(x, moving_avg):
      del x
      return tf.identity(moving_avg)

    def _backward(op, grad):
      """Forwards the gradients to the op's inputs."""
      del op
      return grad, None

    func_name = "PassThroughGradients_{}".format(dtype.name)
    return function.Defun(
        dtype, dtype, python_grad_func=_backward,
        func_name=func_name)(_forward)
예제 #17
0
 def register(self, input_types, op, grad_op, name=None):
     """
 :param list[tf.DType] input_types:
 :param (tf.Tensor) -> tf.Tensor op:
 :param (tf.Operation, tf.Tensor) -> tf.Tensor grad_op:
 :param str name: optional func_name
 :return: op
 :rtype: (tf.Tensor) -> tf.Tensor
 """
     if op in self.registered_ops:
         return self.registered_ops[op]
     from tensorflow.python.framework import function
     op_with_new_grad = function.Defun(*input_types,
                                       python_grad_func=grad_op,
                                       func_name=name)(op)
     self.registered_ops[op] = op_with_new_grad
     # We need to add one instance of the new op to the graph now because of:
     # https://github.com/tensorflow/tensorflow/issues/6804
     op_with_new_grad(*[tf.placeholder(dtype) for dtype in input_types])
     return op_with_new_grad
예제 #18
0
 def _BuildForward(self, use_func=True, num_unroll=100):
     batch_size = 16
     lstm_dims = 32
     cell = FunctionTest.LSTMCell
     if use_func:
         cell = function.Defun(x=tf.float32,
                               mprev=tf.float32,
                               cprev=tf.float32,
                               weights=tf.float32)(cell)
     m = tf.zeros(shape=[batch_size, lstm_dims])
     c = tf.zeros(shape=[batch_size, lstm_dims])
     weights = tf.random_uniform([2 * lstm_dims, 4 * lstm_dims],
                                 -1,
                                 1,
                                 seed=123456)
     inputs = tf.random_uniform([num_unroll, batch_size, lstm_dims],
                                seed=654321)
     x = tf.unpack(inputs)
     for i in range(num_unroll):
         m, c = cell(x[i], m, c, weights)
     return weights, m, c
def BatchMatMul(a, b):
    use_fp32_batch_matmul = (os.environ.get("use_fp32_batch_matmul") == "true")
    xla_compile = (os.environ.get("xla_compile") == "true")
    if use_fp32_batch_matmul:

        def DoFn(a, b):
            dtype = a.dtype
            a = tf.to_float(a)
            b = tf.to_float(b)
            return tf.cast(tf.matmul(a, b), dtype)

        # If using xla_compile, the fwd and bak per tower are wrapped in xla_compile
        if not xla_compile:
            DoFn = function.Defun(noinline=True)(DoFn)
            res = DoFn(a, b)
            res.set_shape((None, None, b.shape[-1].value))
        else:
            # If xla_compile, leave to xla to handle the casts.
            res = DoFn(a, b)
    else:
        res = tf.matmul(a, b)
    return res
예제 #20
0
    def testFunctionWithResourcesOnDifferentDevices(self):
        # TODO(akshayka): Remove the `skipTest` once we can whitelist ops as
        # safe to be invoked with resources on different devices.
        self.skipTest("The Placer disallows ops with resource inputs "
                      "on different devices.")

        with ops.device("/cpu:0"):
            v_cpu_zero = resource_variable_ops.ResourceVariable(
                [0.0, 1.0, 2.0], name="v_cpu_zero")

        with ops.device("/cpu:1"):
            v_cpu_one = resource_variable_ops.ResourceVariable(
                [0.0, 1.0, 2.0], name="v_cpu_one")

        with ops.device("/gpu:0"):
            v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0],
                                                           name="v_gpu")

        def sum_gather():
            cpu_result = math_ops.reduce_sum(
                array_ops.gather(v_cpu_zero, [1, 2]))
            also_cpu_result = math_ops.reduce_sum(
                array_ops.gather(v_cpu_one, [1, 2]))
            gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
            return cpu_result, also_cpu_result, gpu_result

        defined = function.Defun()(sum_gather)
        with self.test_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False,
                log_device_placement=True,
                device_count={"CPU": 2})) as sess:
            sess.run(variables.global_variables_initializer())
            expected = sess.run(sum_gather())
            result = sess.run(
                functional_ops.partitioned_call(args=defined.captured_inputs,
                                                f=defined))
            self.assertAllEqual(expected, result)
예제 #21
0
    def _BuildForward(self, weights, inp, mode="cell"):
        def Loop(cell, w, i):
            x = tf.unpack(i, self.NUM_UNROLL)
            m = tf.zeros_like(x[0])
            c = tf.zeros_like(x[0])
            for i in range(self.NUM_UNROLL):
                m, c = cell(x[i], m, c, w)
            return m

        cell = UnrollLSTMTest.LSTMCell
        if mode == "complete":
            # Constructs the complete graph in python.
            return Loop(cell, weights, inp)

        cell = function.Defun(x=tf.float32,
                              mprev=tf.float32,
                              cprev=tf.float32,
                              weights=tf.float32)(cell)
        if mode == "cell":
            # Just represent the LSTM as a function.
            return Loop(cell, weights, inp)

        if mode == "loop":
            # Wraps the whole loop as a function.
            @function.Defun(w=tf.float32, i=tf.float32)
            def LSTMLoop(w, i):
                return Loop(cell, w, i)

            return LSTMLoop(weights, inp)

        if mode == "loop10":
            # Wraps 10 lstm steps into one function, and the whole loop
            # into another calling the formers.

            # Groups 10 steps at a time):
            # TODO(zhifengc): Any way to make the syntax less hideous?
            @function.Defun(m=tf.float32,
                            c=tf.float32,
                            w=tf.float32,
                            x0=tf.float32,
                            x1=tf.float32,
                            x2=tf.float32,
                            x3=tf.float32,
                            x4=tf.float32,
                            x5=tf.float32,
                            x6=tf.float32,
                            x7=tf.float32,
                            x8=tf.float32,
                            x9=tf.float32)
            def Loop10(w, m, c, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9):
                for x in [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9]:
                    m, c = cell(x, m, c, w)
                return m, c

            @function.Defun(weights=tf.float32, inp=tf.float32)
            def LSTMLoop10(weights, inp):
                x = tf.unpack(inp, self.NUM_UNROLL)
                m = tf.zeros_like(x[0])
                c = tf.zeros_like(x[0])
                assert self.NUM_UNROLL % 10 == 0
                for i in range(0, self.NUM_UNROLL, 10):
                    m, c = Loop10(weights, m, c, *x[i:i + 10])
                return m

            return LSTMLoop10(weights, inp)
예제 #22
0
def GenericInput(processor, *args, **kwargs):
    """Builds a generic input pipeline.

  Example usage::

    def ParseRecord(record):
      # Given a tf.string record, return a (NestedMap, bucketing key) pair.
      feature_map = ...
      features = tf.parse_single_example(record, feature_map)
      # Each example is represented by a NestedMap of tensors (without a
      # batch dimension).
      example = py_utils.NestedMap(field1=..., field2=...)
      # bucketing_key is a scalar convertible to tf.int32.
      # Use 1 if all examples are of the same size.
      bucketing_key = 1
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)
    # input_batch is a NestedMap of tensors, where dim 0 of each tensor
    # represents the batch dimension.
    input_batch.field1 = ...

  Args:
    processor: a function that takes a string record as input and returns a
      tuple (output, bucketing_key). `output` must be a NestedMap or a list of
      tensors representing one example. The `bucketing_key` must be a scalar
      convertible to a tf.int32 tensor that represents the bucketing key (e.g.,
      sequence length for sequence inputs).
    *args: additional args for x_ops.generic_input.
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A tuple of (outputs, bucket_keys):
    - outputs: a NestedMap or a list of tensors, similar to `processor`'s
      return,  except every tensor will have an additional dimension 0 that
      represents the batch dimension.
    - bucket_keys: a tf.int32 vector.
  """
    output_tmpl = py_utils.NestedMap()

    def _FlatOutputProcessor(inputs):
        """Returns a flattened list of 'processor(inputs)'."""
        output, bucketing_key = processor(inputs)
        if isinstance(output, list):
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output), '{}'.format(output)
        else:
            assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output.Flatten()), '{}'.format(
                           output.DebugString())
        bucketing_key = tf.to_int32(bucketing_key)
        tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                         bucketing_key)
        output_tmpl.values = output
        flat_output_tmpl = output_tmpl.Flatten()
        tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
        tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                         function.get_extra_inputs(),
                         function.get_extra_args(), function.get_extra_vars())
        assert not function.get_extra_args(), (
            'fns {} is not pure: extra_args={}'.format(
                processor, function.get_extra_args()))
        return flat_output_tmpl + [bucketing_key]

    proc_fn = function.Defun(tf.string)(_FlatOutputProcessor)

    out_types = [
        tf.DType(a.type) for a in proc_fn.definition.signature.output_arg
    ]
    assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
    flat_outputs, bucket_keys = ops.gen_x_ops.generic_input(
        processor=proc_fn, out_types=out_types[:-1], *args, **kwargs)
    tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs)
    # Pack flat_outputs to outputs.
    outputs = output_tmpl.Pack(flat_outputs).values
    tf.logging.debug('x_ops.generic_input outputs=%s', outputs)
    return outputs, bucket_keys
예제 #23
0
def GenericInput(processor, *args, **kwargs):
    """Builds a generic input pipeline.

  Example usage::

    def ParseRecord(record):
      # Given a tf.string record, return a (NestedMap, bucketing key) pair.
      feature_map = ...
      features = tf.parse_single_example(record, feature_map)
      # Each example is represented by a NestedMap of tensors (without a
      # batch dimension).
      example = py_utils.NestedMap(field1=..., field2=...)
      # bucketing_key is an int scalar tensor.
      # Use 1 if all examples are of the same size.
      bucketing_key = tf.to_int32(1)
      return example, bucketing_key

    input_batch = GenericInput(ParseRecord, file_pattern=..., ...)
    # input_batch is a NestedMap of tensors, where dim 0 of each tensor
    # represents the batch dimension.
    input_batch.field1 = ...

  Args:
    processor: a function that takes a string record as input and returns a list
      of tensors or NestedMaps representing one example. The last return value
      of processor must be an int32 scalar tensor that represents the bucketing
      key (e.g., sequence length for sequence inputs).
    *args: additional args for x_ops.generic_input.
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A list of tensors or NestedMaps, similar `processor`'s return, except:
      * The bucket key is not included in the output.
      * Every tensor will have an additional dimension 0 that represents the
        batch dimension.
  """
    output_tmpl = py_utils.NestedMap()

    def _FlatOutputProcessor(inputs):
        """Returns a flattened list of 'processor(inputs)'."""
        outputs = processor(inputs)
        tf.logging.debug('Processor outputs=%s', outputs)
        assert len(outputs) > 1, outputs
        # Add 'outputs' as a list so that each element will be flattened.
        output_tmpl.values = list(outputs)
        flat_outputs = output_tmpl.Flatten()
        tf.logging.debug('Processor flat outputs=%s', flat_outputs)
        tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                         function.get_extra_inputs(),
                         function.get_extra_args(), function.get_extra_vars())
        assert not function.get_extra_args(), (
            'fns {} is not pure: extra_args={}'.format(
                processor, function.get_extra_args()))
        return flat_outputs

    proc_fn = function.Defun(tf.string)(_FlatOutputProcessor)

    out_types = [
        tf.DType(a.type) for a in proc_fn.definition.signature.output_arg
    ]
    assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
    flat_outputs = ops.gen_x_ops.generic_input(processor=proc_fn,
                                               out_types=out_types[:-1],
                                               *args,
                                               **kwargs)
    tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs)
    if not output_tmpl:
        return flat_outputs
    # Pack flat_outputs to outputs.
    output_tmpl.values.pop(-1)
    outputs = output_tmpl.Pack(flat_outputs).values
    tf.logging.debug('x_ops.generic_input outputs=%s', outputs)
    return outputs
 def _GetFunc(cls, **kwargs):
     return framework_function.Defun(dtypes.float32, dtypes.float32,
                                     **kwargs)(cls.XSquarePlusB)
예제 #25
0
def _bahdanau_score(processed_query, keys, normalize, v, g, b):
    """Implements Bahdanau-style (additive) scoring function.

  This attention has two forms.  The first is Bhandanau attention,
  as described in:

  Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
  "Neural Machine Translation by Jointly Learning to Align and Translate."
  ICLR 2015. https://arxiv.org/abs/1409.0473

  The second is the normalized form.  This form is inspired by the
  weight normalization article:

  Tim Salimans, Diederik P. Kingma.
  "Weight Normalization: A Simple Reparameterization to Accelerate
   Training of Deep Neural Networks."
  https://arxiv.org/abs/1602.07868

  To enable the second form, set `normalize=True`.

  Args:
    processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys.
    keys: Processed memory, shape `[batch_size, max_time, num_units]`.
    normalize: Whether to normalize the score function.

  Returns:
    A `[batch_size, max_time]` tensor of unnormalized score values.
  """
    # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
    processed_query = tf.expand_dims(processed_query, 1)

    if normalize:
        # normed_v = g * v / ||v||
        def NormalizedAttenFwd(keys, processed_query, g, v, b):
            """normalized atten."""
            normed_v = g * v * tf.rsqrt(tf.reduce_sum(tf.square(v)))
            batch = tf.shape(keys)[0]
            max_time = tf.shape(keys)[1]
            units = tf.shape(keys)[-1]

            # [batch, time, dim]
            activation = tf.tanh(keys + processed_query + b)
            # [batch * time, dim]
            activation = tf.reshape(activation, [batch * max_time, units])

            # [dim, 1]
            v = tf.expand_dims(normed_v, -1)
            # [batch * time, 1]  -> [batch * time]
            y = tf.squeeze(tf.matmul(activation, v), axis=1)
            y = tf.reshape(y, [batch, max_time])
            return y

        use_xla = os.environ.get("use_xla") == "true"

        def NormalizedAtten(keys, processed_query, g, v, b):
            return NormalizedAttenFwd(keys, processed_query, g, v, b)

        fn = NormalizedAtten
        if os.environ.get("use_defun") == "true":
            fn = function.Defun(compiled=use_xla)(fn)
        res = fn(keys, processed_query, g, v, b)
        res.set_shape((None, keys.shape[1]))
        return res
    else:

        def _Atten(keys, processed_query, v):
            """atten."""
            batch = tf.shape(keys)[0]
            max_time = tf.shape(keys)[1]
            units = tf.shape(keys)[-1]

            activation = tf.tanh(keys + processed_query)
            activation = tf.reshape(activation, [batch * max_time, units])

            v = tf.expand_dims(v, -1)
            y = tf.squeeze(tf.matmul(activation, v), axis=1)
            y = tf.reshape(y, [batch, max_time])
            return y

        fn = _Atten
        if os.environ.get("use_defun") == "true":
            fn = function.Defun()(fn)
        return fn(keys, processed_query, v)
예제 #26
0
    def _softmax_cross_entropy_loss(self, logits, labels, label_smoothing):
        """Compute softmax loss or sampled softmax loss."""
        use_defun = os.environ["use_defun"] == "true"
        use_xla = os.environ["use_xla"] == "true"

        # @function.Defun(noinline=True, compiled=use_xla)
        def ComputePositiveCrossent(labels, logits):
            crossent = math_utils.sparse_softmax_crossent_with_logits(
                labels=labels, logits=logits)
            return crossent

        crossent = ComputePositiveCrossent(labels, logits)
        assert crossent.dtype == tf.float32

        def _safe_shape_div(x, y):
            """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
            return x // tf.maximum(y, 1)

        @function.Defun(tf.float32, tf.float32, compiled=use_xla)
        def ReduceSumGrad(x, grad):
            """docstring."""
            input_shape = tf.shape(x)
            # TODO(apassos) remove this once device placement for eager ops makes more
            # sense.
            with tf.colocate_with(input_shape):
                output_shape_kept_dims = math_ops.reduced_shape(
                    input_shape, -1)
                tile_scaling = _safe_shape_div(input_shape,
                                               output_shape_kept_dims)
            grad = tf.reshape(grad, output_shape_kept_dims)
            return tf.tile(grad, tile_scaling)

        def ReduceSum(x):
            """docstring."""
            return tf.reduce_sum(x, axis=-1)

        if use_defun:
            ReduceSum = function.Defun(tf.float32,
                                       compiled=use_xla,
                                       noinline=True,
                                       grad_func=ReduceSumGrad)(ReduceSum)

        if abs(label_smoothing) > 1e-3:
            # pylint:disable=invalid-name
            def ComputeNegativeCrossentFwd(logits):
                """docstring."""
                # [time, batch, dim]
                # [time, batch]
                max_logits = tf.reduce_max(logits, axis=-1)
                # [time, batch, dim]
                shifted_logits = logits - tf.expand_dims(max_logits, axis=-1)
                # Always compute loss in fp32
                shifted_logits = tf.to_float(shifted_logits)
                # [time, batch]
                log_sum_exp = tf.log(ReduceSum(tf.exp(shifted_logits)))
                # [time, batch, dim] - [time, batch, 1] --> reduce_sum(-1) -->
                # [time, batch]
                neg_crossent = ReduceSum(shifted_logits -
                                         tf.expand_dims(log_sum_exp, axis=-1))
                return neg_crossent

            def ComputeNegativeCrossent(logits):
                return ComputeNegativeCrossentFwd(logits)

            if use_defun:
                ComputeNegativeCrossent = function.Defun(
                    compiled=use_xla)(ComputeNegativeCrossent)

            neg_crossent = ComputeNegativeCrossent(logits)
            neg_crossent = tf.to_float(neg_crossent)
            num_labels = logits.shape[-1].value
            crossent = (1.0 - label_smoothing) * crossent - (
                label_smoothing / tf.to_float(num_labels) * neg_crossent)
            # pylint:enable=invalid-name

        return crossent
예제 #27
0
def _scan(fn,
          elems,
          initial,
          reverse=False,
          inclusive=False,
          final_only=False):
    """Repeatedly applies callable `fn` to a sequence of elements.

  Implemented by functional_ops.While, tpu friendly, no gradient.

  This is similar to functional_ops.scan but significantly faster on tpu/gpu
  for the forward backward use case.

  Examples:
    scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0]

    Multiple accumulators:
      scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0))

    Multiple inputs:
      scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0)

  Args:
    fn: callable, fn(accumulators, element) return new accumulator values. The
      (possibly nested) sequence of accumulators is the same as `initial` and
      the return value must have the same structure.
    elems: A (possibly nested) tensor which will be unpacked along the first
      dimension. The resulting slices will be the second argument to fn. The
      first dimension of all nested input tensors must be the same.
    initial: A tensor or (possibly nested) sequence of tensors with initial
      values for the accumulators.
    reverse: (optional) True enables scan and output elems in reverse order.
    inclusive: (optional) True includes the initial accumulator values in the
      output. Length of output will be len(elem sequence) + 1. Not meaningful if
      final_only is True.
    final_only: (optional) When True, return only the final accumulated values,
      not the concatenation of accumulated values for each input.

  Returns:
    A (possibly nested) sequence of tensors with the results of applying fn
    to tensors unpacked from elems and previous accumulator values.
  """

    flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)]
    num_elems = array_ops.shape(flat_elems[0])[0]
    pack_elems = lambda x: nest.pack_sequence_as(structure=elems,
                                                 flat_sequence=x)
    flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)]
    pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x)
    accum_dtypes = [x.dtype for x in flat_initial]
    num_accums = len(flat_initial)

    # Types for counter, [outputs], [accumulators] loop arguments.
    if final_only:
        loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes
    else:
        loop_dtypes = [dtypes.int32, dtypes.int32
                       ] + accum_dtypes + accum_dtypes

    # TODO(tombagby): Update to tfe.defun
    def cond(i, num_elems, *args):
        del args
        return i >= 0 if reverse else i < num_elems

    # The loop *args are [output tensors] + [accumulator tensors] which must
    # be paired. Each output corresponds to one accumulator.
    def body(i, num_elems, *args):
        """Loop body."""
        i.set_shape([])
        if final_only:
            accum = args
        else:
            out, accum = args[:num_accums], args[num_accums:]
        slices = [array_ops.gather(e, i) for e in flat_elems]
        accum = fn(pack(accum), pack_elems(slices))
        flat_accum = nest.flatten(accum)
        if final_only:
            new_out = []
        else:
            update_i = i + 1 if inclusive and not reverse else i
            new_out = [
                inplace_ops.alias_inplace_update(x, update_i, y)
                for x, y in zip(out, flat_accum)
            ]
        i = i - 1 if reverse else i + 1
        return [i, num_elems] + new_out + flat_accum

    init_i = (array_ops.shape(flat_elems[0])[0] -
              1 if reverse else constant_op.constant(0, dtype=dtypes.int32))
    outputs = []
    if not final_only:
        num_outputs = array_ops.shape(
            flat_elems[0])[0] + (1 if inclusive else 0)
        for initial_accum in flat_initial:
            out_shape = array_ops.concat(
                [[num_outputs], array_ops.shape(initial_accum)], 0)
            out = inplace_ops.empty(out_shape,
                                    dtype=initial_accum.dtype,
                                    init=True)
            if inclusive:
                out = inplace_ops.alias_inplace_add(
                    out, init_i + (1 if reverse else 0), initial_accum)
            outputs.append(out)
    loop_in = [init_i, num_elems] + outputs + flat_initial
    hostmem = [
        i for i, x in enumerate(loop_in)
        if x.dtype.base_dtype in (dtypes.int32, dtypes.int64)
    ]

    if context.executing_eagerly():
        loop_results = loop_in
        while cond(*loop_results):
            loop_results = body(*loop_results)
    else:
        # TODO(tombagby): Update to while_v2.
        cond = function.Defun(*loop_dtypes)(cond)
        body = function.Defun(*loop_dtypes)(body)
        loop_results = functional_ops.While(loop_in,
                                            cond,
                                            body,
                                            hostmem=hostmem)
    out = loop_results[2:num_accums + 2]
    return pack(out)
예제 #28
0
 def _GetFunc(cls, **kwargs):
   return function.Defun(x=tf.float32, b=tf.float32, **kwargs)(
       cls.XSquarePlusB)