示例#1
0
    def test_verify_dataset_shuffled(self, dataset_fn, expect_shuffled=False):
        dataset = dataset_fn()

        if not expect_shuffled:
            with test.mock.patch.object(logging, 'warning') as mock_log:
                shuffled = training_utils_v1.verify_dataset_shuffled(dataset)
                self.assertRegex(str(mock_log.call_args),
                                 'input dataset `x` is not shuffled.')
                self.assertFalse(shuffled)
        else:
            self.assertTrue(training_utils_v1.verify_dataset_shuffled(dataset))
示例#2
0
  def fit(self,
          model,
          x=None,
          y=None,
          batch_size=None,
          epochs=1,
          verbose=1,
          callbacks=None,
          validation_split=0.,
          validation_data=None,
          shuffle=True,
          class_weight=None,
          sample_weight=None,
          initial_epoch=0,
          steps_per_epoch=None,
          validation_steps=None,
          validation_freq=1,
          **kwargs):
    model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
    # Make sure that y, sample_weights, validation_split are not passed.
    training_utils_v1.validate_dataset_input(x, y, sample_weight,
                                             validation_split)
    if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) and
        shuffle):
      training_utils_v1.verify_dataset_shuffled(x)

    return fit_generator(
        model,
        x,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        verbose=verbose,
        callbacks=callbacks,
        validation_data=validation_data,
        validation_steps=validation_steps,
        validation_freq=validation_freq,
        class_weight=class_weight,
        workers=0,
        shuffle=shuffle,
        initial_epoch=initial_epoch,
        steps_name='steps_per_epoch')