Exemplo n.º 1
0
  def make_op(self):
    if self.cache_key in self.op_cache:
      return self.op_cache[self.cache_key]
    mod = self._make_mod()
    op = getattr(mod, _camel_case_to_snake_case(self.op_name))
    self.op_cache[self.cache_key] = op

    if self.description.is_grad_defined:
      grad_description = self.description.grad()
      grad_op_maker = OpMaker(description=grad_description, compiler_opts=self.compiler_opts)
      grad_op = grad_op_maker.make_op()

      from tensorflow.python.framework import ops
      def grad_wrapper(fwd_op, *bwd_grads):
        """
        :param tf.Operation fwd_op: for fwd_op.inputs and fwd_op.outputs
        :param list[tf.Tensor] bwd_grads:
        :return: list of tensors of gradients for each input
        :rtype: list[tf.Tensor]
        """
        assert len(bwd_grads) == len(fwd_op.outputs)

        grad_inputs = list(fwd_op.inputs) + list(fwd_op.outputs) + list(bwd_grads)
        grad_inputs = self.description._filter_grad_inputs(grad_inputs)
        grad_outputs = TFUtil.make_var_tuple(grad_op(*grad_inputs))
        if grad_description.num_dummy_outs > 0:
          grad_outputs = grad_outputs[:-grad_description.num_dummy_outs]
        grad_outputs = self.description.make_results_of_gradient(grad_outputs)
        return grad_outputs

      grad_wrapper.__name__ = grad_description.name
      ops.RegisterGradient(self.name)(grad_wrapper)

    return op
Exemplo n.º 2
0
    def __init__(self, parent_graph=None):
        """Initializes an ImperativeGraph.

    Args:
      parent_graph: (Optional) An ImperativeGraph.
    """
        self._parent_graph = parent_graph
        # Whether the create_op function should augment an op with extra logic for
        # imperative execution.
        self._return_as_is = False
        # Operation -> list of Tensors map. Used for overriding the op.outputs
        # property, useful during gradient computation.
        self._outputs_map = {}
        # Operation -> function map. Used for overriding the gradient function
        # for an op.
        self._gradient_function_map = {}
        # Unique name for the graph. Used for naming the container in which
        # temporary variables are placed.
        self._name = uuid.uuid4().hex
        # Names for op types used for marking ops so we can override their
        # gradient functions.
        self._merge_op_type = 'ImperativeMerge' + self._name
        self._imperative_op_type = 'ImperativeOp' + self._name
        # The list of 'assign' ops that initialize variables.
        self._init_ops = []
        # Names of variables whose init ops have been already recorded in _init_ops.
        self._init_variable_names = set()
        # A flag to indicate whether a variable and the corresponding initialization
        # ops are being created. Typically set by the initializer of Variable class.
        self._in_variable_creation = False
        self._variable_cleanup_ops = []
        # Call the parent's initializer.
        super(ImperativeGraph, self).__init__()

        # Register a simple 'pass through' function to be used for ops that have
        # _merge_op_type as the _gradient_op_type attribute.
        ops.RegisterGradient(
            self._merge_op_type)(lambda op, grad, _: [grad] * len(op.inputs))

        # For ops that have _imperative_op_grad as the _gradient_op_type attribute,
        # temporarily replace their outputs with the values in _output_map before
        # calling the original gradient function.
        def _imperative_op_grad(op, *grad):
            with self.replace_outputs(op):
                return self._gradient_function_map[op.name](op, *grad)

        ops.RegisterGradient(self._imperative_op_type)(_imperative_op_grad)
    def testNoGradientForStringOutputs(self):
        with ops.Graph().as_default():

            def _TestOpGrad(_, float_grad, string_grad):
                """Gradient function for TestStringOutput."""
                self.assertEquals(float_grad.dtype, dtypes.float32)
                self.assertFalse(string_grad)
                return float_grad

            ops.RegisterGradient("TestStringOutput")(_TestOpGrad)

            c = constant(1.0)
            x, _ = test_ops.test_string_output(c)
            z = x * 2.0
            w = z * 3.0
            grads = gradients.gradients(z, [c])
            self.assertTrue(isinstance(grads[0], ops.Tensor))
            grads = gradients.gradients(w, [c])
            self.assertTrue(isinstance(grads[0], ops.Tensor))
