示例#1
0
    def adapt(self, data, batch_size=None, steps=None, reset_state=True):
        """Fits the state of the preprocessing layer to the data being passed.

    Arguments:
        data: The data to train on. It can be passed either as a tf.data
          Dataset, or as a numpy array.
        batch_size: Integer or `None`.
            Number of samples per state update.
            If unspecified, `batch_size` will default to 32.
            Do not specify the `batch_size` if your data is in the
            form of datasets, generators, or `keras.utils.Sequence` instances
            (since they generate batches).
        steps: Integer or `None`.
            Total number of steps (batches of samples)
            When training with input tensors such as
            TensorFlow data tensors, the default `None` is equal to
            the number of samples in your dataset divided by
            the batch size, or 1 if that cannot be determined. If x is a
            `tf.data` dataset, and 'steps' is None, the epoch will run until
            the input dataset is exhausted. When passing an infinitely
            repeating dataset, you must specify the `steps` argument. This
            argument is not supported with array inputs.
        reset_state: Optional argument specifying whether to clear the state of
          the layer at the start of the call to `adapt`, or whether to start
          from the existing state. This argument may not be relevant to all
          preprocessing layers: a subclass of PreprocessingLayer may choose to
          throw if 'reset_state' is set to False.
    """
        _disallow_inside_tf_function('adapt')
        if not version_utils.should_use_v2():
            raise RuntimeError('`adapt` is only supported in tensorflow v2.')  # pylint: disable=g-doc-exception
        if not self.stateful:
            return
        if not self.streaming and self._is_adapted and not reset_state:
            raise ValueError(
                '{} does not supporting calling `adapt` twice without '
                'resetting the state.'.format(self.__class__.__name__))
        if not self._is_compiled:
            self.compile()  # Compile with defaults.
        if self.built and reset_state:
            self.reset_state()
        data_handler = data_adapter.DataHandler(
            data,
            batch_size=batch_size,
            steps_per_epoch=steps,
            epochs=1,
            steps_per_execution=self._steps_per_execution,
            distribute=False)
        self._adapt_function = self.make_adapt_function()
        for _, iterator in data_handler.enumerate_epochs():
            with data_handler.catch_stop_iteration():
                for _ in data_handler.steps():
                    self._adapt_function(iterator)
                    if data_handler.should_sync:
                        context.async_wait()
        self.finalize_state()
        self._is_adapted = True
 def test_finite_dataset_without_steps_per_epoch(self):
   data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1)
   data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2)
   returned_data = []
   for _, iterator in data_handler.enumerate_epochs():
     epoch_data = []
     for _ in data_handler.steps():
       epoch_data.append(next(iterator).numpy())
     returned_data.append(epoch_data)
   self.assertEqual(returned_data, [[0, 1, 2], [0, 1, 2]])
 def test_finite_dataset_with_steps_per_epoch(self):
   data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1)
   # User can choose to only partially consume `Dataset`.
   data_handler = data_adapter.DataHandler(
       data, initial_epoch=0, epochs=2, steps_per_epoch=2)
   self.assertFalse(data_handler._adapter.should_recreate_iterator())
   returned_data = []
   for _, iterator in data_handler.enumerate_epochs():
     epoch_data = []
     for _ in data_handler.steps():
       epoch_data.append(next(iterator).numpy())
     returned_data.append(epoch_data)
   self.assertEqual(returned_data, [[0, 1], [2, 3]])
 def test_list_of_scalars(self):
   data_handler = data_adapter.DataHandler([[0], [1], [2]],
                                           epochs=2,
                                           steps_per_epoch=3)
   returned_data = []
   for _, iterator in data_handler.enumerate_epochs():
     epoch_data = []
     for _ in data_handler.steps():
       epoch_data.append(next(iterator))
     returned_data.append(epoch_data)
   returned_data = self.evaluate(returned_data)
   self.assertEqual(returned_data, [[([0],), ([1],),
                                     ([2],)], [([0],), ([1],), ([2],)]])
