示例#1
0
    def test_dataset_input_shape_validation(self):
        with tf.compat.v1.get_default_graph().as_default(
        ), self.cached_session():
            model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
            model.compile(optimizer='rmsprop', loss='mse')

            # User forgets to batch the dataset
            inputs = np.zeros((10, 3))
            targets = np.zeros((10, 4))
            dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
            dataset = dataset.repeat(100)

            with self.assertRaisesRegex(
                    ValueError,
                    r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)'
            ):
                model.train_on_batch(dataset)

            # Wrong input shape
            inputs = np.zeros((10, 5))
            targets = np.zeros((10, 4))
            dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
            dataset = dataset.repeat(100)
            dataset = dataset.batch(10)

            with self.assertRaisesRegex(
                    ValueError, r'expected (.*?) to have shape \(3,\)'):
                model.train_on_batch(dataset)
    def test_dataset_with_class_weight(self):
        model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
        model.compile('rmsprop', 'mse')

        inputs = np.zeros((10, 3), np.float32)
        targets = np.zeros((10, 4), np.float32)
        dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
        dataset = dataset.repeat(100)
        dataset = dataset.batch(10)
        class_weight_np = np.array([0.25, 0.25, 0.25, 0.25])
        class_weight = dict(enumerate(class_weight_np))

        model.fit(dataset,
                  epochs=1,
                  steps_per_epoch=2,
                  verbose=1,
                  class_weight=class_weight)
示例#3
0
 def test_load_non_keras_saved_model(self):
   model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
   saved_model_dir = self._save_model_dir()
   tf.saved_model.save(model, saved_model_dir)
   with self.assertRaisesRegex(ValueError, 'Unable to create a Keras model'):
     keras_load.load(saved_model_dir)