Ejemplo n.º 1
0
 def predict(self,
             model,
             x,
             batch_size=None,
             verbose=0,
             steps=None,
             callbacks=None,
             **kwargs):
   """Predict loop for Distribution Strategies."""
   dist_utils.validate_inputs(x=x, y=None)
   batch_size, steps = dist_utils.process_batch_and_step_size(
       model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT)
   batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
   dataset = model._distribution_standardize_user_data(
       x,
       batch_size=batch_size,
       allow_partial_batch=True)
   if dist_utils.is_tpu_strategy(model._distribution_strategy):
     steps = training_utils.infer_steps_for_dataset(
         model, dataset, steps, steps_name='steps')
     if steps is None:
       raise ValueError('Number of steps could not be inferred from the data, '
                        'please pass the steps argument.')
     if not context.executing_eagerly():
       return experimental_tpu_predict_loop(
           model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
   return training_arrays.predict_loop(
       model,
       dataset,
       batch_size=batch_size,
       verbose=verbose,
       steps=steps,
       callbacks=callbacks)
Ejemplo n.º 2
0
    def evaluate(self,
                 model,
                 x=None,
                 y=None,
                 batch_size=None,
                 verbose=1,
                 sample_weight=None,
                 steps=None,
                 callbacks=None,
                 **kwargs):
        """Evaluate loop for Distribution Strategies."""
        dist_utils.validate_inputs(x, y)
        batch_size, steps = dist_utils.process_batch_and_step_size(
            model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST)
        batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
        dataset = model._distribution_standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            batch_size=batch_size,
            allow_partial_batch=True)

        if dist_utils.is_tpu_strategy(model._distribution_strategy):
            steps = training_utils.infer_steps_for_dataset(model,
                                                           dataset,
                                                           steps,
                                                           steps_name='steps')
            if steps is None:
                raise ValueError(
                    'Number of steps could not be inferred from the data, '
                    'please pass the steps argument.')

            if not context.executing_eagerly():
                # Run TPU evaluation in a custom loop in graph mode.
                return experimental_tpu_test_loop(model,
                                                  dataset,
                                                  verbose=verbose,
                                                  steps=steps,
                                                  callbacks=callbacks)

        return training_arrays_v1.test_loop(model,
                                            inputs=dataset,
                                            batch_size=batch_size,
                                            verbose=verbose,
                                            steps=steps,
                                            callbacks=callbacks)
Ejemplo n.º 3
0
  def _model_iteration(
      self, model, mode, x=None, y=None, batch_size=None, verbose=1,
      sample_weight=None, steps=None, callbacks=None, max_queue_size=10,
      workers=1, use_multiprocessing=False, **kwargs):

    batch_size = model._validate_or_infer_batch_size(
        batch_size, steps, x)
    strategy = _get_distribution_strategy(model)
    batch_size, steps = dist_utils.process_batch_and_step_size(
        strategy, x, batch_size, steps, mode)
    dist_utils.validate_callbacks(input_callbacks=callbacks,
                                  optimizer=model.optimizer)
    # Enter tf.distribute.Strategy scope.
    with strategy.scope():
      adapter = _process_inputs(
          model,
          mode,
          x,
          y,
          batch_size=batch_size,
          sample_weights=sample_weight,
          steps=steps,
          distribution_strategy=strategy,
          max_queue_size=max_queue_size,
          workers=workers,
          use_multiprocessing=use_multiprocessing)
      total_samples = _get_total_number_of_samples(adapter)
      use_sample = total_samples is not None
      dataset = adapter.get_dataset()

      if not steps:
        # Raise an error if `steps` isn't specified but the dataset
        # is infinite.
        steps = adapter.get_size() or training_utils.infer_steps_for_dataset(
            model, dataset, steps, steps_name='steps')

      # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
      training_context = TrainingContext()
      dataset = strategy.experimental_distribute_dataset(dataset)

      execution_function = training_v2_utils._get_or_make_execution_function(
          model, mode)

      data_iterator = iter(dataset)

      callbacks = cbks.configure_callbacks(
          callbacks,
          model,
          do_validation=False,
          batch_size=batch_size,
          epochs=1,
          steps_per_epoch=steps,
          samples=use_sample,
          count_mode='samples' if use_sample else 'steps',
          verbose=0,  # Handle ProgBarLogger separately in this loop.
          mode=mode)

      with training_context.on_start(
          model, callbacks, use_sample, verbose, mode):
        with training_context.on_epoch(0, mode) as epoch_logs:
          model.reset_metrics()
          result = run_one_epoch(
              model,
              data_iterator,
              execution_function,
              dataset_size=adapter.get_size(),
              batch_size=adapter.batch_size(),
              strategy=strategy,
              steps_per_epoch=steps,
              num_samples=total_samples,
              mode=mode,
              training_context=training_context,
              total_epochs=1)
          cbks.make_logs(model, epoch_logs, result, mode)

    if len(result) == 1:
      result = result[0]
    return result