Exemplo n.º 4
0
 def call(self, inputs):
     if isinstance(inputs, list):
         inputs = [ops.convert_to_tensor(one_input) for one_input in inputs]
     else:
         inputs = [ops.convert_to_tensor(inputs)]
     # Register and override the gradients
     ops.RegisterGradient(self._id)(self.backward_tensor)
     g = ops.get_default_graph()
     with g.gradient_override_map({
             "PyFunc": self._id,
             "pyfunc_0": self._id,
             "PyFuncStateless": self._id
     }):
         res = script_ops.py_func(self.forward,
                                  inputs,
                                  self.Tout,
                                  name=self.name)
         oshape = self._output_shape([inp.get_shape() for inp in inputs])
         if isinstance(res, list):
             for i in range(len(res)):
                 res[i].set_shape(oshape[i])
         return res
Exemplo n.º 5
0
    def testNoGradientForStringOutputs(self):
        # This test can't be run twice because the TestStringOutput gradient can
        # only be registered once. Just run with the C API enabled.
        if not ops._USE_C_API: return

        with ops.Graph().as_default():

            def _TestOpGrad(_, float_grad, string_grad):
                """Gradient function for TestStringOutput."""
                self.assertEquals(float_grad.dtype, dtypes.float32)
                self.assertFalse(string_grad)
                return float_grad

            ops.RegisterGradient("TestStringOutput")(_TestOpGrad)

            c = constant(1.0)
            x, _ = test_ops.test_string_output(c)
            z = x * 2.0
            w = z * 3.0
            grads = gradients.gradients(z, [c])
            self.assertTrue(isinstance(grads[0], ops.Tensor))
            grads = gradients.gradients(w, [c])
            self.assertTrue(isinstance(grads[0], ops.Tensor))
Exemplo n.º 6
0
    x = op.inputs[0]
    a = op.inputs[1]  # [Rank(x), 2]
    # Takes a slice of a. The 1st column. [Rank(x), 1].
    pad_before = array_ops.slice(a, [0, 0],
                                 array_ops.stack([array_ops.rank(x), 1]))
    # Make it a 1-D tensor.
    begin = array_ops.reshape(pad_before, [-1])
    sizes = array_ops.shape(x)
    x_grad = array_ops.slice(grad, begin, sizes)
    if len(op.inputs) == 3:
        return x_grad, None, None
    else:
        return x_grad, None


ops.RegisterGradient("Pad")(_PadGrad)
ops.RegisterGradient("PadV2")(_PadGrad)


# ReverseSequence is just a permutation.  The gradient permutes back.
@ops.RegisterGradient("ReverseSequence")
def _ReverseSequenceGrad(op, grad):
    seq_lengths = op.inputs[1]
    return [
        array_ops.reverse_sequence(grad,
                                   batch_axis=op.get_attr("batch_dim"),
                                   seq_axis=op.get_attr("seq_dim"),
                                   seq_lengths=seq_lengths), None
    ]

            # This is the first time this Switch is visited. It comes from the
            # Identity branch. Such a Switch has `None` gradient for the Exit branch,
            # meaning the output is not differentiable.
            return None, None
    elif isinstance(op_ctxt, CondContext):
        good_grad = grad[op_ctxt.branch]
        zero_grad = grad[1 - op_ctxt.branch]
        # At this point, we have created zero_grad guarded by the right switch.
        return merge([good_grad, zero_grad], name="cond_grad")[0], None
    else:
        false_grad = switch(grad[0], op.inputs[1])[0]
        true_grad = switch(grad[1], op.inputs[1])[1]
        return merge([false_grad, true_grad])[0], None


ops.RegisterGradient("Switch")(_SwitchGrad)
ops.RegisterGradient("RefSwitch")(_SwitchGrad)


