示例#1
0
  def _model_iteration(
      self, model, mode, x=None, y=None, batch_size=None, verbose=1,
      sample_weight=None, steps=None, callbacks=None, max_queue_size=10,
      workers=1, use_multiprocessing=False, **kwargs):

    batch_size = model._validate_or_infer_batch_size(
        batch_size, steps, x)
    strategy = _get_distribution_strategy(model)
    batch_size, steps = dist_utils.process_batch_and_step_size(
        strategy, x, batch_size, steps, mode)
    dist_utils.validate_callbacks(input_callbacks=callbacks,
                                  optimizer=model.optimizer)
    # Enter tf.distribute.Strategy scope.
    with strategy.scope():
      adapter = _process_inputs(
          model,
          mode,
          x,
          y,
          batch_size=batch_size,
          sample_weights=sample_weight,
          steps=steps,
          distribution_strategy=strategy,
          max_queue_size=max_queue_size,
          workers=workers,
          use_multiprocessing=use_multiprocessing)
      total_samples = _get_total_number_of_samples(adapter)
      use_sample = total_samples is not None
      dataset = adapter.get_dataset()

      if not steps:
        # Raise an error if `steps` isn't specified but the dataset
        # is infinite.
        steps = adapter.get_size() or training_utils.infer_steps_for_dataset(
            model, dataset, steps, steps_name='steps')

      # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
      training_context = TrainingContext()
      dataset = strategy.experimental_distribute_dataset(dataset)

      execution_function = training_v2_utils._get_or_make_execution_function(
          model, mode)

      data_iterator = iter(dataset)

      callbacks = cbks.configure_callbacks(
          callbacks,
          model,
          do_validation=False,
          batch_size=batch_size,
          epochs=1,
          steps_per_epoch=steps,
          samples=use_sample,
          count_mode='samples' if use_sample else 'steps',
          verbose=0,  # Handle ProgBarLogger separately in this loop.
          mode=mode)

      with training_context.on_start(
          model, callbacks, use_sample, verbose, mode):
        with training_context.on_epoch(0, mode) as epoch_logs:
          model.reset_metrics()
          result = run_one_epoch(
              model,
              data_iterator,
              execution_function,
              dataset_size=adapter.get_size(),
              batch_size=adapter.batch_size(),
              strategy=strategy,
              steps_per_epoch=steps,
              num_samples=total_samples,
              mode=mode,
              training_context=training_context,
              total_epochs=1)
          cbks.make_logs(model, epoch_logs, result, mode)

    if len(result) == 1:
      result = result[0]
    return result
