Пример #1
0
 def test_variable_capture(self):
     with variable_utils.record_variable_creation_scope() as variable_list:
         v1 = tf.Variable(1.0)
         v2 = tf.Variable('abc', name='my_test_var')
         v3 = tf.compat.v1.get_variable(
             name='v1_var',
             shape=(),
             initializer=tf.compat.v1.initializers.constant)
         # Explicitly add a variable that is not added to any collections.
         v4 = tf.compat.v1.get_variable(
             name='v1_var_no_collections',
             shape=(),
             initializer=tf.compat.v1.initializers.constant,
             collections=[])
     self.assertEqual([v1, v2, v3, v4], variable_list)
Пример #2
0
  def test_construct_from_keras_converges(self):
    functional_model = functional.functional_model_from_keras(
        keras_model=create_test_keras_model(),
        loss_fn=tf.keras.losses.MeanSquaredError(),
        input_spec=(tf.TensorSpec([None, 1], dtype=tf.float32),
                    tf.TensorSpec([None, 1], dtype=tf.int32)))
    with tf.Graph().as_default() as test_graph:
      # Capture all the variables for later initialization in the session,
      # otherwise it's hard to get our hands on the Keras-owned variables.
      with variable_utils.record_variable_creation_scope(
      ) as captured_variables:
        # Create data satisfying y = 2*x + 1
        dataset = tf.data.Dataset.from_tensor_slices((
            # Features
            [[1.0], [2.0], [3.0]],
            # Labels.
            [[3.0], [5.0], [7.0]],
        )).batch(1)
        variables = tf.nest.map_structure(tf.Variable,
                                          functional_model.initial_weights)
        optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

        @tf.function
        def train():
          weights = tf.nest.map_structure(lambda v: v.read_value(), variables)
          initial_loss = loss = functional_model.forward_pass(
              weights, next(iter(dataset)), training=True).loss
          trainable = variables[0]
          for batch in dataset.repeat(30):
            with tf.GradientTape() as tape:
              weights = tf.nest.map_structure(lambda v: v.read_value(),
                                              variables)
              tape.watch(weights[0])
              batch_output = functional_model.forward_pass(
                  weights, batch, training=True)
            gradients = tape.gradient(batch_output.loss, weights[0])
            optimizer.apply_gradients(zip(gradients, trainable))
            loss = batch_output.loss
          return initial_loss, loss

        initial_loss, final_loss = train()
    with tf.compat.v1.Session(graph=test_graph) as sess:
      sess.run(tf.compat.v1.initializers.variables(captured_variables))
      initial_loss, final_loss = sess.run([initial_loss, final_loss])
    # Expect some amount of convergence after a few epochs of the dataset.
    self.assertGreater(initial_loss, 2.0)
    self.assertLess(final_loss, 0.2)
