示例#1
0
    def test_verify_dataset_shuffled(self):
        dataset = dataset_ops.Dataset.range(5)
        training_utils.assert_not_shuffled(dataset)

        with test.mock.patch.object(logging, 'warning') as mock_log:
            training_utils.verify_dataset_shuffled(dataset)
            self.assertRegexpMatches(str(mock_log.call_args),
                                     'input dataset `x` is not shuffled.')

        shuffled_dataset = dataset.shuffle(10)
        training_utils.verify_dataset_shuffled(shuffled_dataset)
    def test_verify_dataset_shuffled(self, dataset_fn, expect_shuffled=False):
        dataset = dataset_fn()

        if not expect_shuffled:
            with test.mock.patch.object(logging, 'warning') as mock_log:
                shuffled = training_utils.verify_dataset_shuffled(dataset)
                self.assertRegex(str(mock_log.call_args),
                                 'input dataset `x` is not shuffled.')
                self.assertFalse(shuffled)
        else:
            self.assertTrue(training_utils.verify_dataset_shuffled(dataset))
示例#3
0
  def test_verify_dataset_shuffled(self):
    dataset = dataset_ops.Dataset.range(5)
    training_utils.assert_not_shuffled(dataset)

    with test.mock.patch.object(logging, 'warning') as mock_log:
      training_utils.verify_dataset_shuffled(dataset)
      self.assertRegexpMatches(
          str(mock_log.call_args),
          'input dataset `x` is not shuffled.')

    shuffled_dataset = dataset.shuffle(10)
    training_utils.verify_dataset_shuffled(shuffled_dataset)
示例#4
0
    def fit(self,
            model,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            **kwargs):
        model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
        # Make sure that y, sample_weights, validation_split are not passed.
        training_utils.validate_dataset_input(x, y, sample_weight,
                                              validation_split)
        if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2))
                and shuffle):
            training_utils.verify_dataset_shuffled(x)

        return fit_generator(model,
                             x,
                             steps_per_epoch=steps_per_epoch,
                             epochs=epochs,
                             verbose=verbose,
                             callbacks=callbacks,
                             validation_data=validation_data,
                             validation_steps=validation_steps,
                             validation_freq=validation_freq,
                             class_weight=class_weight,
                             workers=0,
                             shuffle=shuffle,
                             initial_epoch=initial_epoch,
                             steps_name='steps_per_epoch')