示例#2
0
    def fit(self,
            model,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            **kwargs):
        batch_size = model._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x)

        strategy = _get_distribution_strategy(model)
        batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
            strategy, x, batch_size, steps_per_epoch, ModeKeys.TRAIN)
        dist_utils.validate_callbacks(input_callbacks=callbacks,
                                      optimizer=model.optimizer)
        # Enter tf.distribute.Strategy scope.
        with strategy.scope():
            training_data_adapter, validation_adapter = _process_training_inputs(
                model,
                x,
                y,
                batch_size=batch_size,
                sample_weights=sample_weight,
                class_weights=class_weight,
                validation_split=validation_split,
                steps_per_epoch=steps_per_epoch,
                shuffle=shuffle,
                validation_data=validation_data,
                validation_steps=validation_steps,
                distribution_strategy=strategy)

            total_samples = _get_total_number_of_samples(training_data_adapter)
            use_sample = total_samples is not None
            do_validation = (validation_adapter is not None)

            if not steps_per_epoch:
                steps_per_epoch = training_data_adapter.get_size()

            # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
            training_context = TrainingContext()

            initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
                initial_epoch, ModeKeys.TRAIN)

            training_dataset = training_data_adapter.get_dataset()
            # Raise an error if steps_per_epoch isn't specified but the dataset
            # is infinite.
            # TODO(scottzhu): This check should probably happen in the adapter
            training_utils.infer_steps_for_dataset(
                training_dataset,
                steps_per_epoch,
                steps_name='steps_per_epoch',
                epochs=0)

            training_dataset = strategy.experimental_distribute_dataset(
                training_dataset)

            _update_sample_weight_mode(model, ModeKeys.TRAIN, training_dataset)
            training_function = training_v2_utils._get_or_make_execution_function(
                model, ModeKeys.TRAIN)

            training_data_iter = None
            # Only recreate iterator when the data has a fixed length, which will be
            # fully consumed every epoch, or has a unknown length (dataset, generator)
            # and will be fully consumed (steps_per_epoch is None)
            recreate_training_iterator = (training_data_adapter.get_size()
                                          is not None
                                          or steps_per_epoch is None)

            if do_validation:
                if not validation_steps:
                    validation_steps = validation_adapter.get_size()
                eval_function = training_v2_utils._get_or_make_execution_function(
                    model, ModeKeys.TEST)
                eval_data_iter = None

                validation_dataset = validation_adapter.get_dataset()
                # Raise an error if validation_steps isn't specified but the validation
                # dataset is infinite.
                # TODO(scottzhu): This check should probably happen in the adapter
                training_utils.infer_steps_for_dataset(
                    validation_dataset,
                    validation_steps,
                    steps_name='validation_steps',
                    epochs=0)
                validation_dataset = strategy.experimental_distribute_dataset(
                    validation_dataset)

            callbacks = cbks.configure_callbacks(
                callbacks,
                model,
                do_validation=do_validation,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                samples=total_samples,
                count_mode='samples' if use_sample else 'steps',
                verbose=0,  # Handle ProgBarLogger separately in this loop.
                mode=ModeKeys.TRAIN)

            with training_context.on_start(model, callbacks, use_sample,
                                           verbose, ModeKeys.TRAIN):
                # TODO(scottzhu): Handle TPUStrategy training loop
                for epoch in range(initial_epoch, epochs):
                    if training_context.callbacks.model.stop_training:
                        break

                    # Training
                    with training_context.on_epoch(
                            epoch, ModeKeys.TRAIN) as epoch_logs:
                        model.reset_metrics()
                        if training_data_iter is None or recreate_training_iterator:
                            if (training_data_iter is not None
                                    and distribution_strategy_context.
                                    has_strategy()):
                                # TODO(kaftan): remove this when MultiDeviceIterator is a
                                ## compositetensor (unless this is more efficient)
                                training_data_iter._initializer  # pylint: disable=pointless-statement
                            else:
                                training_data_iter = iter(training_dataset)

                        training_result = run_one_epoch(
                            model,
                            training_data_iter,
                            training_function,
                            dataset_size=training_data_adapter.get_size(),
                            batch_size=training_data_adapter.batch_size(),
                            strategy=strategy,
                            steps_per_epoch=steps_per_epoch,
                            num_samples=total_samples,
                            mode=ModeKeys.TRAIN,
                            training_context=training_context,
                            total_epochs=epochs)
                        cbks.make_logs(model, epoch_logs, training_result,
                                       ModeKeys.TRAIN)

                        # Evaluation
                        if (do_validation
                                and training_utils.should_run_validation(
                                    validation_freq, epoch)
                                and not callbacks.model.stop_training):
                            if (eval_data_iter is not None
                                    and distribution_strategy_context.
                                    has_strategy()):
                                # TODO(kaftan): remove this when MultiDeviceIterator is a
                                ## compositetensor (unless this is more efficient)
                                eval_data_iter._initializer  # pylint: disable=pointless-statement
                            else:
                                eval_data_iter = iter(validation_dataset)

                            val_total_samples = _get_total_number_of_samples(
                                validation_adapter)
                            eval_context = TrainingContext()
                            with eval_context.on_start(model,
                                                       callbacks,
                                                       use_sample,
                                                       verbose=0,
                                                       mode=ModeKeys.TEST):
                                with eval_context.on_epoch(
                                        epoch, ModeKeys.TEST):
                                    model.reset_metrics()
                                    eval_result = run_one_epoch(
                                        model,
                                        eval_data_iter,
                                        eval_function,
                                        dataset_size=validation_adapter.
                                        get_size(),
                                        batch_size=validation_adapter.
                                        batch_size(),
                                        strategy=strategy,
                                        steps_per_epoch=validation_steps,
                                        num_samples=val_total_samples,
                                        mode=ModeKeys.TEST,
                                        training_context=eval_context,
                                        total_epochs=1)
                                    cbks.make_logs(model,
                                                   epoch_logs,
                                                   eval_result,
                                                   ModeKeys.TEST,
                                                   prefix='val_')

        return model.history