示例#5
0
    def test_class_weight_user_errors(self):
        with self.assertRaisesRegexp(ValueError, 'to be a dict with keys'):
            data_adapter.DataHandler(
                x=[[0], [1], [2]],
                y=[[2], [1], [0]],
                batch_size=1,
                sample_weight=[[1.], [2.], [4.]],
                class_weight={
                    0: 0.5,
                    1: 1.,
                    3: 1.5  # Skips class `2`.
                })

        with self.assertRaisesRegexp(ValueError, 'with a single output'):
            data_adapter.DataHandler(x=np.ones((10, 1)),
                                     y=[np.ones((10, 1)),
                                        np.zeros((10, 1))],
                                     batch_size=2,
                                     class_weight={
                                         0: 0.5,
                                         1: 1.,
                                         2: 1.5
                                     })
 def test_finite_dataset_with_steps_per_epoch_exact_size(self):
   data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]).batch(1)
   # If user specifies exact size of `Dataset` as `steps_per_epoch`,
   # create a new iterator each epoch.
   data_handler = data_adapter.DataHandler(
       data, initial_epoch=0, epochs=2, steps_per_epoch=4)
   self.assertTrue(data_handler._adapter.should_recreate_iterator())
   returned_data = []
   for _, iterator in data_handler.enumerate_epochs():
     epoch_data = []
     for _ in data_handler.steps():
       epoch_data.append(next(iterator).numpy())
     returned_data.append(epoch_data)
   self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]])
  def test_single_x_input_no_tuple_wrapping(self, use_numpy):
    x = np.ones((10, 1))

    if use_numpy:
      batch_size = 2
    else:
      x = dataset_ops.Dataset.from_tensor_slices(x).batch(2)
      batch_size = None

    data_handler = data_adapter.DataHandler(x, batch_size=batch_size)
    for _, iterator in data_handler.enumerate_epochs():
      for _ in data_handler.steps():
        # Check that single x input is not wrapped in a tuple.
        self.assertIsInstance(next(iterator), ops.Tensor)
 def test_composite_tensor(self):
   st = sparse_tensor.SparseTensor(
       indices=[[0, 0], [1, 0], [2, 0]], values=[0, 1, 2], dense_shape=[3, 1])
   data_handler = data_adapter.DataHandler(st, epochs=2, steps_per_epoch=3)
   returned_data = []
   for _, iterator in data_handler.enumerate_epochs():
     epoch_data = []
     for _ in data_handler.steps():
       epoch_data.append(next(iterator))
     returned_data.append(epoch_data)
   returned_data = self.evaluate(
       nest.map_structure(sparse_ops.sparse_tensor_to_dense, returned_data))
   self.assertEqual(returned_data, [[([0],), ([1],),
                                     ([2],)], [([0],), ([1],), ([2],)]])
 def test_insufficient_data(self):
   ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1])
   ds = ds.filter(lambda *args, **kwargs: True)
   data_handler = data_adapter.DataHandler(
       ds, initial_epoch=0, epochs=2, steps_per_epoch=3)
   returned_data = []
   for _, iterator in data_handler.enumerate_epochs():
     epoch_data = []
     for _ in data_handler.steps():
       with data_handler.catch_stop_iteration():
         epoch_data.append(next(iterator))
     returned_data.append(epoch_data)
   returned_data = self.evaluate(returned_data)
   self.assertTrue(data_handler._insufficient_data)
   self.assertEqual(returned_data, [[0, 1]])
 def test_numpy(self):
   x = np.array([0, 1, 2])
   y = np.array([0, 2, 4])
   sw = np.array([0, 4, 8])
   data_handler = data_adapter.DataHandler(
       x=x, y=y, sample_weight=sw, batch_size=1, epochs=2)
   returned_data = []
   for _, iterator in data_handler.enumerate_epochs():
     epoch_data = []
     for _ in data_handler.steps():
       epoch_data.append(next(iterator))
     returned_data.append(epoch_data)
   returned_data = self.evaluate(returned_data)
   self.assertEqual(returned_data,
                    [[(0, 0, 0), (1, 2, 4),
                      (2, 4, 8)], [(0, 0, 0), (1, 2, 4), (2, 4, 8)]])
  def test_generator(self):

    def generator():
      for _ in range(2):
        for step in range(3):
          yield (ops.convert_to_tensor_v2_with_dispatch([step]),)

    data_handler = data_adapter.DataHandler(
        generator(), epochs=2, steps_per_epoch=3)
    returned_data = []
    for _, iterator in data_handler.enumerate_epochs():
      epoch_data = []
      for _ in data_handler.steps():
        epoch_data.append(next(iterator))
      returned_data.append(epoch_data)
    returned_data = self.evaluate(returned_data)
    self.assertEqual(returned_data, [[([0],), ([1],),
                                      ([2],)], [([0],), ([1],), ([2],)]])
  def test_unknown_cardinality_dataset_with_steps_per_epoch(self):
    ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
    filtered_ds = ds.filter(lambda x: x < 4)
    self.assertEqual(
        cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN)

    # User can choose to only partially consume `Dataset`.
    data_handler = data_adapter.DataHandler(
        filtered_ds, initial_epoch=0, epochs=2, steps_per_epoch=2)
    self.assertFalse(data_handler._adapter.should_recreate_iterator())
    returned_data = []
    for _, iterator in data_handler.enumerate_epochs():
      epoch_data = []
      for _ in data_handler.steps():
        epoch_data.append(next(iterator))
      returned_data.append(epoch_data)
    returned_data = self.evaluate(returned_data)
    self.assertEqual(returned_data, [[0, 1], [2, 3]])
  def test_unknown_cardinality_dataset_without_steps_per_epoch(self):
    ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
    filtered_ds = ds.filter(lambda x: x < 4)
    self.assertEqual(
        cardinality.cardinality(filtered_ds).numpy(), cardinality.UNKNOWN)

    data_handler = data_adapter.DataHandler(
        filtered_ds, initial_epoch=0, epochs=2)
    self.assertTrue(data_handler._adapter.should_recreate_iterator())
    returned_data = []
    for _, iterator in data_handler.enumerate_epochs():
      epoch_data = []
      with data_handler.catch_stop_iteration():
        for _ in data_handler.steps():
          epoch_data.append(next(iterator))
      returned_data.append(epoch_data)
    returned_data = self.evaluate(returned_data)
    self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]])
    self.assertEqual(data_handler._steps_per_epoch, 4)
