def watch_variable(tape, variable): """Marks this variable to be watched by the given tape.""" strategy, context = ( distribution_strategy_context.get_strategy_and_replica_context()) if context: variables = [strategy.extended.value_container(variable)] else: variables = strategy.experimental_local_results(variable) for var in variables: pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
def variable_accessed(variable): """Notifies all tapes in the stack that a variable has been accessed. Args: variable: variable to be watched. """ strategy, context = ( distribution_strategy_context.get_strategy_and_replica_context()) if context: variables = [strategy.extended.value_container(variable)] else: variables = strategy.experimental_local_results(variable) for var in variables: pywrap_tfe.TFE_Py_TapeVariableAccessed(var) pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
def variables_accessed(variables): """Notifies all tapes in the stack that variables have been accessed. Only trainable variables are marked as accessed. Args: variables: iterable of variables to mark as accessed. """ strategy, context = ( distribution_strategy_context.get_strategy_and_replica_context()) accessed = [] if context: accessed = [ strategy.extended.value_container(variable) for variable in variables if variable.trainable ] else: for variable in variables: if variable.trainable: accessed.extend(strategy.experimental_local_results(variable)) for var in accessed: pywrap_tfe.TFE_Py_TapeVariableAccessed(var) pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)