@ops.RegisterGradient("Merge")
def _MergeGrad(op, grad, _):
    """Gradients for a Merge op are calculated using a Switch op."""
    input_op = op.inputs[0].op
    graph = ops.get_default_graph()
    # pylint: disable=protected-access
    op_ctxt = control_flow_ops._GetOutputContext(input_op)
    grad_ctxt = graph._get_control_flow_context()
    # pylint: enable=protected-access
    if isinstance(op_ctxt, WhileContext):
        # pylint: disable=protected-access
        return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
Exemplo n.º 8
0
    seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs
    i, cs, f, o, ci, co, h = op.outputs
    _, cs_grad, _, _, _, _, h_grad = grads
    (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad,
     b_grad) = gen_rnn_ops.block_lstm_grad(
         seq_len_max=seq_len_max,
         x=x,
         cs_prev=cs_prev,
         h_prev=h_prev,
         w=w,
         wci=wci,
         wcf=wcf,
         wco=wco,
         b=b,
         i=i,
         cs=cs,
         f=f,
         o=o,
         ci=ci,
         co=co,
         h=h,
         cs_grad=cs_grad,
         h_grad=h_grad,
         use_peephole=op.get_attr("use_peephole"))
    return (None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad,
            wcf_grad, wco_grad, b_grad)


