def _call_func(self, args, kwargs): try: vars_at_start = len( ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)) trainable_at_start = len( ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)) if self._variables_created: result = self._func(*args, **kwargs) else: # The first time we run, restore variables if necessary (via # Checkpointable). with checkpointable_util.capture_dependencies(template=self): result = self._func(*args, **kwargs) if self._variables_created: # Variables were previously created, implying this is not the first # time the template has been called. Check to make sure that no new # trainable variables were created this time around. trainable_variables = ops.get_collection_ref( ops.GraphKeys.TRAINABLE_VARIABLES) # If a variable that we intend to train is created as a side effect # of creating a template, then that is almost certainly an error. if trainable_at_start != len(trainable_variables): raise ValueError( "Trainable variable created when calling a template " "after the first time, perhaps you used tf.Variable " "when you meant tf.get_variable: %s" % (trainable_variables[trainable_at_start:], )) # Non-trainable tracking variables are a legitimate reason why a new # variable would be created, but it is a relatively advanced use-case, # so log it. variables = ops.get_collection_ref( ops.GraphKeys.GLOBAL_VARIABLES) if vars_at_start != len(variables): logging.info( "New variables created when calling a template after " "the first time, perhaps you used tf.Variable when you " "meant tf.get_variable: %s", variables[vars_at_start:]) else: self._variables_created = True return result except Exception as exc: # Reraise the exception, but append the original definition to the # trace. args = exc.args if not args: arg0 = "" else: arg0 = args[0] trace = "".join( _skip_common_stack_elements(self._stacktrace, traceback.format_stack())) arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace) new_args = [arg0] new_args.extend(args[1:]) exc.args = tuple(new_args) raise
def _call_func(self, args, kwargs): try: vars_at_start = len( ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)) trainable_at_start = len( ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)) if self._variables_created: result = self._func(*args, **kwargs) else: # The first time we run, restore variables if necessary (via # Checkpointable). with checkpointable_util.capture_dependencies(template=self): result = self._func(*args, **kwargs) if self._variables_created: # Variables were previously created, implying this is not the first # time the template has been called. Check to make sure that no new # trainable variables were created this time around. trainable_variables = ops.get_collection_ref( ops.GraphKeys.TRAINABLE_VARIABLES) # If a variable that we intend to train is created as a side effect # of creating a template, then that is almost certainly an error. if trainable_at_start != len(trainable_variables): raise ValueError("Trainable variable created when calling a template " "after the first time, perhaps you used tf.Variable " "when you meant tf.get_variable: %s" % (trainable_variables[trainable_at_start:],)) # Non-trainable tracking variables are a legitimate reason why a new # variable would be created, but it is a relatively advanced use-case, # so log it. variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES) if vars_at_start != len(variables): logging.info("New variables created when calling a template after " "the first time, perhaps you used tf.Variable when you " "meant tf.get_variable: %s", variables[vars_at_start:]) else: self._variables_created = True return result except Exception as exc: # Reraise the exception, but append the original definition to the # trace. args = exc.args if not args: arg0 = "" else: arg0 = args[0] trace = "".join(_skip_common_stack_elements(self._stacktrace, traceback.format_stack())) arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace) new_args = [arg0] new_args.extend(args[1:]) exc.args = tuple(new_args) raise
def __call__(self): with variable_scope.variable_scope("ManualScope") as vs: self.variable_scope = vs with checkpointable_utils.capture_dependencies(template=self): return self._build()