Exemple #1
0
def watch_variable(tape, variable):
    """Marks this variable to be watched by the given tape."""
    strategy, context = (
        distribution_strategy_context.get_strategy_and_replica_context())
    if context:
        variables = [strategy.extended.value_container(variable)]
    else:
        variables = strategy.experimental_local_results(variable)
    for var in variables:
        pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var)  # pylint: disable=protected-access
        pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
Exemple #2
0
def variable_accessed(variable):
    """Notifies all tapes in the stack that a variable has been accessed.

  Args:
    variable: variable to be watched.
  """
    strategy, context = (
        distribution_strategy_context.get_strategy_and_replica_context())
    if context:
        variables = [strategy.extended.value_container(variable)]
    else:
        variables = strategy.experimental_local_results(variable)
    for var in variables:
        pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
        pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
Exemple #3
0
def variables_accessed(variables):
    """Notifies all tapes in the stack that variables have been accessed.

  Only trainable variables are marked as accessed.

  Args:
    variables: iterable of variables to mark as accessed.
  """
    strategy, context = (
        distribution_strategy_context.get_strategy_and_replica_context())
    accessed = []
    if context:
        accessed = [
            strategy.extended.value_container(variable)
            for variable in variables if variable.trainable
        ]
    else:
        for variable in variables:
            if variable.trainable:
                accessed.extend(strategy.experimental_local_results(variable))

    for var in accessed:
        pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
        pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)