示例#1
0
  def adapt(self, data, reset_state=True):
    """Adapt the state of the layers of the preprocessing stage to the data.

    Args:
      data: A batched Dataset object, or a NumPy array, or an EagerTensor.
        Data to be iterated over to adapt the state of the layers in this
        preprocessing stage.
      reset_state: Whether this call to `adapt` should reset the state of
        the layers in this preprocessing stage.
    """
    if not isinstance(data,
                      (dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)):
      raise ValueError(
          '`adapt()` requires a batched Dataset, an EagerTensor, '
          'or a Numpy array as input, '
          'got {}'.format(type(data)))
    if isinstance(data, dataset_ops.DatasetV2):
      # Validate the datasets to try and ensure we haven't been passed one with
      # infinite size. That would cause an infinite loop here.
      if tf_utils.dataset_is_infinite(data):
        raise ValueError(
            'The dataset passed to `adapt()` has an infinite number of '
            'elements. Please use dataset.take(...) to make the number '
            'of elements finite.')

    for current_layer_index in range(0, len(self.layers)):
      if not hasattr(self.layers[current_layer_index], 'adapt'):
        # Skip any layer that does not need adapting.
        continue

      def map_fn(x):
        """Maps `PreprocessingStage` inputs to inputs at `current_layer_index`.

        Args:
          x: Batch of inputs seen in entry of the `PreprocessingStage` instance.

        Returns:
          Batch of inputs to be processed by layer
            `self.layers[current_layer_index]`
        """
        if current_layer_index == 0:  # pylint: disable=cell-var-from-loop
          return x
        for i in range(current_layer_index):  # pylint: disable=cell-var-from-loop
          x = self.layers[i](x)
        return x

      if isinstance(data, dataset_ops.DatasetV2):
        current_layer_data = data.map(map_fn)
      else:
        current_layer_data = map_fn(data)
      self.layers[current_layer_index].adapt(current_layer_data,
                                             reset_state=reset_state)
示例#2
0
    def adapt(self, data, reset_state=True):
        """Fits the state of the preprocessing layer to the data being passed.

    Args:
      data: The data to train on. It can be passed either as a tf.data Dataset,
        or as a numpy array.
      reset_state: Optional argument specifying whether to clear the state of
        the layer at the start of the call to `adapt`, or whether to start from
        the existing state. Subclasses may choose to throw if reset_state is set
        to 'False'.
    """
        if reset_state:
            accumulator = None
        else:
            accumulator = self._combiner.restore(self._restore_updates())
        if isinstance(data, (list, tuple)):
            data = ops.convert_to_tensor_v2_with_dispatch(data)
        if not isinstance(data, (dataset_ops.DatasetV2, np.ndarray, ops.Tensor,
                                 ragged_tensor.RaggedTensor)):
            raise ValueError('`adapt()` requires a batched Dataset, a Tensor, '
                             'or a Numpy array as input, '
                             'got {}'.format(type(data)))

        if isinstance(data, dataset_ops.DatasetV2):
            # Validate that the dataset only contains single-tensor elements.
            if not isinstance(data.element_spec, type_spec.TypeSpec):
                raise TypeError(
                    'The dataset should yield single-Tensor elements. Use `dataset.map`'
                    'to select the element of interest.\n'
                    'Got dataset.element_spec=' + str(data.element_spec))
            # Validate the datasets to try and ensure we haven't been passed one with
            # infinite size. That would cause an infinite loop here.
            if tf_utils.dataset_is_infinite(data):
                raise ValueError(
                    'The dataset passed to `adapt()` has an infinite number of '
                    'elements. Please use `dataset.take(...)` to make the number '
                    'of elements finite.')
            next_data = self._get_dataset_iterator(data)
            # TODO(fchollet): consider checking if the dataset is already batched
            # and otherwise batching it.
        elif isinstance(data, (ops.Tensor, ragged_tensor.RaggedTensor)):
            next_data = self._get_dataset_iterator(
                dataset_ops.Dataset.from_tensor_slices(data).batch(512))
        else:
            generator, _ = training_generator_v1.convert_to_generator_like(
                data, batch_size=512)
            # If the data is not a dataset, we can iterate over it using next(foo);
            # here, we wrap that into a callable.
            next_data = lambda: next(generator)

        # TODO(momernick): Some sort of status bar?
        # TODO(momernick): Implement parallel processing here?
        try:
            data_element = next_data()

            # First, see if the layer is built or not. If it is not, then we must
            # build it.
            if not self.built:
                try:
                    # If this is a Numpy array or tensor, we can get shape from .shape.
                    # If not, an attribute error will be thrown.
                    data_shape = data_element.shape
                    data_shape_nones = tuple([None] * len(data_element.shape))
                except AttributeError:
                    # The input has an unknown number of dimensions.
                    data_shape = None
                    data_shape_nones = None

                # TODO (b/159261555): move this to base layer build.
                batch_input_shape = getattr(self, '_batch_input_shape', None)
                if batch_input_shape is None:
                    # Set the number of dimensions.
                    self._batch_input_shape = data_shape_nones

                self.build(data_shape)

            # Once we have built the Layer, we can process the input data. We do so
            # until we've gotten an exception indicating that we have no more data.
            while True:
                accumulator = self._combiner.compute(data_element, accumulator)
                data_element = next_data()
        # Note that this belongs to the outer indentation of 'try' - we need to
        # catch exceptions resulting from the first 'next_data()' invocation as
        # well.
        except (StopIteration, errors.OutOfRangeError):
            pass

        updates = self._combiner.extract(accumulator)
        self._set_state_variables(updates)