示例#3
0
  def fit(
      self, model, x=None, y=None, batch_size=None, epochs=1, verbose=1,
      callbacks=None, validation_split=0., validation_data=None, shuffle=True,
      class_weight=None, sample_weight=None, initial_epoch=0,
      steps_per_epoch=None, validation_steps=None, validation_freq=1,
      max_queue_size=10, workers=1, use_multiprocessing=False, **kwargs):
    batch_size = model._validate_or_infer_batch_size(
        batch_size, steps_per_epoch, x)

    strategy = _get_distribution_strategy(model)
    batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
        strategy,
        x,
        batch_size,
        steps_per_epoch,
        ModeKeys.TRAIN,
        validation_split=validation_split)
    dist_utils.validate_callbacks(input_callbacks=callbacks,
                                  optimizer=model.optimizer)
    # Enter tf.distribute.Strategy scope.
    with strategy.scope():
      training_data_adapter, validation_adapter = _process_training_inputs(
          model,
          x,
          y,
          batch_size=batch_size,
          epochs=epochs,
          sample_weights=sample_weight,
          class_weights=class_weight,
          validation_split=validation_split,
          steps_per_epoch=steps_per_epoch,
          shuffle=shuffle,
          validation_data=validation_data,
          validation_steps=validation_steps,
          distribution_strategy=strategy,
          max_queue_size=max_queue_size,
          workers=workers,
          use_multiprocessing=use_multiprocessing)

      total_samples = _get_total_number_of_samples(training_data_adapter)
      use_sample = total_samples is not None
      do_validation = (validation_adapter is not None)

      recreate_training_iterator = (
          training_data_adapter.should_recreate_iterator(steps_per_epoch))
      if not steps_per_epoch:
        # TODO(b/139762795): Add step inference for when steps is None to
        # prevent end of sequence warning message.
        steps_per_epoch = training_data_adapter.get_size()

      # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
      training_context = TrainingContext()

      training_dataset = training_data_adapter.get_dataset()
      # Raise an error if steps_per_epoch isn't specified but the dataset
      # is infinite.
      # TODO(scottzhu): This check should probably happen in the adapter
      inferred_steps = training_utils.infer_steps_for_dataset(
          model,
          training_dataset,
          steps_per_epoch,
          steps_name='steps_per_epoch',
          epochs=0)

      steps_per_epoch = (
          inferred_steps if steps_per_epoch is None else steps_per_epoch)

      training_dataset = strategy.experimental_distribute_dataset(
          training_dataset)

      training_function = training_v2_utils._get_or_make_execution_function(
          model, ModeKeys.TRAIN)

      training_data_iter = None
      if do_validation:
        validation_dataset = validation_adapter.get_dataset()
        if not validation_steps:
          # Raise an error if validation_steps isn't specified but the
          # validation dataset is infinite.
          validation_steps = (
              validation_adapter.get_size() or
              training_utils.infer_steps_for_dataset(
                  model,
                  validation_dataset,
                  validation_steps,
                  steps_name='validation_steps'))
        eval_function = training_v2_utils._get_or_make_execution_function(
            model, ModeKeys.TEST)
        eval_data_iter = None
        validation_dataset = strategy.experimental_distribute_dataset(
            validation_dataset)
        val_total_samples = _get_total_number_of_samples(validation_adapter)
      else:
        val_total_samples = None

      if verbose and (total_samples or steps_per_epoch):
        _print_train_info(total_samples, steps_per_epoch, val_total_samples,
                          validation_steps)

      training_callbacks = cbks.configure_callbacks(
          callbacks,
          model,
          do_validation=do_validation,
          batch_size=batch_size,
          epochs=epochs,
          steps_per_epoch=steps_per_epoch,
          samples=total_samples or steps_per_epoch,
          count_mode='samples' if use_sample else 'steps',
          verbose=0,  # Handle ProgBarLogger separately in this loop.
          mode=ModeKeys.TRAIN)

      with training_context.on_start(model, training_callbacks, use_sample,
                                     verbose, ModeKeys.TRAIN):

        initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
            initial_epoch, ModeKeys.TRAIN)

        for epoch in range(initial_epoch, epochs):
          if training_context.callbacks.model.stop_training:
            break

          # Training
          with training_context.on_epoch(epoch, ModeKeys.TRAIN) as epoch_logs:
            model.reset_metrics()
            if training_data_iter is None or recreate_training_iterator:
              if (training_data_iter is not None and
                  distribution_strategy_context.has_strategy()):
                # TODO(kaftan): remove this when MultiDeviceIterator is a
                ## compositetensor (unless this is more efficient)
                training_data_iter._initializer  # pylint: disable=pointless-statement
              else:
                training_data_iter = iter(training_dataset)

            training_result = run_one_epoch(
                model,
                training_data_iter,
                training_function,
                dataset_size=training_data_adapter.get_size(),
                batch_size=training_data_adapter.batch_size(),
                strategy=strategy,
                steps_per_epoch=steps_per_epoch,
                num_samples=total_samples,
                mode=ModeKeys.TRAIN,
                training_context=training_context,
                total_epochs=epochs)
            cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)

            # In the case of steps_per_epoch = None, the final cardinality will
            # be determined when the inputs are fully consumed (eg dataset or
            # generator). Update the steps_per_epoch to the new value.
            if (steps_per_epoch is None
                and training_context.progbar.progbar.target is not None):
              steps_per_epoch = training_context.progbar.progbar.target

            # Evaluation
            if (do_validation and
                training_utils.should_run_validation(validation_freq, epoch) and
                not training_callbacks.model.stop_training):
              if (eval_data_iter is not None and
                  distribution_strategy_context.has_strategy()):
                # TODO(kaftan): remove this when MultiDeviceIterator is a
                ## compositetensor (unless this is more efficient)
                eval_data_iter._initializer  # pylint: disable=pointless-statement
              else:
                eval_data_iter = iter(validation_dataset)

              validation_callbacks = cbks.configure_callbacks(
                  training_callbacks,
                  model,
                  batch_size=batch_size,
                  epochs=1,
                  steps_per_epoch=validation_steps,
                  samples=val_total_samples or validation_steps,
                  count_mode='samples' if use_sample else 'steps',
                  verbose=0,  # Handle ProgBarLogger separately in this loop.
                  mode=ModeKeys.TEST)

              eval_context = TrainingContext()
              with eval_context.on_start(
                  model,
                  validation_callbacks,
                  use_sample,
                  verbose=0,
                  mode=ModeKeys.TEST):
                with eval_context.on_epoch(epoch, ModeKeys.TEST):
                  model.reset_metrics()
                  eval_result = run_one_epoch(
                      model,
                      eval_data_iter,
                      eval_function,
                      dataset_size=validation_adapter.get_size(),
                      batch_size=validation_adapter.batch_size(),
                      strategy=strategy,
                      steps_per_epoch=validation_steps,
                      num_samples=val_total_samples,
                      mode=ModeKeys.TEST,
                      training_context=eval_context,
                      total_epochs=1)
                  cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
                                 prefix='val_')

    return model.history
