예제 #1
0
    def __init__(self, layer):
        self.layer = layer

        self.layer_call_method = _get_layer_call_method(layer)
        self._expects_training_arg = utils.layer_uses_training_bool(layer)
        self._call_spec = layer._call_spec  # pylint: disable=protected-access

        # Create new call spec if the layer itself does not accept a training arg,
        # but one of its child layers does. When this layer's call functions are
        # traced, they will be traced with an added `training` keyword argument.
        if (not self.layer._expects_training_arg
                and self._expects_training_arg):  # pylint: disable=protected-access
            arg_spec = utils.set_training_arg_spec(
                self._call_spec.full_argspec, False)
            self._call_spec = layer_utils.CallFunctionSpec(arg_spec)

        self._layer_inputs = self._get_layer_inputs(layer)
        self._functions = weakref.WeakValueDictionary()

        # Get the input argument name from the args.
        if self._call_spec.arg_names:
            self._input_arg_name = self._call_spec.arg_names[0]
        else:
            # Layer could be defined with only varargs, in which case use a default
            # name.
            self._input_arg_name = "inputs"
예제 #2
0
파일: save_impl.py 프로젝트: ohsdba/keras
  def __init__(self, layer):
    self.layer = layer

    self.layer_call_method = _get_layer_call_method(layer)
    self._expects_training_arg = utils.layer_uses_training_bool(layer)
    self._training_arg_index = utils.get_training_arg_index(
        self.layer_call_method)

    # If the layer call function has kwargs, then the traced function cannot
    # have an input signature.
    arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
    self._has_kwargs = bool(self._expects_training_arg or
                            arg_spec.defaults or
                            arg_spec.kwonlyargs or
                            arg_spec.varkw)

    self._input_signature = self._generate_input_signature(layer)
    self._functions = weakref.WeakValueDictionary()
    # Bool indicating whether this object is currently tracing the layer call
    # functions.
    self.tracing = False

    # Get the input argument name from the args.
    args = arg_spec.args
    if tf_inspect.ismethod(self.layer_call_method):
      args = args[1:]
    self._input_arg_name = args[0] if args else 'inputs'
예제 #3
0
    def __init__(self, layer):
        self.layer = layer

        self.layer_call_method = _get_layer_call_method(layer)
        self._expects_training_arg = utils.layer_uses_training_bool(layer)
        self._training_arg_index = utils.get_training_arg_index(
            self.layer_call_method)

        self._layer_inputs = self._get_layer_inputs(layer)
        self._functions = weakref.WeakValueDictionary()

        # Get the input argument name from the args.
        arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
        args = arg_spec.args
        if tf_inspect.ismethod(self.layer_call_method):
            args = args[1:]
        self._input_arg_name = args[0] if args else 'inputs'