示例#3
0
  def adapt(self, data, reset_state=True):
    """Adapt the state of the layers of the preprocessing stage to the data.

    Args:
      data: A batched Dataset object, a NumPy array, an EagerTensor, or a list,
        dict or nested structure of Numpy Arrays or EagerTensors. The elements
        of Dataset object need to conform with inputs of the stage. The first
        dimension of NumPy arrays or EagerTensors are understood to be batch
        dimension. Data to be iterated over to adapt the state of the layers in
        this preprocessing stage.
      reset_state: Whether this call to `adapt` should reset the state of the
        layers in this preprocessing stage.

    Examples:

    >>> # For a stage with dict input
    >>> inputs = {'x2': tf.keras.Input(shape=(5,)),
    ...           'x1': tf.keras.Input(shape=(1,))}
    >>> outputs = [inputs['x1'], inputs['x2']]
    >>> stage = FunctionalPreprocessingStage(inputs, outputs)
    >>> ds = tf.data.Dataset.from_tensor_slices({'x1': tf.ones((4,5)),
    ...                                          'x2': tf.ones((4,1))})
    >>> sorted(ds.element_spec.items()) # Check element_spec
    [('x1', TensorSpec(shape=(5,), dtype=tf.float32, name=None)),
     ('x2', TensorSpec(shape=(1,), dtype=tf.float32, name=None))]
    >>> stage.adapt(ds)
    >>> data_np = {'x1': np.ones((4, 5)), 'x2': np.ones((4, 1))}
    >>> stage.adapt(data_np)

    """
    if not isinstance(data, dataset_ops.Dataset):
      data = self._flatten_to_reference_inputs(data)
      if any([
          not isinstance(datum, (np.ndarray, ops.EagerTensor)) for datum in data
      ]):
        raise ValueError(
            '`adapt()` requires a batched Dataset, a list of EagerTensors '
            'or Numpy arrays as input, got {}'.format(type(data)))
      ds_input = [
          dataset_ops.Dataset.from_tensor_slices(x).batch(1) for x in data
      ]

    if isinstance(data, dataset_ops.Dataset):
      # Validate the datasets to try and ensure we haven't been passed one with
      # infinite size. That would cause an infinite loop here.
      if tf_utils.dataset_is_infinite(data):
        raise ValueError(
            'The dataset passed to `adapt()` has an infinite number of '
            'elements. Please use dataset.take(...) to make the number '
            'of elements finite.')
      # Unzip dataset object to a list of single input dataset.
      ds_input = _unzip_dataset(data)

    # Dictionary mapping reference tensors to datasets
    ds_dict = {}
    tensor_usage_count = self._tensor_usage_count
    for x, y in zip(self.inputs, ds_input):
      x_id = str(id(x))
      ds_dict[x_id] = [y] * tensor_usage_count[x_id]

    nodes_by_depth = self._nodes_by_depth
    depth_keys = sorted(nodes_by_depth.keys(), reverse=True)

    def build_map_fn(node, args, kwargs):
      if not isinstance(args.element_spec, tuple):

        def map_fn(*x):
          return nest.flatten(node.layer(*x, **kwargs))
      else:

        def map_fn(*x):
          return nest.flatten(node.layer(x, **kwargs))

      return map_fn

    for depth in depth_keys:
      for node in nodes_by_depth[depth]:
        # Input node
        if node.is_input:
          continue

        # Node with input not computed yet
        if any(t_id not in ds_dict for t_id in node.flat_input_ids):
          continue

        args, kwargs = node.map_arguments(ds_dict)
        args = dataset_ops.Dataset.zip(nest.list_to_tuple(*args))

        if hasattr(node.layer, 'adapt'):
          node.layer.adapt(args, reset_state=reset_state)

        map_fn = build_map_fn(node, args, kwargs)
        outputs = args.map(map_fn)
        outputs = _unzip_dataset(outputs)

        # Update ds_dict.
        for x_id, y in zip(node.flat_output_ids, outputs):
          ds_dict[x_id] = [y] * tensor_usage_count[x_id]