Ejemplo n.º 4
0
  def fit(
      self, model, x=None, y=None, batch_size=None, epochs=1, verbose=1,
      callbacks=None, validation_split=0., validation_data=None, shuffle=True,
      class_weight=None, sample_weight=None, initial_epoch=0,
      steps_per_epoch=None, validation_steps=None, validation_freq=1,
      max_queue_size=10, workers=1, use_multiprocessing=False, **kwargs):
    batch_size = model._validate_or_infer_batch_size(
        batch_size, steps_per_epoch, x)

    strategy = _get_distribution_strategy(model)
    batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
        strategy,
        x,
        batch_size,
        steps_per_epoch,
        ModeKeys.TRAIN,
        validation_split=validation_split)
    dist_utils.validate_callbacks(input_callbacks=callbacks,
                                  optimizer=model.optimizer)
    # Enter tf.distribute.Strategy scope.
    with strategy.scope():
      training_data_adapter, validation_adapter = _process_training_inputs(
          model,
          x,
          y,
          batch_size=batch_size,
          epochs=epochs,
          sample_weights=sample_weight,
          class_weights=class_weight,
          validation_split=validation_split,
          steps_per_epoch=steps_per_epoch,
          shuffle=shuffle,
          validation_data=validation_data,
          validation_steps=validation_steps,
          distribution_strategy=strategy,
          max_queue_size=max_queue_size,
          workers=workers,
          use_multiprocessing=use_multiprocessing)

      total_samples = _get_total_number_of_samples(training_data_adapter)
      use_sample = total_samples is not None
      do_validation = (validation_adapter is not None)

      recreate_training_iterator = (
          training_data_adapter.should_recreate_iterator(steps_per_epoch))
      if not steps_per_epoch:
        # TODO(b/139762795): Add step inference for when steps is None to
        # prevent end of sequence warning message.
        steps_per_epoch = training_data_adapter.get_size()

      # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
      training_context = TrainingContext()

      training_dataset = training_data_adapter.get_dataset()
      # Raise an error if steps_per_epoch isn't specified but the dataset
      # is infinite.
      # TODO(scottzhu): This check should probably happen in the adapter
      inferred_steps = training_utils.infer_steps_for_dataset(
          model,
          training_dataset,
          steps_per_epoch,
          steps_name='steps_per_epoch',
          epochs=0)

      steps_per_epoch = (
          inferred_steps if steps_per_epoch is None else steps_per_epoch)

      training_dataset = strategy.experimental_distribute_dataset(
          training_dataset)

      training_function = training_v2_utils._get_or_make_execution_function(
          model, ModeKeys.TRAIN)

      training_data_iter = None
      if do_validation:
        validation_dataset = validation_adapter.get_dataset()
        if not validation_steps:
          # Raise an error if validation_steps isn't specified but the
          # validation dataset is infinite.
          validation_steps = (
              validation_adapter.get_size() or
              training_utils.infer_steps_for_dataset(
                  model,
                  validation_dataset,
                  validation_steps,
                  steps_name='validation_steps'))
        eval_function = training_v2_utils._get_or_make_execution_function(
            model, ModeKeys.TEST)
        eval_data_iter = None
        validation_dataset = strategy.experimental_distribute_dataset(
            validation_dataset)
        val_total_samples = _get_total_number_of_samples(validation_adapter)
      else:
        val_total_samples = None

      if verbose and (total_samples or steps_per_epoch):
        _print_train_info(total_samples, steps_per_epoch, val_total_samples,
                          validation_steps)

      training_callbacks = cbks.configure_callbacks(
          callbacks,
          model,
          do_validation=do_validation,
          batch_size=batch_size,
          epochs=epochs,
          steps_per_epoch=steps_per_epoch,
          samples=total_samples or steps_per_epoch,
          count_mode='samples' if use_sample else 'steps',
          verbose=0,  # Handle ProgBarLogger separately in this loop.
          mode=ModeKeys.TRAIN)

      with training_context.on_start(model, training_callbacks, use_sample,
                                     verbose, ModeKeys.TRAIN):

        initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
            initial_epoch, ModeKeys.TRAIN)

        for epoch in range(initial_epoch, epochs):
          if training_context.callbacks.model.stop_training:
            break

          # Training
          with training_context.on_epoch(epoch, ModeKeys.TRAIN) as epoch_logs:
            model.reset_metrics()
            if training_data_iter is None or recreate_training_iterator:
              if (training_data_iter is not None and
                  distribution_strategy_context.has_strategy()):
                # TODO(kaftan): remove this when MultiDeviceIterator is a
                ## compositetensor (unless this is more efficient)
                training_data_iter._initializer  # pylint: disable=pointless-statement
              else:
                training_data_iter = iter(training_dataset)

            training_result = run_one_epoch(
                model,
                training_data_iter,
                training_function,
                dataset_size=training_data_adapter.get_size(),
                batch_size=training_data_adapter.batch_size(),
                strategy=strategy,
                steps_per_epoch=steps_per_epoch,
                num_samples=total_samples,
                mode=ModeKeys.TRAIN,
                training_context=training_context,
                total_epochs=epochs)
            cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)

            # In the case of steps_per_epoch = None, the final cardinality will
            # be determined when the inputs are fully consumed (eg dataset or
            # generator). Update the steps_per_epoch to the new value.
            if (steps_per_epoch is None
                and training_context.progbar.progbar.target is not None):
              steps_per_epoch = training_context.progbar.progbar.target

            # Evaluation
            if (do_validation and
                training_utils.should_run_validation(validation_freq, epoch) and
                not training_callbacks.model.stop_training):
              if (eval_data_iter is not None and
                  distribution_strategy_context.has_strategy()):
                # TODO(kaftan): remove this when MultiDeviceIterator is a
                ## compositetensor (unless this is more efficient)
                eval_data_iter._initializer  # pylint: disable=pointless-statement
              else:
                eval_data_iter = iter(validation_dataset)

              validation_callbacks = cbks.configure_callbacks(
                  training_callbacks,
                  model,
                  batch_size=batch_size,
                  epochs=1,
                  steps_per_epoch=validation_steps,
                  samples=val_total_samples or validation_steps,
                  count_mode='samples' if use_sample else 'steps',
                  verbose=0,  # Handle ProgBarLogger separately in this loop.
                  mode=ModeKeys.TEST)

              eval_context = TrainingContext()
              with eval_context.on_start(
                  model,
                  validation_callbacks,
                  use_sample,
                  verbose=0,
                  mode=ModeKeys.TEST):
                with eval_context.on_epoch(epoch, ModeKeys.TEST):
                  model.reset_metrics()
                  eval_result = run_one_epoch(
                      model,
                      eval_data_iter,
                      eval_function,
                      dataset_size=validation_adapter.get_size(),
                      batch_size=validation_adapter.batch_size(),
                      strategy=strategy,
                      steps_per_epoch=validation_steps,
                      num_samples=val_total_samples,
                      mode=ModeKeys.TEST,
                      training_context=eval_context,
                      total_epochs=1)
                  cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
                                 prefix='val_')

    return model.history