示例#14
0
文件: core.py 项目: jxzhangjhu/ucate
 def mc_sample(self,
               x,
               batch_size=None,
               steps=None,
               max_queue_size=10,
               workers=1,
               use_multiprocessing=False):
     outputs = None
     with self.distribute_strategy.scope():
         data_handler = data_adapter.DataHandler(
             x=x,
             batch_size=batch_size,
             steps_per_epoch=steps,
             initial_epoch=0,
             epochs=1,
             max_queue_size=max_queue_size,
             workers=workers,
             use_multiprocessing=use_multiprocessing,
             model=self)
         predict_function = self.make_mc_sample_function()
         for _, iterator in data_handler.enumerate_epochs():
             with data_handler.catch_stop_iteration():
                 for step in data_handler.steps():
                     tmp_batch_outputs = predict_function(iterator)
                     if not data_handler.inferred_steps:
                         context.async_wait()
                     batch_outputs = tmp_batch_outputs
                     if outputs is None:
                         outputs = nest.map_structure(
                             lambda batch_output: [batch_output],
                             batch_outputs)
                     else:
                         nest.map_structure_up_to(
                             batch_outputs,
                             lambda output, batch_output: output.append(
                                 batch_output), outputs, batch_outputs)
     all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
     return tf_utils.to_numpy_or_python_type(all_outputs)
