示例#1
0
    def __call__(self, *args):
        """Executes the passed function in eager mode."""
        tensor_inputs = [
            x for x in nest.flatten(args) if isinstance(x, ops.Tensor)
        ]
        if tape.should_record(tensor_inputs) or tape.should_record(
                self._extra_inputs):
            if not self._has_backprop:
                self._compute_backprop()
            return self._backprop_call(tensor_inputs)

        if context.in_graph_mode():
            g = ops.get_default_graph()
            if self._fdef.name not in g._functions:  # pylint: disable=protected-access
                g._add_function(self._fdef)  # pylint: disable=protected-access
            signature = self._fdef.definition.signature
            args = list(tensor_inputs) + self._extra_inputs
            op = g.create_op(
                signature.name, [ops.convert_to_tensor(x) for x in args],
                [dtypes.DType(x.type) for x in signature.output_arg],
                op_def=signature,
                name="FunctionCall",
                compute_shapes=False)
            result = op.outputs
            for i, s in enumerate(self._output_shapes):
                result[i].set_shape(s)
        else:
            result = execute.execute(str(self._func_name),
                                     num_outputs=self._num_outputs,
                                     inputs=tensor_inputs + self._extra_inputs)

        return self._build_call_outputs(self._returns, result)
示例#2
0
  def __call__(self, *args):
    """Executes the passed function in eager mode."""
    tensor_inputs = [
        x for x in nest.flatten(args)
        if isinstance(x, ops.Tensor)
    ]
    if tape.should_record(tensor_inputs) or tape.should_record(
        self._extra_inputs):
      if not self._has_backprop:
        self._compute_backprop()
      return self._backprop_call(tensor_inputs)

    if context.in_graph_mode():
      g = ops.get_default_graph()
      if self._fdef.name not in g._functions:  # pylint: disable=protected-access
        g._add_function(self._fdef)  # pylint: disable=protected-access
      signature = self._fdef.definition.signature
      args = list(tensor_inputs) + self._extra_inputs
      op = g.create_op(
          signature.name, [ops.convert_to_tensor(x) for x in args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      result = op.outputs
      for i, s in enumerate(self._output_shapes):
        result[i].set_shape(s)
    else:
      result = execute.execute(
          str(self._func_name),
          num_outputs=self._num_outputs,
          inputs=tensor_inputs + self._extra_inputs)

    return self._build_call_outputs(self._returns, result)
示例#3
0
    def __call__(self, *args):
        """Executes the passed function in eager mode."""
        for v in self._variables:
            if v._trainable:  # pylint: disable=protected-access
                tape.watch_variable(v)

        tensor_inputs = [
            x for x in nest.flatten(args) if isinstance(x, ops.Tensor)
        ]
        if tape.should_record(tensor_inputs) or tape.should_record(
                self._extra_inputs):
            if not self._has_backprop:
                self._compute_backprop()
            return self._backprop_call(tensor_inputs)

        ctx = context.context()
        if ctx.in_graph_mode():
            g = ops.get_default_graph()
            if self._function_def.name not in g._functions:  # pylint: disable=protected-access
                g._add_function(self._function_def)  # pylint: disable=protected-access
            for f in self._graph._functions.values():  # pylint: disable=protected-access
                if f.name not in g._functions:  # pylint: disable=protected-access
                    g._add_function(f)  # pylint: disable=protected-access
            signature = self._function_def.definition.signature
            args = list(tensor_inputs) + self._extra_inputs
            op = g.create_op(
                signature.name,
                [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
                tuple(
                    dtypes_module.DType(x.type) for x in signature.output_arg),
                op_def=signature,
                name="FunctionCall",
                compute_shapes=False)
            result = op.outputs
            if not result:
                return op
            for i, s in enumerate(self._output_shapes):
                result[i].set_shape(s)
        else:
            result = execute.execute(str(self._func_name),
                                     num_outputs=self._num_outputs,
                                     inputs=tensor_inputs + self._extra_inputs,
                                     attrs=None,
                                     ctx=ctx)

        return self._build_call_outputs(result)
示例#4
0
  def __call__(self, *args):
    """Executes the passed function in eager mode."""
    for v in self._variables:
      if v._trainable:  # pylint: disable=protected-access
        tape.watch_variable(v)

    tensor_inputs = [x for x in nest.flatten(args)
                     if isinstance(x, ops.Tensor)]
    if tape.should_record(tensor_inputs) or tape.should_record(
        self._extra_inputs):
      if not self._has_backprop:
        self._compute_backprop()
      return self._backprop_call(tensor_inputs)

    ctx = context.context()
    if ctx.in_graph_mode():
      g = ops.get_default_graph()
      self.add_to_graph(g)
      signature = self._function_def.definition.signature
      args = list(tensor_inputs) + self._extra_inputs
      op = g.create_op(
          signature.name,
          [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
          tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      result = op.outputs
      if not result:
        return op
      for i, s in enumerate(self._output_shapes):
        result[i].set_shape(s)
    else:
      result = execute.execute(
          str(self._func_name),
          num_outputs=self._num_outputs,
          inputs=tensor_inputs + self._extra_inputs,
          attrs=None,
          ctx=ctx)

    return self._build_call_outputs(result)
示例#5
0
  def __call__(self, *args):
    """Executes the passed function in eager mode."""
    for v in self._variables:
      if v.trainable:
        tape.watch_variable(v)

    tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
    if tape.should_record(tensor_inputs) or tape.should_record(
        self._extra_inputs):
      if self._backward_function is None:
        self._construct_backprop_function()
      return self._backprop_call(tensor_inputs)

    ctx = context.context()
    if ctx.executing_eagerly():
      result = execute.execute(
          str(self._func_name),
          num_outputs=self._num_outputs,
          inputs=tensor_inputs + self._extra_inputs,
          attrs=None,
          ctx=ctx)
    else:
      g = ops.get_default_graph()
      self.add_to_graph(g)
      signature = self._function_def.definition.signature
      args = list(tensor_inputs) + self._extra_inputs
      op = g.create_op(
          signature.name,
          [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
          tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      result = op.outputs
      if not result:
        return op
      for i, s in enumerate(self._output_shapes):
        result[i].set_shape(s)

    return self._build_call_outputs(result)