示例#4
0
    def adapt(self, data, reset_state=True):
        """Fits the state of the preprocessing layer to the data being passed.

    Arguments:
      data: The data to train on. It can be passed either as a tf.data Dataset,
        or as a numpy array.
      reset_state: Optional argument specifying whether to clear the state of
        the layer at the start of the call to `adapt`, or whether to start from
        the existing state. Subclasses may choose to throw if reset_state is set
        to 'False'.
    """
        if reset_state:
            accumulator = None
        else:
            accumulator = self._combiner.restore(self._restore_updates())

        if not isinstance(
                data, (dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)):
            raise ValueError(
                '`adapt()` requires a batched Dataset, an EagerTensor, '
                'or a Numpy array as input, '
                'got {}'.format(type(data)))

        if isinstance(data, dataset_ops.DatasetV2):
            # Validate the datasets to try and ensure we haven't been passed one with
            # infinite size. That would cause an infinite loop here.
            if tf_utils.dataset_is_infinite(data):
                raise ValueError(
                    'The dataset passed to `adapt()` has an infinite number of '
                    'elements. Please use `dataset.take(...)` to make the number '
                    'of elements finite.')
            next_data = self._get_dataset_iterator(data)
        else:
            generator, _ = training_generator.convert_to_generator_like(
                data, batch_size=len(data))
            # If the data is not a dataset, we can iterate over it using next(foo);
            # here, we wrap that into a callable.
            next_data = lambda: next(generator)

        # TODO(momernick): Some sort of status bar?
        # TODO(momernick): Implement parallel processing here?
        try:
            data_element = next_data()

            # First, see if the layer is built or not. If it is not, then we must
            # build it.
            if not self.built:
                try:
                    # If this is a Numpy array or tensor, we can get shape from .shape.
                    # If not, an attribute error will be thrown (and we can assume the
                    # input data is a scalar with shape None.
                    shape = data_element.shape
                except AttributeError:
                    shape = None
                self.build(shape)

            # Once we have built the Layer, we can process the input data. We do so
            # until we've gotten an exception indicating that we have no more data.
            while True:
                accumulator = self._combiner.compute(data_element, accumulator)
                data_element = next_data()
        # Note that this belongs to the outer indentation of 'try' - we need to
        # catch exceptions resulting from the first 'next_data()' invocation as
        # well.
        except (StopIteration, errors.OutOfRangeError):
            pass

        updates = self._combiner.extract(accumulator)
        self._set_state_variables(updates)