Пример #1
0
    def _call_func(self, args, kwargs):
        try:
            vars_at_start = len(
                ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
            trainable_at_start = len(
                ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
            if self._variables_created:
                result = self._func(*args, **kwargs)
            else:
                # The first time we run, restore variables if necessary (via
                # Checkpointable).
                with checkpointable_util.capture_dependencies(template=self):
                    result = self._func(*args, **kwargs)

            if self._variables_created:
                # Variables were previously created, implying this is not the first
                # time the template has been called. Check to make sure that no new
                # trainable variables were created this time around.
                trainable_variables = ops.get_collection_ref(
                    ops.GraphKeys.TRAINABLE_VARIABLES)
                # If a variable that we intend to train is created as a side effect
                # of creating a template, then that is almost certainly an error.
                if trainable_at_start != len(trainable_variables):
                    raise ValueError(
                        "Trainable variable created when calling a template "
                        "after the first time, perhaps you used tf.Variable "
                        "when you meant tf.get_variable: %s" %
                        (trainable_variables[trainable_at_start:], ))

                # Non-trainable tracking variables are a legitimate reason why a new
                # variable would be created, but it is a relatively advanced use-case,
                # so log it.
                variables = ops.get_collection_ref(
                    ops.GraphKeys.GLOBAL_VARIABLES)
                if vars_at_start != len(variables):
                    logging.info(
                        "New variables created when calling a template after "
                        "the first time, perhaps you used tf.Variable when you "
                        "meant tf.get_variable: %s", variables[vars_at_start:])
            else:
                self._variables_created = True
            return result
        except Exception as exc:
            # Reraise the exception, but append the original definition to the
            # trace.
            args = exc.args
            if not args:
                arg0 = ""
            else:
                arg0 = args[0]
            trace = "".join(
                _skip_common_stack_elements(self._stacktrace,
                                            traceback.format_stack()))
            arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
            new_args = [arg0]
            new_args.extend(args[1:])
            exc.args = tuple(new_args)
            raise
Пример #2
0
  def _call_func(self, args, kwargs):
    try:
      vars_at_start = len(
          ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
      trainable_at_start = len(
          ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
      if self._variables_created:
        result = self._func(*args, **kwargs)
      else:
        # The first time we run, restore variables if necessary (via
        # Checkpointable).
        with checkpointable_util.capture_dependencies(template=self):
          result = self._func(*args, **kwargs)

      if self._variables_created:
        # Variables were previously created, implying this is not the first
        # time the template has been called. Check to make sure that no new
        # trainable variables were created this time around.
        trainable_variables = ops.get_collection_ref(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        # If a variable that we intend to train is created as a side effect
        # of creating a template, then that is almost certainly an error.
        if trainable_at_start != len(trainable_variables):
          raise ValueError("Trainable variable created when calling a template "
                           "after the first time, perhaps you used tf.Variable "
                           "when you meant tf.get_variable: %s" %
                           (trainable_variables[trainable_at_start:],))

        # Non-trainable tracking variables are a legitimate reason why a new
        # variable would be created, but it is a relatively advanced use-case,
        # so log it.
        variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)
        if vars_at_start != len(variables):
          logging.info("New variables created when calling a template after "
                       "the first time, perhaps you used tf.Variable when you "
                       "meant tf.get_variable: %s",
                       variables[vars_at_start:])
      else:
        self._variables_created = True
      return result
    except Exception as exc:
      # Reraise the exception, but append the original definition to the
      # trace.
      args = exc.args
      if not args:
        arg0 = ""
      else:
        arg0 = args[0]
      trace = "".join(_skip_common_stack_elements(self._stacktrace,
                                                  traceback.format_stack()))
      arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
      new_args = [arg0]
      new_args.extend(args[1:])
      exc.args = tuple(new_args)
      raise
 def __call__(self):
   with variable_scope.variable_scope("ManualScope") as vs:
     self.variable_scope = vs
     with checkpointable_utils.capture_dependencies(template=self):
       return self._build()
 def __call__(self):
   with variable_scope.variable_scope("ManualScope") as vs:
     self.variable_scope = vs
     with checkpointable_utils.capture_dependencies(template=self):
       return self._build()