示例#4
0
  def _model_iteration(
      self, model, mode, x=None, y=None, batch_size=None, verbose=1,
      sample_weight=None, steps=None, callbacks=None, **kwargs):

    batch_size = model._validate_or_infer_batch_size(
        batch_size, steps, x)
    strategy = _get_distribution_strategy(model)
    batch_size, steps = dist_utils.process_batch_and_step_size(
        strategy, x, batch_size, steps, mode)
    dist_utils.validate_callbacks(input_callbacks=callbacks,
                                  optimizer=model.optimizer)
    # Enter tf.distribute.Strategy scope.
    with dist_utils.distributed_scope(
        strategy=strategy, learning_phase=0):

      adapter = _process_inputs(
          model,
          x,
          y,
          batch_size=batch_size,
          sample_weights=sample_weight,
          steps=steps,
          distribution_strategy=strategy)

      if not steps:
        steps = adapter.get_size()

      # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
      training_context = TrainingContext()

      dataset = adapter.get_dataset()
      # Raise an error if `steps` isn't specified but the dataset
      # is infinite.
      # TODO(scottzhu): This check should probably happen in the adapter
      training_utils.infer_steps_for_dataset(
          dataset, steps, steps_name='steps', epochs=0)
      dataset = strategy.experimental_distribute_dataset(dataset)

      _update_sample_weight_mode(model, mode, dataset)
      execution_function = training_v2_utils._get_or_make_execution_function(
          model, mode)

      data_iterator = iter(dataset)

      callbacks = cbks.configure_callbacks(
          callbacks,
          model,
          do_validation=False,
          batch_size=batch_size,
          epochs=1,
          steps_per_epoch=steps,
          samples=None,
          verbose=0,  # Handle ProgBarLogger separately in this loop.
          mode=mode)

      with training_context.on_start(model, callbacks, verbose, mode):
        # TODO(scottzhu): Handle TPUStrategy training loop
        with training_context.on_epoch(0, mode) as epoch_logs:
          model.reset_metrics()
          result = run_one_epoch(
              model,
              data_iterator,
              execution_function,
              dataset_size=adapter.get_size(),
              strategy=strategy,
              steps_per_epoch=steps,
              mode=mode,
              training_context=training_context,
              total_epochs=1)
          cbks.make_logs(model, epoch_logs, result, mode)

    if len(result) == 1:
      result = result[0]
    return result