Ejemplo n.º 5
0
    def fit(self,
            model,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            **kwargs):
        batch_size = model._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x)

        strategy = _get_distribution_strategy(model)
        batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
            strategy, x, batch_size, steps_per_epoch, ModeKeys.TRAIN)
        dist_utils.validate_callbacks(input_callbacks=callbacks,
                                      optimizer=model.optimizer)
        # Enter tf.distribute.Strategy scope.
        with strategy.scope():
            training_data_adapter, validation_adapter = _process_training_inputs(
                model,
                x,
                y,
                batch_size=batch_size,
                sample_weights=sample_weight,
                class_weights=class_weight,
                validation_split=validation_split,
                steps_per_epoch=steps_per_epoch,
                shuffle=shuffle,
                validation_data=validation_data,
                validation_steps=validation_steps,
                distribution_strategy=strategy)

            total_samples = _get_total_number_of_samples(training_data_adapter)
            use_sample = total_samples is not None
            do_validation = (validation_adapter is not None)

            if not steps_per_epoch:
                steps_per_epoch = training_data_adapter.get_size()

            # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
            training_context = TrainingContext()

            initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
                initial_epoch, ModeKeys.TRAIN)

            training_dataset = training_data_adapter.get_dataset()
            # Raise an error if steps_per_epoch isn't specified but the dataset
            # is infinite.
            # TODO(scottzhu): This check should probably happen in the adapter
            training_utils.infer_steps_for_dataset(
                training_dataset,
                steps_per_epoch,
                steps_name='steps_per_epoch',
                epochs=0)

            training_dataset = strategy.experimental_distribute_dataset(
                training_dataset)

            _update_sample_weight_mode(model, ModeKeys.TRAIN, training_dataset)
            training_function = training_v2_utils._get_or_make_execution_function(
                model, ModeKeys.TRAIN)

            training_data_iter = None
            # Only recreate iterator when the data has a fixed length, which will be
            # fully consumed every epoch, or has a unknown length (dataset, generator)
            # and will be fully consumed (steps_per_epoch is None)
            recreate_training_iterator = (training_data_adapter.get_size()
                                          is not None
                                          or steps_per_epoch is None)

            if do_validation:
                if not validation_steps:
                    validation_steps = validation_adapter.get_size()
                eval_function = training_v2_utils._get_or_make_execution_function(
                    model, ModeKeys.TEST)
                eval_data_iter = None

                validation_dataset = validation_adapter.get_dataset()
                # Raise an error if validation_steps isn't specified but the validation
                # dataset is infinite.
                # TODO(scottzhu): This check should probably happen in the adapter
                training_utils.infer_steps_for_dataset(
                    validation_dataset,
                    validation_steps,
                    steps_name='validation_steps',
                    epochs=0)
                validation_dataset = strategy.experimental_distribute_dataset(
                    validation_dataset)

            callbacks = cbks.configure_callbacks(
                callbacks,
                model,
                do_validation=do_validation,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                samples=total_samples,
                count_mode='samples' if use_sample else 'steps',
                verbose=0,  # Handle ProgBarLogger separately in this loop.
                mode=ModeKeys.TRAIN)

            with training_context.on_start(model, callbacks, use_sample,
                                           verbose, ModeKeys.TRAIN):
                # TODO(scottzhu): Handle TPUStrategy training loop
                for epoch in range(initial_epoch, epochs):
                    if training_context.callbacks.model.stop_training:
                        break

                    # Training
                    with training_context.on_epoch(
                            epoch, ModeKeys.TRAIN) as epoch_logs:
                        model.reset_metrics()
                        if training_data_iter is None or recreate_training_iterator:
                            if (training_data_iter is not None
                                    and distribution_strategy_context.
                                    has_strategy()):
                                # TODO(kaftan): remove this when MultiDeviceIterator is a
                                ## compositetensor (unless this is more efficient)
                                training_data_iter._initializer  # pylint: disable=pointless-statement
                            else:
                                training_data_iter = iter(training_dataset)

                        training_result = run_one_epoch(
                            model,
                            training_data_iter,
                            training_function,
                            dataset_size=training_data_adapter.get_size(),
                            batch_size=training_data_adapter.batch_size(),
                            strategy=strategy,
                            steps_per_epoch=steps_per_epoch,
                            num_samples=total_samples,
                            mode=ModeKeys.TRAIN,
                            training_context=training_context,
                            total_epochs=epochs)
                        cbks.make_logs(model, epoch_logs, training_result,
                                       ModeKeys.TRAIN)

                        # Evaluation
                        if (do_validation
                                and training_utils.should_run_validation(
                                    validation_freq, epoch)
                                and not callbacks.model.stop_training):
                            if (eval_data_iter is not None
                                    and distribution_strategy_context.
                                    has_strategy()):
                                # TODO(kaftan): remove this when MultiDeviceIterator is a
                                ## compositetensor (unless this is more efficient)
                                eval_data_iter._initializer  # pylint: disable=pointless-statement
                            else:
                                eval_data_iter = iter(validation_dataset)

                            val_total_samples = _get_total_number_of_samples(
                                validation_adapter)
                            eval_context = TrainingContext()
                            with eval_context.on_start(model,
                                                       callbacks,
                                                       use_sample,
                                                       verbose=0,
                                                       mode=ModeKeys.TEST):
                                with eval_context.on_epoch(
                                        epoch, ModeKeys.TEST):
                                    model.reset_metrics()
                                    eval_result = run_one_epoch(
                                        model,
                                        eval_data_iter,
                                        eval_function,
                                        dataset_size=validation_adapter.
                                        get_size(),
                                        batch_size=validation_adapter.
                                        batch_size(),
                                        strategy=strategy,
                                        steps_per_epoch=validation_steps,
                                        num_samples=val_total_samples,
                                        mode=ModeKeys.TEST,
                                        training_context=eval_context,
                                        total_epochs=1)
                                    cbks.make_logs(model,
                                                   epoch_logs,
                                                   eval_result,
                                                   ModeKeys.TEST,
                                                   prefix='val_')

        return model.history
