Example #1
0
def _process_training_inputs(model,
                             x,
                             y,
                             batch_size=None,
                             epochs=1,
                             sample_weights=None,
                             class_weights=None,
                             steps_per_epoch=None,
                             validation_split=0.,
                             validation_data=None,
                             validation_steps=None,
                             shuffle=True,
                             distribution_strategy=None,
                             max_queue_size=10,
                             workers=1,
                             use_multiprocessing=False):
  """Process the data input for fit() with respect to validation_split."""
  if validation_split and 0. < validation_split < 1. and validation_data:
    raise ValueError('validation_data and validation_split cannot be used '
                     'at same time.')

  adapter_cls = data_adapter.select_data_adapter(x, y)

  # Handle validation_split, we want to split the data and get the training
  # section before we give it to data adapter.
  if validation_split and 0. < validation_split < 1.:
    if adapter_cls not in _ADAPTER_FOR_VALIDATION_SPLIT:
      raise ValueError(
          '`validation_split` argument is not supported when '
          'data adapter is {}. Received: x={}, validation_split={}'.format(
              adapter_cls, x, validation_split))
    # Retrieve the training section from x and y, and then construct dataset
    # from it.
    x, y, sample_weights = model._standardize_user_data(
        x,
        y,
        sample_weight=sample_weights,
        class_weight=class_weights,
        batch_size=batch_size,
        check_steps=False,
        steps=steps_per_epoch)
    (x, y, sample_weights,
     val_x, val_y,
     val_sample_weights) = training_utils.split_training_and_validation_data(
         x, y, sample_weights, validation_split)

    sample_weight_modes = [
        e.sample_weight_mode for e in model._training_endpoints
    ]
    train_adapter = adapter_cls(
        x,
        y,
        batch_size=batch_size,
        epochs=epochs,
        sample_weights=sample_weights,
        sample_weight_modes=sample_weight_modes,
        shuffle=shuffle,
        distribution_strategy=distribution_strategy)

    val_adapter = adapter_cls(
        val_x,
        val_y,
        sample_weights=val_sample_weights,
        sample_weight_modes=sample_weight_modes,
        batch_size=batch_size,
        distribution_strategy=distribution_strategy)
  else:
    train_adapter = _process_inputs(
        model,
        ModeKeys.TRAIN,
        x,
        y,
        sample_weights=sample_weights,
        batch_size=batch_size,
        epochs=epochs,
        class_weights=class_weights,
        shuffle=shuffle,
        steps=steps_per_epoch,
        distribution_strategy=distribution_strategy,
        max_queue_size=max_queue_size,
        workers=workers,
        use_multiprocessing=use_multiprocessing)
    val_adapter = None
    if validation_data:
      (val_x, val_y,
       val_sample_weights) = training_utils.unpack_validation_data(
           validation_data, raise_if_ambiguous=False)
      # For eval data, we use a representative batch size of the
      # training data if batch_size was unknown.
      # This is useful for generator/sequence training data input with numpy
      # validation data input.
      if not batch_size:
        batch_size = train_adapter.representative_batch_size()
      val_adapter = _process_inputs(
          model,
          ModeKeys.TEST,
          val_x,
          val_y,
          sample_weights=val_sample_weights,
          batch_size=batch_size,
          class_weights=class_weights,
          steps=validation_steps,
          distribution_strategy=distribution_strategy)
    elif validation_steps:
      raise ValueError('`validation_steps` should not be specified if '
                       '`validation_data` is None.')
  return train_adapter, val_adapter
Example #2
0
    def fit(self,
            model,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            **kwargs):
        """Fit loop for Distribution Strategies."""
        dist_utils.validate_callbacks(input_callbacks=callbacks,
                                      optimizer=model.optimizer)
        dist_utils.validate_inputs(x, y)

        batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
            model._distribution_strategy, x, batch_size, steps_per_epoch,
            ModeKeys.TRAIN)
        batch_size = model._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x)
        dataset = model._distribution_standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            class_weight=class_weight,
            batch_size=batch_size,
            validation_split=validation_split,
            shuffle=shuffle,
            epochs=epochs)
        if not dist_utils.is_distributing_by_cloning(model):
            with model._distribution_strategy.scope():
                (dataset, _, _) = model._standardize_user_data(
                    dataset,
                    sample_weight=sample_weight,
                    class_weight=class_weight,
                    batch_size=batch_size,
                    validation_split=validation_split,
                    shuffle=shuffle)

        val_dataset = None
        if validation_data:
            val_x, val_y, val_sample_weights = training_utils.unpack_validation_data(
                validation_data)
            dist_utils.validate_inputs(val_x, val_y)
            _, validation_steps = dist_utils.process_batch_and_step_size(
                model._distribution_strategy, val_x, batch_size,
                validation_steps, ModeKeys.TEST)

            val_dataset = model._distribution_standardize_user_data(
                val_x,
                val_y,
                sample_weight=val_sample_weights,
                class_weight=None,
                batch_size=batch_size,
                validation_split=validation_split,
                shuffle=shuffle,
                allow_partial_batch=True)
        elif validation_split:
            raise ValueError('validation_split argument is not supported with '
                             'distribution strategies.')

        if dist_utils.is_tpu_strategy(model._distribution_strategy):
            steps_per_epoch = training_utils.infer_steps_for_dataset(
                dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
            if steps_per_epoch is None:
                raise ValueError(
                    'Number of steps could not be inferred from the data, '
                    'please pass the steps_per_epoch argument.')

            if not context.executing_eagerly():
                # Run TPU training in a custom loop in graph mode.
                return experimental_tpu_fit_loop(
                    model,
                    dataset,
                    epochs=epochs,
                    verbose=verbose,
                    callbacks=callbacks,
                    val_dataset=val_dataset,
                    initial_epoch=initial_epoch,
                    steps_per_epoch=steps_per_epoch,
                    validation_steps=validation_steps,
                    validation_freq=validation_freq)

        return training_arrays.fit_loop(model,
                                        dataset,
                                        batch_size=batch_size,
                                        epochs=epochs,
                                        verbose=verbose,
                                        callbacks=callbacks,
                                        val_inputs=val_dataset,
                                        shuffle=shuffle,
                                        initial_epoch=initial_epoch,
                                        steps_per_epoch=steps_per_epoch,
                                        validation_steps=validation_steps,
                                        validation_freq=validation_freq,
                                        steps_name='steps_per_epoch')