示例#15
0
def accuracy_aware_fit(cls_instance,
                       train_dataset,
                       compression_ctrl,
                       nncf_config,
                       callbacks,
                       initial_epoch,
                       uncompressed_model_accuracy,
                       steps_per_epoch=None,
                       batch_size=None,
                       tensorboard_writer=None,
                       log_dir=None,
                       validation_data=None,
                       validation_steps=None,
                       result_dict_to_val_metric_fn=None,
                       **kwargs):
    if result_dict_to_val_metric_fn is None:
        result_dict_to_val_metric_fn = lambda metric: metric

    with cls_instance.distribute_strategy.scope(), \
        training_utils.RespectCompiledTrainableState(cls_instance):
        # pylint: disable=protected-access
        data_handler = data_adapter.DataHandler(
            x=train_dataset,
            y=None,
            sample_weight=None,
            batch_size=batch_size,
            steps_per_epoch=steps_per_epoch,
            initial_epoch=initial_epoch,
            epochs=1,
            shuffle=True,
            class_weight=None,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False,
            model=cls_instance,
            steps_per_execution=cls_instance._steps_per_execution)

        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(
                callbacks,
                add_history=True,
                model=cls_instance,
                epochs=1,
                verbose=1,
                add_progbar=True,
                steps=data_handler.inferred_steps)

    def train_epoch_fn(compression_ctrl, model, epoch):
        model.reset_metrics()

        if model.train_function is None:
            model.train_function = model.make_train_function()
        _, iterator = next(data_handler.enumerate_epochs())

        callbacks.on_epoch_begin(epoch)
        with data_handler.catch_stop_iteration():
            for step in data_handler.steps():
                with trace.Trace('train',
                                 epoch_num=epoch,
                                 step_num=step,
                                 batch_size=None,
                                 _r=1):
                    callbacks.on_train_batch_begin(step)
                    tmp_logs = model.train_function(iterator)
                    if data_handler.should_sync:
                        context.async_wait()
                    logs = tmp_logs
                    end_step = step + data_handler.step_increment
                    callbacks.on_train_batch_end(end_step, logs)
                    if model.stop_training:
                        break

        if logs is None:
            raise ValueError('Expect x to be a non-empty array or dataset.')
        epoch_logs = copy.copy(logs)
        callbacks.on_epoch_end(epoch, epoch_logs)

    if validation_data is None:
        validation_data = train_dataset

    def validate_fn(model, epoch=None):
        val_x, val_y, val_sample_weight = (
            data_adapter.unpack_x_y_sample_weight(validation_data))
        val_logs = model.evaluate(x=val_x,
                                  y=val_y,
                                  sample_weight=val_sample_weight,
                                  batch_size=None,
                                  steps=validation_steps,
                                  callbacks=callbacks,
                                  return_dict=True)
        return result_dict_to_val_metric_fn(val_logs)

    callbacks.on_train_begin()
    cls_instance.original_model_accuracy = uncompressed_model_accuracy
    acc_aware_training_loop = create_accuracy_aware_training_loop(
        nncf_config, compression_ctrl)
    cls_instance = acc_aware_training_loop.run(
        cls_instance,
        train_epoch_fn=train_epoch_fn,
        validate_fn=validate_fn,
        tensorboard_writer=tensorboard_writer,
        log_dir=log_dir)
    callbacks.on_train_end()