ops.RegisterGradient("BlockLSTM")(_block_lstm_grad)
ops.RegisterGradient("BlockLSTMV2")(_block_lstm_grad)
Exemplo n.º 9
0
    def _grad(op, grad):
        """A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
        # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
        # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
        # graph we special-case the situation where the FFT length and last
        # dimension of the input are known at graph construction time.
        fft_length = op.inputs[1]
        is_odd = _math_ops.mod(fft_length[-1], 2)
        input_last_dimension = _array_ops.shape(op.inputs[0])[-1]
        mask = _array_ops.concat(
            [[1.0], 2.0 * _array_ops.ones([input_last_dimension - 2 + is_odd]),
             _array_ops.ones([1 - is_odd])], 0)

        rsize = _math_ops.reciprocal(
            _math_ops.to_float(_fft_size_for_grad(grad, rank)))

        # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
        # factor and a mask. The mask scales the gradient for the Hermitian
        # symmetric components of the RFFT by a factor of two, since these
        # components are de-duplicated in the RFFT.
        the_rfft = rfft_fn(grad, fft_length)
        return the_rfft * _math_ops.cast(rsize * mask, _dtypes.complex64), None

    return _grad


_ops.RegisterGradient("RFFT")(_rfft_grad_helper(1, irfft))
_ops.RegisterGradient("IRFFT")(_irfft_grad_helper(1, rfft))
_ops.RegisterGradient("RFFT2D")(_rfft_grad_helper(2, irfft2d))
_ops.RegisterGradient("IRFFT2D")(_irfft_grad_helper(2, rfft2d))
Exemplo n.º 10
0
def load_function_def_library(library,
                              load_shared_name_suffix=None,
                              wrapper_function=None):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Gradients are re-registered under new names. Ops that reference the gradients
  are updated to reflect the new registered names.

  Args:
    library: FunctionDefLibrary proto message.
    load_shared_name_suffix: If specified, used to uniquify shared
      names. Otherwise, a unique name is generated.
    wrapper_function: An object that will be wrapped on newly created functions.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    library_function_names = set(fdef.signature.name
                                 for fdef in library.function)
    functions = {}
    renamed_functions = {}

    # Our graph building code currently requires functions to be registered with
    # some tf.Graph in order to import functions using the
    # op-name-is-function-name calling convention. To avoid leaking memory into
    # the global default graph when executing eagerly, we create a temporary
    # Graph.
    #
    # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
    # fixing function_def_to_graph_def.
    if ops.executing_eagerly_outside_functions():
        graph = ops.Graph()
    else:
        graph = ops.get_default_graph()

    if load_shared_name_suffix is None:
        load_shared_name_suffix = "_load_{}".format(ops.uid())

    # Custom gradient functions must be re-registered under new UIDs.
    library_gradient_names = {}  # Maps old op type to old function name
    new_gradient_op_types = {}  # Maps old gradient op type to new op type.
    gradients_to_register = {}  # Maps old function name to new op type
    for gdef in library.registered_gradients:
        if gdef.registered_op_type:
            new_op_type = custom_gradient.generate_name()
            old_op_type = compat.as_bytes(gdef.registered_op_type)

            library_gradient_names[old_op_type] = gdef.gradient_func
            new_gradient_op_types[old_op_type] = new_op_type
            gradients_to_register[gdef.gradient_func] = new_op_type

    function_deps = {}
    for fdef in library.function:
        function_deps[fdef.signature.name] = _list_function_deps(
            fdef, library_function_names, library_gradient_names)

    loaded_gradients = {}
    for fdef in _sort_function_defs(library, function_deps):
        copy = _fix_fdef(fdef, functions, load_shared_name_suffix,
                         new_gradient_op_types)

        # There is no need to copy all functions into the function def graph. It
        # leads to a O(n^2) increase of memory when importing functions and the
        # extra function definitions are a no-op since they already imported as a
        # function before and passed in explicitly (due to the topologic sort
        # import).
        with graph.as_default():
            func_graph = function_def_lib.function_def_to_graph(copy)
        # Restores gradients for function-call ops (not the same as ops that use
        # custom gradients)
        _restore_gradient_functions(func_graph, renamed_functions,
                                    loaded_gradients)

        for dep in function_deps[fdef.signature.name]:
            functions[dep].add_to_graph(func_graph)

        # We do not initialize the new ConcreteFunction's function_spec and/or
        # arg_keywords here (which are used to parse the structured and flat
        # signatures, respectively). ConcreteFunction that are part of a saved
        # function is set up later by recreate_function(); and bare ConcreteFunction
        # is set up by by setup_bare_concrete_function().
        # However, we copy the FunctionDef attributes to the new ConcreteFunction,
        # excluding the "_input_shapes", which may cause an error during input shape
        # initialization at a later stage.
        if "_input_shapes" in copy.attr:
            del copy.attr["_input_shapes"]
        func = function_lib.ConcreteFunction(func_graph, attrs=copy.attr)
        if wrapper_function:
            func = wrapper_function(func)
        func.add_to_graph(graph)

        functions[fdef.signature.name] = func
        renamed_functions[func.name] = func
        if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
            # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
            # is fixed. Currently it's leaking memory to maintain bug compatibility
            # with previous behavior.
            func.add_to_graph(ops.get_default_graph())

        if fdef.signature.name in gradients_to_register:
            gradient_op_type = gradients_to_register[fdef.signature.name]
            loaded_gradients[compat.as_bytes(gradient_op_type)] = func
            ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func))

    return functions
Exemplo n.º 11
0
    def _Grad(op, grad):
        """A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
        # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
        # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
        # graph we special-case the situation where the FFT length and last
        # dimension of the input are known at graph construction time.
        fft_length = op.inputs[1]
        is_odd = math_ops.mod(fft_length[-1], 2)
        input_last_dimension = array_ops.shape(op.inputs[0])[-1]
        mask = array_ops.concat(
            [[1.0], 2.0 * array_ops.ones([input_last_dimension - 2 + is_odd]),
             array_ops.ones([1 - is_odd])], 0)

        rsize = math_ops.reciprocal(
            math_ops.to_float(_FFTSizeForGrad(grad, rank)))

        # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
        # factor and a mask. The mask scales the gradient for the Hermitian
        # symmetric components of the RFFT by a factor of two, since these
        # components are de-duplicated in the RFFT.
        rfft = rfft_fn(grad, fft_length)
        return rfft * math_ops.cast(rsize * mask, dtypes.complex64), None

    return _Grad


ops.RegisterGradient("RFFT")(_RFFTGradHelper(1, spectral_ops.irfft))
ops.RegisterGradient("IRFFT")(_IRFFTGradHelper(1, spectral_ops.rfft))
ops.RegisterGradient("RFFT2D")(_RFFTGradHelper(2, spectral_ops.irfft2d))
ops.RegisterGradient("IRFFT2D")(_IRFFTGradHelper(2, spectral_ops.rfft2d))