Example #1
0
def _init_state_variable_for_rng(model, layout_map):
    """Init the state variable in tf.ranodm.Generator.

    Since the BaseRandomLayer in keras explicitly untrack the
    tf.random.Generator, the variable in it will stay as LazyInitVariable, which
    cause runtime error if we don't replace them with proper DVariable. Since
    user usually are not aware the existence of those variable, we will just
    give them replicated layout since they are tiny.

    Args:
      model: the model whose layers will be checked to find the
        BaseRandomLayers.
      layout_map: used to get the default mesh information to create DVariable.
    """

    for l in model._flatten(
            predicate=lambda o: isinstance(o, base_layer.BaseRandomLayer)):
        keras_generator = l._random_generator
        if keras_generator._built and keras_generator._generator is None:
            raise ValueError(
                "Keras is expected to use tf.random.Generator when using "
                "DTensor API. Please call "
                "`tf.keras.backend.experimental.enable_tf_random_generator` at "
                "the beginning of your program.")
        if hasattr(keras_generator, "_generator") and _is_lazy_init_variable(
                keras_generator._generator._state_var):
            # Replace it with DVariable
            keras_generator._generator._state_var = _create_dvariable(
                layout_map, "", keras_generator._generator._state_var)
        else:
            # When the keras_generator is not built yet. Call the init function
            # with DTensor device to init all the variable with default
            # replicated layout.
            with dtensor.run_on(layout_map.get_default_mesh()):
                keras_generator._maybe_init()
Example #2
0
def call_with_layout(fn, layout, *args, **kwargs):
    """Invoke the function with inputs and relayout the result.

    Args:
      fn: the function to invoke.
      layout: if not None, the output of the fn will be relayout with this.
      *args: positional arguments to be called with fn.
      **kwargs: keyword arguments to be called with fn.

    Returns:
      The output of fn, with potential relayout with the layout specified.
    """
    if layout:
        with dtensor.run_on(layout.mesh):
            result = fn(*args, **kwargs)
            return dtensor.relayout(result, layout)
    return fn(*args, **kwargs)