示例#16
0
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):
        """ From tf.keras.Model. """
        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')
        training._disallow_inside_tf_function('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight), validation_split=validation_split))

        if validation_data:
            val_x, val_y, val_sample_weight = (
                data_adapter.unpack_x_y_sample_weight(validation_data))

        with self.distribute_strategy.scope(), \
             training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = data_adapter.DataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self,
                steps_per_execution=self._steps_per_execution)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=verbose != 0,
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            self._train_counter.assign(0)
            callbacks.on_train_begin()
            training_logs = None
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (  # pylint: disable=protected-access
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                with data_handler.catch_stop_iteration():
                    if self._update_cycle > 1:
                        self._grad_accumulator.reset()
                    for step in data_handler.steps():
                        with trace.Trace(
                            'TraceContext',
                            graph_type='train',
                            epoch_num=epoch,
                            step_num=step,
                            batch_size=batch_size):
                            callbacks.on_train_batch_begin(step)
                            if self._update_cycle > 1:
                                for _ in range(self._update_cycle - 1):
                                    self.accumulate_function(iterator)
                            tmp_logs = train_function(iterator)
                            if data_handler.should_sync:
                                context.async_wait()
                            logs = tmp_logs  # No error, now safe to assign to logs.
                            end_step = step + data_handler.step_increment
                            callbacks.on_train_batch_end(end_step, logs)
                epoch_logs = copy.copy(logs)

                # Run validation.
                if validation_data and self._should_eval(epoch, validation_freq):
                    # Create data_handler for evaluation and cache it.
                    if getattr(self, '_eval_data_handler', None) is None:
                        self._eval_data_handler = data_adapter.DataHandler(
                            x=val_x,
                            y=val_y,
                            sample_weight=val_sample_weight,
                            batch_size=validation_batch_size or batch_size,
                            steps_per_epoch=validation_steps,
                            initial_epoch=0,
                            epochs=1,
                            max_queue_size=max_queue_size,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing,
                            model=self,
                            steps_per_execution=self._steps_per_execution)
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {'val_' + name: val for name, val in val_logs.items()}
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                training_logs = epoch_logs
                if self.stop_training:
                    break

            # If eval data_hanlder exists, delete it after all epochs are done.
            if getattr(self, '_eval_data_handler', None) is not None:
                del self._eval_data_handler
            callbacks.on_train_end(logs=training_logs)
            return self.history
    def adapt(self, data, batch_size=None, steps=None, reset_state=True):
        """Fits the state of the preprocessing layer to the data being passed.

    After calling `adapt` on a layer, a preprocessing layer's state will not
    update during training. In order to make preprocessing layers efficient in
    any distribution context, they are kept constant with respect to any
    compiled `tf.Graph`s that call the layer. This does not affect the layer use
    when adapting each layer only once, but if you adapt a layer multiple times
    you will need to take care to re-compile any compiled functions as follows:

     * If you are adding a preprocessing layer to a `keras.Model`, you need to
       call `model.compile` after each subsequent call to `adapt`.
     * If you are calling a preprocessing layer inside `tf.data.Dataset.map`,
       you should call `map` again on the input `tf.data.Dataset` after each
       `adapt`.
     * If you are using a `tf.function` directly which calls a preprocessing
       layer, you need to call `tf.function` again on your callable after
       each subsequent call to `adapt`.

    `tf.keras.Model` example with multiple adapts:

    >>> layer = tf.keras.layers.experimental.preprocessing.Normalization()
    >>> layer.adapt([0, 2])
    >>> model = tf.keras.Sequential(layer)
    >>> model.predict([0, 1, 2])
    array([[-1.],
           [ 0.],
           [ 1.]], dtype=float32)
    >>> layer.adapt([-1, 1])
    >>> model.compile() # This is needed to re-compile model.predict!
    >>> model.predict([0, 1, 2])
    array([[0.],
           [1.],
           [2.]], dtype=float32)

    `tf.data.Dataset` example with multiple adapts:

    >>> layer = tf.keras.layers.experimental.preprocessing.Normalization()
    >>> layer.adapt([0, 2])
    >>> input_ds = tf.data.Dataset.range(3)
    >>> normalized_ds = input_ds.map(layer)
    >>> list(normalized_ds.as_numpy_iterator())
    [array([[-1.]], dtype=float32),
     array([[0.]], dtype=float32),
     array([[1.]], dtype=float32)]
    >>> layer.adapt([-1, 1])
    >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset.
    >>> list(normalized_ds.as_numpy_iterator())
    [array([[0.]], dtype=float32),
     array([[1.]], dtype=float32),
     array([[2.]], dtype=float32)]

    Arguments:
        data: The data to train on. It can be passed either as a tf.data
          Dataset, or as a numpy array.
        batch_size: Integer or `None`.
            Number of samples per state update.
            If unspecified, `batch_size` will default to 32.
            Do not specify the `batch_size` if your data is in the
            form of datasets, generators, or `keras.utils.Sequence` instances
            (since they generate batches).
        steps: Integer or `None`.
            Total number of steps (batches of samples)
            When training with input tensors such as
            TensorFlow data tensors, the default `None` is equal to
            the number of samples in your dataset divided by
            the batch size, or 1 if that cannot be determined. If x is a
            `tf.data` dataset, and 'steps' is None, the epoch will run until
            the input dataset is exhausted. When passing an infinitely
            repeating dataset, you must specify the `steps` argument. This
            argument is not supported with array inputs.
        reset_state: Optional argument specifying whether to clear the state of
          the layer at the start of the call to `adapt`, or whether to start
          from the existing state. This argument may not be relevant to all
          preprocessing layers: a subclass of PreprocessingLayer may choose to
          throw if 'reset_state' is set to False.
    """
        _disallow_inside_tf_function('adapt')
        if not version_utils.should_use_v2():
            raise RuntimeError('`adapt` is only supported in tensorflow v2.')  # pylint: disable=g-doc-exception
        if not self.streaming and self._is_adapted and not reset_state:
            raise ValueError(
                '{} does not supporting calling `adapt` twice without '
                'resetting the state.'.format(self.__class__.__name__))
        if not self._is_compiled:
            self.compile()  # Compile with defaults.
        if self.built and reset_state:
            self.reset_state()
        data_handler = data_adapter.DataHandler(
            data,
            batch_size=batch_size,
            steps_per_epoch=steps,
            epochs=1,
            steps_per_execution=self._steps_per_execution,
            distribute=False)
        self._adapt_function = self.make_adapt_function()
        for _, iterator in data_handler.enumerate_epochs():
            with data_handler.catch_stop_iteration():
                for _ in data_handler.steps():
                    self._adapt_function(iterator)
                    if data_handler.should_sync:
                        context.async_wait()
        self.finalize_state()
        self._is_adapted = True
示例#18
0
    def predict_dataset(self, dataset: tf.data.Dataset) -> Iterable[Sample]:
        """
        Apply the prediction model on the tf.data.Dataset. No pre- or post-processing will be applied.

        Args:
            dataset: The tf.data.Dataset with one dictionary as output or two if PredictionParams.include_targets = True

        Returns:
            The raw predicted Samples of the model

        See Also:
            - predict
            - predict_pipeline
            - predict_raw
        """
        if self._keras_model is None:
            raise ValueError("No model set. Call predictor.set_model(model)")

        keras_model = self._keras_model

        class WrappedModel(tf.keras.models.Model):
            def __init__(self, with_targets, **kwargs):
                super().__init__(**kwargs)
                self.with_targets = with_targets

            def call(self, inputs, training=None, mask=None):
                if self.with_targets:
                    inputs, targets, meta = inputs
                    return inputs, targets, keras_model(inputs), meta
                else:
                    inputs, meta = inputs
                    return inputs, keras_model(inputs), meta

            def get_config(self):
                raise NotImplementedError

        # wrap model so that it outputs inputs, meta and optionally the targets
        wrapped_model = WrappedModel(self.params.include_targets)
        wrapped_model.compile(run_eagerly=self.params.run_eagerly)

        if self._params.include_targets:
            dataset = dataset.map(lambda i, t, m: ((i, t, m), ))
        else:
            dataset = dataset.map(lambda i, m: ((i, m), ))

        with MeasureTime() as total_time:
            # The following code is copied from keras.model.Model.predict
            # It sets up the distribution strategy, the DataSet (here DataHandler)
            # Then one epoch is iterated until catch_stop_iteration() is reached
            with self._keras_model.distribute_strategy.scope():
                data_handler = data_adapter.DataHandler(dataset)
                predict_function = wrapped_model.make_predict_function()
                for _, iterator in data_handler.enumerate_epochs(
                ):  # Single epoch.
                    with data_handler.catch_stop_iteration():
                        for _ in data_handler.steps():
                            with MeasureTime() as batch_time:
                                r = predict_function(
                                    iterator)  # hack to access inputs

                                # If targets are included, the return value differs
                                if self._params.include_targets:
                                    inputs, targets, outputs, meta = sync_to_numpy_or_python_type(
                                        r)
                                else:
                                    inputs, outputs, meta = sync_to_numpy_or_python_type(
                                        r)
                                    targets = {
                                    }  # No targets in normal prediction

                                # split into single samples
                                try:
                                    batch_size = tf.nest.flatten(
                                        inputs)[0].shape[0]
                                except StopIteration as e:
                                    raise ValueError(
                                        f"Empty inputs {inputs}. This should never occur!"
                                    ) from e
                                for sample in self._unwrap_batch(
                                        inputs, targets, outputs, meta):
                                    self._on_sample_end(sample)
                                    yield sample

                            # Some Benchmarks
                            self.benchmark_results.finish_batch(
                                batch_size, batch_time.duration)
                            self._on_step_end(
                                Sample(inputs=inputs,
                                       outputs=outputs,
                                       targets=targets,
                                       meta=meta))

        # Overall Benchmarks
        self.benchmark_results.finish_epoch(total_time.duration)
        self._on_predict_end()