Ejemplo n.º 6
0
    def fit(self,
            model,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            **kwargs):
        """Fit loop for Distribution Strategies."""
        dist_utils.validate_callbacks(input_callbacks=callbacks,
                                      optimizer=model.optimizer)
        dist_utils.validate_inputs(x, y)

        batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
            model._distribution_strategy, x, batch_size, steps_per_epoch,
            ModeKeys.TRAIN)
        batch_size = model._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x)
        dataset = model._distribution_standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            class_weight=class_weight,
            batch_size=batch_size,
            validation_split=validation_split,
            shuffle=shuffle,
            epochs=epochs)
        if not dist_utils.is_distributing_by_cloning(model):
            with model._distribution_strategy.scope():
                (dataset, _, _) = model._standardize_user_data(
                    dataset,
                    sample_weight=sample_weight,
                    class_weight=class_weight,
                    batch_size=batch_size,
                    validation_split=validation_split,
                    shuffle=shuffle)

        val_dataset = None
        if validation_data:
            val_x, val_y, val_sample_weights = training_utils.unpack_validation_data(
                validation_data)
            dist_utils.validate_inputs(val_x, val_y)
            _, validation_steps = dist_utils.process_batch_and_step_size(
                model._distribution_strategy, val_x, batch_size,
                validation_steps, ModeKeys.TEST)

            val_dataset = model._distribution_standardize_user_data(
                val_x,
                val_y,
                sample_weight=val_sample_weights,
                class_weight=None,
                batch_size=batch_size,
                validation_split=validation_split,
                shuffle=shuffle,
                allow_partial_batch=True)
        elif validation_split:
            raise ValueError('validation_split argument is not supported with '
                             'distribution strategies.')

        if dist_utils.is_tpu_strategy(model._distribution_strategy):
            steps_per_epoch = training_utils.infer_steps_for_dataset(
                dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
            if steps_per_epoch is None:
                raise ValueError(
                    'Number of steps could not be inferred from the data, '
                    'please pass the steps_per_epoch argument.')

            if not context.executing_eagerly():
                # Run TPU training in a custom loop in graph mode.
                return experimental_tpu_fit_loop(
                    model,
                    dataset,
                    epochs=epochs,
                    verbose=verbose,
                    callbacks=callbacks,
                    val_dataset=val_dataset,
                    initial_epoch=initial_epoch,
                    steps_per_epoch=steps_per_epoch,
                    validation_steps=validation_steps,
                    validation_freq=validation_freq)

        return training_arrays.fit_loop(model,
                                        dataset,
                                        batch_size=batch_size,
                                        epochs=epochs,
                                        verbose=verbose,
                                        callbacks=callbacks,
                                        val_inputs=val_dataset,
                                        shuffle=shuffle,
                                        initial_epoch=initial_epoch,
                                        steps_per_epoch=steps_per_epoch,
                                        validation_steps=validation_steps,
                                        validation_freq=validation_freq,
                                        steps_name='steps_per_epoch')
Ejemplo n.º 7
0
    def _model_iteration(self,
                         model,
                         mode,
                         x=None,
                         y=None,
                         batch_size=None,
                         verbose=1,
                         sample_weight=None,
                         steps=None,
                         callbacks=None,
                         **kwargs):

        batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
        strategy = _get_distribution_strategy(model)
        if strategy:
            batch_size, steps = dist_utils.process_batch_and_step_size(
                strategy, x, batch_size, steps, mode)
            dist_utils.validate_callbacks(input_callbacks=callbacks,
                                          optimizer=model.optimizer)
            # Enter tf.distribute.Strategy scope.
            scope = dist_utils.distributed_scope(strategy=strategy,
                                                 learning_phase=0)
            scope.__enter__()

        adapter = _process_inputs(model,
                                  x,
                                  y,
                                  batch_size=batch_size,
                                  sample_weights=sample_weight,
                                  steps=steps,
                                  distribution_strategy=strategy)

        if not steps:
            steps = adapter.get_size()

        # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
        training_context = TrainingContext()

        _update_sample_weight_mode(model, mode, adapter)
        execution_function = _make_execution_function(model, mode)
        data_iterator = _create_dataset_iterator(strategy,
                                                 adapter.get_dataset())

        callbacks = cbks.configure_callbacks(
            callbacks,
            model,
            do_validation=False,
            batch_size=batch_size,
            epochs=1,
            steps_per_epoch=steps,
            samples=None,
            verbose=0,  # Handle ProgBarLogger separately in this loop.
            mode=mode)

        with training_context.on_start(model, callbacks, verbose, mode):
            # TODO(scottzhu): Handle TPUStrategy training loop
            with training_context.on_epoch(0, mode) as epoch_logs:
                model.reset_metrics()
                result = run_one_epoch(model,
                                       data_iterator,
                                       execution_function,
                                       dataset_size=adapter.get_size(),
                                       strategy=strategy,
                                       steps_per_epoch=steps,
                                       mode=mode,
                                       training_context=training_context,
                                       current_epoch=1)
                cbks.make_logs(model, epoch_logs, result, mode)

        if strategy:
            scope.__exit__(None, None, None)

        if len(result) == 1:
            result = result[0]
        return result
Ejemplo n.º 8
0
    def fit(self,
            model,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            **kwargs):
        batch_size = model._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x)

        strategy = _get_distribution_strategy(model)
        if strategy:
            batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
                strategy, x, batch_size, steps_per_epoch, ModeKeys.TRAIN)
            dist_utils.validate_callbacks(input_callbacks=callbacks,
                                          optimizer=model.optimizer)
            # Enter tf.distribute.Strategy scope.
            scope = dist_utils.distributed_scope(strategy=strategy,
                                                 learning_phase=1)
            scope.__enter__()

        training_data_adapter, validation_adapter = _process_training_inputs(
            model,
            x,
            y,
            batch_size=batch_size,
            sample_weights=sample_weight,
            class_weights=class_weight,
            validation_split=validation_split,
            steps_per_epoch=steps_per_epoch,
            shuffle=shuffle,
            validation_data=validation_data,
            validation_steps=validation_steps,
            distribution_strategy=strategy)

        do_validation = (validation_adapter is not None)

        if not steps_per_epoch:
            steps_per_epoch = training_data_adapter.get_size()

        # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
        training_context = TrainingContext()

        initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
            initial_epoch, ModeKeys.TRAIN)

        _update_sample_weight_mode(model, ModeKeys.TRAIN,
                                   training_data_adapter)
        training_function = _make_execution_function(model, ModeKeys.TRAIN)

        training_data_iter = None
        # Only recreate iterator when the data has a fixed length, which will be
        # fully consumed every epoch, or has a unknown length (dataset, generator)
        # and will be fully consumed (steps_per_epoch is None)
        recreate_training_iterator = (training_data_adapter.get_size()
                                      is not None or steps_per_epoch is None)

        if do_validation:
            if not validation_steps:
                validation_steps = validation_adapter.get_size()
            eval_function = _make_execution_function(model, ModeKeys.TEST)
            eval_data_iter = None
            recreate_eval_iterator = (validation_adapter.get_size() is not None
                                      or validation_steps is None)

        callbacks = cbks.configure_callbacks(
            callbacks,
            model,
            do_validation=do_validation,
            batch_size=batch_size,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
            samples=None,
            verbose=0,  # Handle ProgBarLogger separately in this loop.
            mode=ModeKeys.TRAIN)

        with training_context.on_start(model, callbacks, verbose,
                                       ModeKeys.TRAIN):
            # TODO(scottzhu): Handle TPUStrategy training loop
            for epoch in range(initial_epoch, epochs):
                if training_context.callbacks.model.stop_training:
                    break

                # Training
                with training_context.on_epoch(epoch,
                                               ModeKeys.TRAIN) as epoch_logs:
                    model.reset_metrics()
                    if training_data_iter is None or recreate_training_iterator:
                        training_data_iter = _create_dataset_iterator(
                            strategy, training_data_adapter.get_dataset())

                    training_result = run_one_epoch(
                        model,
                        training_data_iter,
                        training_function,
                        dataset_size=training_data_adapter.get_size(),
                        strategy=strategy,
                        steps_per_epoch=steps_per_epoch,
                        mode=ModeKeys.TRAIN,
                        training_context=training_context,
                        current_epoch=epoch)
                    cbks.make_logs(model, epoch_logs, training_result,
                                   ModeKeys.TRAIN)

                    # Evaluation
                    if (do_validation and training_utils.should_run_validation(
                            validation_freq, epoch)
                            and not callbacks.model.stop_training):
                        if eval_data_iter is None or recreate_eval_iterator:
                            eval_data_iter = _create_dataset_iterator(
                                strategy, validation_adapter.get_dataset())
                        eval_context = TrainingContext()
                        with eval_context.on_start(model,
                                                   callbacks,
                                                   verbose=0,
                                                   mode=ModeKeys.TEST):
                            with eval_context.on_epoch(epoch, ModeKeys.TEST):
                                model.reset_metrics()
                                eval_result = run_one_epoch(
                                    model,
                                    eval_data_iter,
                                    eval_function,
                                    dataset_size=validation_adapter.get_size(),
                                    strategy=strategy,
                                    steps_per_epoch=validation_steps,
                                    mode=ModeKeys.TEST,
                                    training_context=eval_context,
                                    current_epoch=epochs)
                                cbks.make_logs(model,
                                               epoch_logs,
                                               eval_result,
                                               ModeKeys.TRAIN,
                                               prefix='val_')

        if strategy:
            scope.__exit__(None, None, None)

        return model.history