Пример #3
0
def tf_computation_serializer(parameter_type: Optional[computation_types.Type],
                              context_stack):
    """Serializes a TF computation with a given parameter type.

  Args:
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `computation_types.Type`.
    context_stack: The context stack to use.

  Yields:
    The first yielded value will be a Python object (such as a dataset,
    a placeholder, or a `structure.Struct`) to be passed to the function to
    serialize. The result of the function should then be passed to the
    following `send` call.
    The next yielded value will be
    a tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if parameter_type is not None:
        py_typecheck.check_type(parameter_type, computation_types.Type)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'arg', parameter_type, graph)
        else:
            parameter_value = None
            parameter_binding = None
        context = tensorflow_computation_context.TensorFlowComputationContext(
            graph)
        with context_stack.install(context):
            with variable_utils.record_variable_creation_scope(
            ) as all_variables:
                result = yield parameter_value
            initializer_ops = []
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                initializer_ops.append(
                    tf.compat.v1.initializers.variables(all_variables,
                                                        name=name))
            initializer_ops.extend(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))
            if initializer_ops:
                # Before running the main new init op, run any initializers for sub-
                # computations from context.init_ops. Variables from import_graph_def
                # will not make it into the global collections, and so will not be
                # initialized without this code path.
                with tf.compat.v1.control_dependencies(context.init_ops):
                    init_op_name = tf.group(*initializer_ops,
                                            name='grouped_initializers').name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)

    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding,
                               initialize_op=init_op_name)
    yield pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow), type_signature
Пример #4
0
def functional_model_from_keras(
    keras_model: Union[tf.keras.Model, Callable[[], tf.keras.Model]],
    loss_fn: tf.keras.losses.Loss,
    input_spec: Union[Sequence[Any], Mapping[str, Any]],
) -> FunctionalModel:
    """Converts a `tf.keras.Model` to a `tff.learning.models.FunctionalModel`.

  NOTE: This method only supports models where calling that model with
  `training=True` and `training=False` produce the same graph. Keras layers
  such as batch normalization will fail because they require updating internal
  state when `training=True` which is not suported.

  IMPORTANT: The returned model must only be used in a graph context (for
  example inside a `tff.tf_computation` decorated callable). It will raise an
  error otherwise.

  Args:
    keras_model: A `tf.keras.Model` object, should be uncompiled. If compiled,
      the metrics, optimizer, and loss function will be ignored. Note: models
        that have multiple outputs will send all outputs to the `loss_fn`.
    loss_fn: A `tf.keras.losses.Loss` object.
    input_spec: A structure of `tf.TensorSpec` defining the input to the model.

  Returns:
    A `tff.learning.models.FunctionalModel`.

  Raises:
    KerasFunctionalModelError: the model has a batch normalization layer.
  """
    # We're going to do something fancy here:
    #
    # 1. Get a copy of all the variables, in the order they are created during
    #    model construction, when in a graph context.
    # 2. Use this ordering to construct a type signature of the model weights in
    #    such a way that we can inject TENSORS (those that are coming in as
    #    arguments) in place of variable creation during a call to
    #    `tf.keras.models.clone_model()`, which gives us a newly constructed Keras
    #    model in the context we want.
    # 3. Profit by having variableless graphs!
    #
    # **WARNING** Caveats:
    #
    # 1. This model _must_ be used inside a graph context (e.g. a
    #    `tff.tf_computation` decorated callable, aka a `tff.Computation`). Keras
    #    appears to create extra variables in the eager context that are not part
    #    of the user specified model, and end up not being compatible.
    #
    # 2. We have found that this trick does NOT work with non-trainable variables
    #    that are updated during training. Namely layers such as
    #    BatchNormalization try to update means/variances during training and are
    #    not compatible with this approach. We generally recommend
    #    GroupNormalization in place of BatchNormalization at the current time.
    #
    # 3. This does not support multiple outputs with different loss functions, or
    #    laywerise regularization losses TODO(b/156629927).
    if isinstance(keras_model, tf.keras.Model):
        for layer in keras_model.layers:
            # There may be other layers that are problematic, at this time updating
            # the mean/variance in batchnorm layer is the only known such instance.
            if isinstance(layer, tf.keras.layers.BatchNormalization):
                raise KerasFunctionalModelError(
                    'Keras model contains a batch normalization layer, which is '
                    'incompatible with `tff.learning.models.FunctionalModel`. Consider '
                    'using group normalization instead.')
        if keras_model.non_trainable_variables:
            raise KerasFunctionalModelError(
                'Received a Keras model with non-trainable variables. Keras models with '
                'non-trainable variables are currently not supported by FunctionalModel'
                '. Most training algorithms (e.g. Federated Averaging) will not '
                'aggregate them, and they are not updated locally by the optimizer. '
                'We can relax this in the future if we have APIs that support updating '
                'non-trainable variables.')
    elif not callable(keras_model):
        raise ValueError(
            '`keras_model` must be a `tf.keras.Model` or a no-arg '
            'callable that returns a `tf.keras.Model`.')

    # Clone the keras model inside a graph context so that we only get the
    # variables for the layers (otherwise keras adds other non-user variables). We
    # also setup ops to inject the current model weights, because the cloned model
    # will be re-initialized from scratch.
    with tf.Graph().as_default() as g:
        with variable_utils.record_variable_creation_scope(
        ) as captured_variables:
            if isinstance(keras_model, tf.keras.Model):
                try:
                    cloned_model = tf.keras.models.clone_model(keras_model)
                except RuntimeError as e:
                    raise KerasFunctionalModelError(
                        'Encountered a error converting the Keras model. Often this '
                        'occurs when the `tf.keras.Model` has a layer that receives '
                        'inputs from other layers directly (e.g. shared embeddings).'
                        'To avoid the problem, wrap the `tf.keras.Model` construction in '
                        'a no-arg callable (e.g. lambda) and pass that callable to '
                        '`functional_model_from_keras`') from e
                if len(cloned_model.variables) != len(keras_model.variables):
                    raise KerasFunctionalModelError(
                        'The input Keras model is likely sharing variables across layers '
                        'which is unsupported. Cloning the model will duplicate these '
                        'variables and result in unexpected training gradients.'
                    )
            else:
                cloned_model = keras_model()
            # Ensure our cloned model has the same weights as the current model.
            # We'll feed in the current model waits into the palceholders for
            # assignmnet in a session below.
            def assign_placeholder(v):
                p = tf.compat.v1.placeholder(dtype=v.dtype)
                return v.assign(p), p

            assign_ops, placeholders = zip(*(assign_placeholder(v)
                                             for v in cloned_model.variables))
    trainable_variables = tuple(v for v in captured_variables if v.trainable)
    non_trainable_variables = tuple(v for v in captured_variables
                                    if not v.trainable)

    # Here we get the initial weights from the incoming keras model in the order
    # they are constructed; and also ensure that the values are set to the
    # incoming model weights rather than their fresh initialization.
    if isinstance(keras_model, tf.keras.Model):
        model_for_variables = keras_model
    else:
        model_for_variables = keras_model()
    current_model_weights = tf.nest.map_structure(
        lambda v: v.read_value().numpy(), model_for_variables.variables)
    with tf.compat.v1.Session(graph=g) as sess:
        sess.run(tf.compat.v1.initializers.variables(captured_variables))
        sess.run(fetches=assign_ops,
                 feed_dict=dict(zip(placeholders, current_model_weights)))
        initial_weights = sess.run(fetches=(trainable_variables,
                                            non_trainable_variables))

    @tf.function
    def predict_on_batch(model_weights: ModelWeights,
                         x: Any,
                         training: bool = True) -> Any:
        with tf.init_scope():
            if tf.executing_eagerly():
                raise KerasFunctionalModelError(
                    'tf.keras.Model used as a FunctionalModel is only usable inside a '
                    'tff.tf_computation decorated callable or a graph context.'
                )
        # Make a copy of the weights container; can't mutate Python containers
        # inside a tf.function.
        trainable, non_trainable = (list(w) for w in model_weights)

        # Here were intercept variable creation requests during the
        # `tf.keras.models.clone_model()` call.
        #
        # Instead of forwarding the variable request to TF core and getting a
        # `tf.Variable` back, we skip that and return only the `tf.Tensor` that
        # corresponds to the `tf.Variable` recreation request (avoiding any variable
        # creation). This works because TF operations that accept `tf.Variable`
        # inputs automatically call `variable.read_value()` and then operate on that
        # resulting tensor. We're relying on shortcutting that and providing the
        # tensor straight away.
        #
        # For example, `tf.matmul` doesn't notice its input is `tf.Variable` or
        # `tf.Tensor`:
        #
        #   v = tf.Variable([[1], [2], [3]])
        #   tf.matmul(v, [[4, 5, 6]])
        #
        #   and
        #
        #   v = tf.constant([[1], [2], [3]])
        #   tf.matmul(v, [[4, 5, 6]])
        #
        #   both result in:
        #
        #   <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
        #   array([[ 4,  5,  6],
        #          [ 8, 10, 12],
        #          [12, 15, 18]], dtype=int32)>
        def swap_tensor_parameter_for_variable(_, **kwargs):
            if kwargs.get('trainable', True):
                return trainable.pop(0)
            else:
                return non_trainable.pop(0)

        with tf.variable_creator_scope(swap_tensor_parameter_for_variable):
            if isinstance(keras_model, tf.keras.Model):
                variableless_model = tf.keras.models.clone_model(keras_model)
            else:
                variableless_model = keras_model()
        return variableless_model(x, training)

    @tf.function
    def forward_pass(model_weights: ModelWeights,
                     batch_input: Any,
                     training: bool = True) -> model_lib.BatchOutput:
        if isinstance(batch_input, collections.abc.Mapping):
            x = batch_input['x']
            y = batch_input['y']
        elif isinstance(batch_input, collections.abc.Sequence):
            x, y = batch_input
        else:
            raise ValueError(
                '`batch_input` must be either a mapping with keys `x` '
                f'and `y` or a sequence of `(x, y)`. Got: {batch_input!r}')
        predictions = predict_on_batch(model_weights, x, training)
        batch_loss = loss_fn(y_true=y, y_pred=predictions)

        # TODO(b/207033265): more work needed to support models with multiple loss
        # functions.

        def nrows(t):
            return t.nrows() if isinstance(t,
                                           tf.RaggedTensor) else tf.shape(t)[0]

        return model_lib.BatchOutput(loss=batch_loss,
                                     predictions=predictions,
                                     num_examples=nrows(
                                         tf.nest.flatten(batch_input)[0]))

    return FunctionalModel(initial_weights=initial_weights,
                           forward_pass_fn=forward_pass,
                           predict_on_batch_fn=predict_on_batch,
                           input_spec=input_spec)
Пример #5
0
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
    """Serializes the 'target' as a TF computation with a given parameter type.

  See also `serialize_tf2_as_tf_computation` for TensorFlow 2
  serialization.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager
      functions, and eventually aim at full support for and compliance with TF
      2.0. This function is currently required to declare either zero parameters
      if `parameter_type` is `None`, or exactly one parameter if it's not
      `None`.  The nested structure of this parameter must correspond to the
      structure of the 'parameter_type'. In the future, we may support targets
      with multiple args/keyword args (to be documented in the API and
      referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    context_stack: The context stack to use.

  Returns:
    A tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(target, types.FunctionType)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    parameter_type = computation_types.to_type(parameter_type)
    signature = function_utils.get_signature(target)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            if len(signature.parameters) != 1:
                raise ValueError(
                    'Expected the target to declare exactly one parameter, found {!r}.'
                    .format(signature.parameters))
            parameter_name = next(iter(signature.parameters))
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                parameter_name, parameter_type, graph)
        else:
            if signature.parameters:
                raise ValueError(
                    'Expected the target to declare no parameters, found {!r}.'
                    .format(signature.parameters))
            parameter_value = None
            parameter_binding = None
        context = tensorflow_computation_context.TensorFlowComputationContext(
            graph)
        with context_stack.install(context):
            with variable_utils.record_variable_creation_scope(
            ) as all_variables:
                if parameter_value is not None:
                    result = target(parameter_value)
                else:
                    result = target()
            initializer_ops = []
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                initializer_ops.append(
                    tf.compat.v1.initializers.variables(all_variables,
                                                        name=name))
            initializer_ops.extend(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))
            if initializer_ops:
                # Before running the main new init op, run any initializers for sub-
                # computations from context.init_ops. Variables from import_graph_def
                # will not make it into the global collections, and so will not be
                # initialized without this code path.
                with tf.compat.v1.control_dependencies(context.init_ops):
                    init_op_name = tf.group(*initializer_ops,
                                            name='grouped_initializers').name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)

    # WARNING: we do not really want to be modifying the graph here if we can
    # avoid it. This is purely to work around performance issues uncovered with
    # the non-standard usage of Tensorflow and have been discussed with the
    # Tensorflow core team before being added.
    clean_graph_def = _clean_graph_def(graph.as_graph_def())
    tensorflow = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(clean_graph_def),
        parameter=parameter_binding,
        result=result_binding,
        initialize_op=init_op_name)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow), type_signature