def __call__(self, *args, **kwargs): # Flatten arguments. flat_args = nest.flatten(args, expand_composites=True) flat_kwargs = nest.flatten(kwargs, expand_composites=True) all_args = flat_args + flat_kwargs # Trace outer_ctx = context_lib.get_default() ctx = NewTracingContext(self.name) with context_lib.set_default(ctx): # TODO(srbs): Iterating over list of inputs is a known performance # bottleneck. Add a pybind API for this. inputs = [ctx.AddParameter(arg.DataType()) for arg in all_args] structured_args = nest.pack_sequence_as(args, inputs[:len(flat_args)]) structured_kwargs = nest.pack_sequence_as(kwargs, inputs[len(flat_args):]) structured_outputs = self._python_func(*structured_args, **structured_kwargs) py_outputs = nest.flatten(structured_outputs, expand_composites=True) num_outputs = len(py_outputs) # TODO(srbs): Drop Nones before calling Finalize. finalized_f = ctx.Finalize(py_outputs) outer_ctx.RegisterFunction(finalized_f) # Build call op call_op = outer_ctx.CreateOperation(self.name, "") call_op.SetOpName(self.name) for arg in all_args: call_op.AddInput(arg) call_op_outputs = call_op.Execute(num_outputs) # Cleanup outer_ctx.RemoveFunction(self.name) return nest.pack_sequence_as(structured_outputs, call_op_outputs)
def gradient(self, targets, sources, output_gradients=None): ctx = context_stack.get_default() flat_targets = nest.flatten(targets) flat_sources = nest.flatten(sources) out_grads = self._c_tape.ComputeGradient(ctx, flat_targets, flat_sources, output_gradients or []) return nest.pack_sequence_as(sources, out_grads)
def __init__(self, persistent=False): self._c_tape = _tape.Tape(persistent) ctx = context_stack.get_default() self._tape_context = _tape.TapeContext( ctx, self._c_tape, gradient_registry.get_global_registry()) self._ctx_manager = None
def sub(a, b, name=None): ctx = context.get_default() return _math_ops.sub(ctx, a, b, name)
def neg(a, name=None): ctx = context.get_default() return _math_ops.neg(ctx, a, name)
def mat_mul(a, b, name=None): ctx = context.get_default() return _math_ops.mat_mul(ctx, a, b, name)
def div_no_nan(a, b, name=None): ctx = context.get_default() return _math_ops.div_no_nan(ctx, a, b, name)
def log1p(a, name=None): ctx = context.get_default() return _math_ops.log1p(ctx, a, name)
def mat_mul(a, b, name=None): ctx = context.get_default() tape = tape_stack.get_default() grad_registry = gradient_registry.get_global_registry() return _math_ops.mat_mul(ctx, a, b, name, tape, grad_registry)
def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): ctx = context.get_default() return _nn_ops.sparse_softmax_cross_entropy_with_logits( ctx, logits, labels, name)
def relu(a, name=None): ctx = context.get_default() return _nn_ops.relu(ctx, a, name)
def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): ctx = context.get_default() tape = tape_stack.get_default() grad_registry = gradient_registry.get_global_registry() return _nn_ops.sparse_softmax_cross_entropy_with_logits( ctx, logits, labels, name, tape, grad_registry)
def relu(a, name=None): ctx = context.get_default() tape = tape_stack.get_default() grad_registry = gradient_registry.get_global_registry() return _nn_ops.relu(ctx, a, name, tape, grad_registry)