Ejemplo n.º 9
0
  def _model_iteration(
      self, model, mode, x=None, y=None, batch_size=None, verbose=1,
      sample_weight=None, steps=None, callbacks=None, **kwargs):

    batch_size = model._validate_or_infer_batch_size(
        batch_size, steps, x)
    strategy = _get_distribution_strategy(model)
    batch_size, steps = dist_utils.process_batch_and_step_size(
        strategy, x, batch_size, steps, mode)
    dist_utils.validate_callbacks(input_callbacks=callbacks,
                                  optimizer=model.optimizer)
    # Enter tf.distribute.Strategy scope.
    with dist_utils.distributed_scope(
        strategy=strategy, learning_phase=0):

      adapter = _process_inputs(
          model,
          x,
          y,
          batch_size=batch_size,
          sample_weights=sample_weight,
          steps=steps,
          distribution_strategy=strategy)

      if not steps:
        steps = adapter.get_size()

      # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
      training_context = TrainingContext()

      dataset = adapter.get_dataset()
      # Raise an error if `steps` isn't specified but the dataset
      # is infinite.
      # TODO(scottzhu): This check should probably happen in the adapter
      training_utils.infer_steps_for_dataset(
          dataset, steps, steps_name='steps', epochs=0)
      dataset = strategy.experimental_distribute_dataset(dataset)

      _update_sample_weight_mode(model, mode, dataset)
      execution_function = training_v2_utils._get_or_make_execution_function(
          model, mode)

      data_iterator = iter(dataset)

      callbacks = cbks.configure_callbacks(
          callbacks,
          model,
          do_validation=False,
          batch_size=batch_size,
          epochs=1,
          steps_per_epoch=steps,
          samples=None,
          verbose=0,  # Handle ProgBarLogger separately in this loop.
          mode=mode)

      with training_context.on_start(model, callbacks, verbose, mode):
        # TODO(scottzhu): Handle TPUStrategy training loop
        with training_context.on_epoch(0, mode) as epoch_logs:
          model.reset_metrics()
          result = run_one_epoch(
              model,
              data_iterator,
              execution_function,
              dataset_size=adapter.get_size(),
              strategy=strategy,
              steps_per_epoch=steps,
              mode=mode,
              training_context=training_context,
              total_epochs=1)
          cbks.make_logs(model, epoch_logs, result, mode)

    if len(result) == 1:
      result = result[0]
    return result