Пример #1
0
    def testPruneSequentialModelPreservesBuiltState(self):
        # No InputLayer
        model = keras.Sequential([
            layers.Dense(10),
            layers.Dense(10),
        ])
        self.assertEqual(model.built, False)
        pruned_model = prune.prune_low_magnitude(model, **self.params)
        self.assertEqual(model.built, False)

        # Test built state preserves across serialization
        with prune.prune_scope():
            loaded_model = keras.models.model_from_config(
                json.loads(pruned_model.to_json()))
            self.assertEqual(loaded_model.built, False)

        # With InputLayer
        model = keras.Sequential([
            layers.Dense(10, input_shape=(10, )),
            layers.Dense(10),
        ])
        self.assertEqual(model.built, True)
        pruned_model = prune.prune_low_magnitude(model, **self.params)
        self.assertEqual(model.built, True)

        # Test built state preserves across serialization
        with prune.prune_scope():
            loaded_model = keras.models.model_from_config(
                json.loads(pruned_model.to_json()))
        self.assertEqual(loaded_model.built, True)
    def testFunctionalModelForLatencyOnDoubleXNNPackPolicy(self):
        i = keras.Input(shape=(8, 8, 3))
        x = layers.ZeroPadding2D(padding=1)(i)
        x = layers.Conv2D(
            filters=16,
            kernel_size=(3, 3),
            strides=(2, 2),
            padding='valid',
        )(x)
        x = layers.Conv2D(filters=16, kernel_size=[1, 1])(x)
        o = layers.GlobalAveragePooling2D()(x)
        model = keras.Model(inputs=[i], outputs=[o])

        pruned_model = prune.prune_low_magnitude(
            model,
            pruning_policy=pruning_policy.PruneForLatencyOnXNNPack(),
            **self.params,
        )
        self.assertEqual(self._count_pruned_layers(pruned_model), 1)
        double_pruned_model = prune.prune_low_magnitude(
            pruned_model,
            pruning_policy=pruning_policy.PruneForLatencyOnXNNPack(),
            **self.params,
        )
        self.assertEqual(self._count_pruned_layers(double_pruned_model), 1)
Пример #3
0
 def testPruneModelCustomNonPrunableLayerRaisesError(self):
     with self.assertRaises(ValueError):
         prune.prune_low_magnitude(
             keras.Sequential([
                 self.keras_prunable_layer, self.keras_non_prunable_layer,
                 self.custom_prunable_layer, self.custom_non_prunable_layer
             ]), **self.params)
    def testPrunesMnist_ReachesTargetSparsity(self, model_type):
        model = test_utils.build_mnist_model(model_type, self.params)
        if model_type == 'layer_list':
            model = keras.Sequential(
                prune.prune_low_magnitude(model, **self.params))
        elif model_type in ['sequential', 'functional']:
            model = prune.prune_low_magnitude(model, **self.params)

        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        test_utils.assert_model_sparsity(self,
                                         0.0,
                                         model,
                                         rtol=1e-4,
                                         atol=1e-4)
        model.fit(np.random.rand(32, 28, 28, 1),
                  keras.utils.to_categorical(
                      np.random.randint(10, size=(32, 1)), 10),
                  callbacks=[pruning_callbacks.UpdatePruningStep()])

        test_utils.assert_model_sparsity(self,
                                         0.5,
                                         model,
                                         rtol=1e-4,
                                         atol=1e-4)

        self._check_strip_pruning_matches_original(model, 0.5)
Пример #5
0
  def testPruneCheckpoints_CheckpointsNotSparse(self):
    is_model_sparsity_not_list = []

    # Run multiple times since problem doesn't always happen.
    for _ in range(3):
      model = keras_test_utils.build_simple_dense_model()
      pruned_model = prune.prune_low_magnitude(model, **self.params)

      checkpoint_dir = tempfile.mkdtemp()
      checkpoint_path = checkpoint_dir + '/weights.{epoch:02d}.tf'

      callbacks = [
          pruning_callbacks.UpdatePruningStep(),
          tf.keras.callbacks.ModelCheckpoint(
              filepath=checkpoint_path, save_weights_only=True, save_freq=1)
      ]

      # Train one step. Sparsity reaches final sparsity.
      self._train_model(pruned_model, epochs=1, callbacks=callbacks)
      test_utils.assert_model_sparsity(self, 0.5, pruned_model)

      latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)

      same_architecture_model = keras_test_utils.build_simple_dense_model()
      pruned_model = prune.prune_low_magnitude(same_architecture_model,
                                               **self.params)

      # Sanity check.
      test_utils.assert_model_sparsity(self, 0, pruned_model)

      pruned_model.load_weights(latest_checkpoint)
      is_model_sparsity_not_list.append(
          test_utils.is_model_sparsity_not(0.5, pruned_model))

    self.assertTrue(any(is_model_sparsity_not_list))
Пример #6
0
  def testPruneModelDoesNotWrapAlreadyWrappedLayer(self):
    model = keras.Sequential(
        [layers.Dense(10),
         prune.prune_low_magnitude(layers.Dense(10), **self.params)])
    pruned_model = prune.prune_low_magnitude(model, **self.params)
    pruned_model.build(input_shape=(10, 1))

    self.assertEqual(len(model.layers), len(pruned_model.layers))
    self._validate_pruned_layer(model.layers[0], pruned_model.layers[0])
    # Second layer is used as-is since it's already a pruned layer.
    self.assertEqual(model.layers[1], pruned_model.layers[1])
Пример #7
0
  def testPruneScope_NotNeededForTFCheckpoint(self):
    model = keras_test_utils.build_simple_dense_model()
    pruned_model = prune.prune_low_magnitude(model)

    _, tf_weights = tempfile.mkstemp('.tf')
    pruned_model.save_weights(tf_weights)

    same_architecture_model = keras_test_utils.build_simple_dense_model()
    same_architecture_model = prune.prune_low_magnitude(same_architecture_model)

    # would error if `prune_scope` was needed.
    same_architecture_model.load_weights(tf_weights)
Пример #8
0
def main(unused_argv):
	# input image dimensions
	img_rows, img_cols = 28, 28

	# the data, shuffled and split between train and test sets
	(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

	if tf.keras.backend.image_data_format() == 'channels_first':
		x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
		x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
		input_shape = (1, img_rows, img_cols)
	else:
		x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
		x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
		input_shape = (img_rows, img_cols, 1)

	x_train = x_train.astype('float32')
	x_test = x_test.astype('float32')
	x_train /= 255
	x_test /= 255
	print('x_train shape:', x_train.shape)
	print(x_train.shape[0], 'train samples')
	print(x_test.shape[0], 'test samples')

	# convert class vectors to binary class matrices
	y_train = tf.keras.utils.to_categorical(y_train, num_classes)
	y_test = tf.keras.utils.to_categorical(y_test, num_classes)

	pruning_params = {
			'pruning_schedule':
					PolynomialDecay(
							initial_sparsity=0.1,
							final_sparsity=0.75,
							begin_step=1000,
							end_step=5000,
							frequency=100)
	}

	layerwise_model = build_layerwise_model(input_shape, **pruning_params)
	sequential_model = build_sequential_model(input_shape)
	sequential_model = prune.prune_low_magnitude(
			sequential_model, **pruning_params)
	functional_model = build_functional_model(input_shape)
	functional_model = prune.prune_low_magnitude(
			functional_model, **pruning_params)

	models = [layerwise_model, sequential_model, functional_model]
	train_and_save(models, x_train, y_train, x_test, y_test)
Пример #9
0
    def testPruneSequentialModel(self):
        # No InputLayer
        model = keras.Sequential([
            layers.Dense(10),
            layers.Dense(10),
        ])
        pruned_model = prune.prune_low_magnitude(model, **self.params)
        self.assertEqual(self._count_pruned_layers(pruned_model), 2)

        # With InputLayer
        model = keras.Sequential([
            layers.Dense(10, input_shape=(10, )),
            layers.Dense(10),
        ])
        pruned_model = prune.prune_low_magnitude(model, **self.params)
        self.assertEqual(self._count_pruned_layers(pruned_model), 2)
Пример #10
0
 def testPruneSubclassModel(self):
   model = TestSubclassedModel()
   with self.assertRaises(ValueError) as e:
     _ = prune.prune_low_magnitude(model, **self.params)
   self.assertEqual(
       str(e.exception),
       self.INVALID_TO_PRUNE_PARAM_ERROR.format(input='TestSubclassedModel'))
Пример #11
0
    def testPrunesPreviouslyUnprunedModel(self):
        model = keras_test_utils.build_simple_dense_model()
        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        # Simple unpruned model. No sparsity.
        model.fit(np.random.rand(20, 10),
                  np_utils.to_categorical(np.random.randint(5, size=(20, 1)),
                                          5),
                  epochs=2,
                  batch_size=20)
        test_utils.assert_model_sparsity(self, 0.0, model)

        # Apply pruning to model.
        model = prune.prune_low_magnitude(model, **self.params)
        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        # Since newly compiled, iterations starts from 0.
        model.fit(np.random.rand(20, 10),
                  np_utils.to_categorical(np.random.randint(5, size=(20, 1)),
                                          5),
                  batch_size=20,
                  callbacks=[pruning_callbacks.UpdatePruningStep()])
        test_utils.assert_model_sparsity(self, 0.5, model)

        self._check_strip_pruning_matches_original(model, 0.5)
    def testPruneStopAndRestart_PreservesSparsity(self, save_restore_fn):
        # TODO(tfmot): renable once SavedModel preserves step again.
        # This existed in TF 2.0 and 2.1 and should be reenabled in
        # TF 2.3. b/151755698
        if save_restore_fn.__name__ == '_save_restore_tf_model':
            return

        begin_step, end_step = 1, 4
        params = self.params
        params['pruning_schedule'] = pruning_schedule.PolynomialDecay(
            0.2, 0.6, begin_step, end_step, 3, 1)

        model = prune.prune_low_magnitude(
            keras_test_utils.build_simple_dense_model(), **params)
        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        # Model hasn't been trained yet. Sparsity 0.0
        test_utils.assert_model_sparsity(self, 0.0, model)

        # Train only 1 step. Sparsity 0.2 (initial_sparsity)
        self._train_model(model, epochs=1)
        test_utils.assert_model_sparsity(self, 0.2, model)

        model = save_restore_fn(model)

        # Training has run all 4 steps. Sparsity 0.6 (final_sparsity)
        self._train_model(model, epochs=3)
        test_utils.assert_model_sparsity(self, 0.6, model)

        self._check_strip_pruning_matches_original(model, 0.6)
Пример #13
0
  def testPrunesSimpleDenseModel(self, distribution):
    with distribution.scope():
      model = prune.prune_low_magnitude(
          keras_test_utils.build_simple_dense_model(), **self.params)
      model.compile(
          loss='categorical_crossentropy',
          optimizer='sgd',
          metrics=['accuracy'])

    # Model hasn't been trained yet. Sparsity 0.0
    test_utils.assert_model_sparsity(self, 0.0, model)

    # Simple unpruned model. No sparsity.
    model.fit(
        np.random.rand(20, 10),
        keras.utils.np_utils.to_categorical(
            np.random.randint(5, size=(20, 1)), 5),
        epochs=2,
        callbacks=[pruning_callbacks.UpdatePruningStep()],
        batch_size=20)
    model.predict(np.random.rand(20, 10))
    test_utils.assert_model_sparsity(self, 0.5, model)

    _, keras_file = tempfile.mkstemp('.h5')
    keras.models.save_model(model, keras_file)

    with prune.prune_scope():
      loaded_model = keras.models.load_model(keras_file)

    test_utils.assert_model_sparsity(self, 0.5, loaded_model)
    def testPruneWithPolynomialDecayPastEndStep_PreservesSparsity(
            self, save_restore_fn):
        begin_step, end_step = 0, 2
        params = self.params
        params['pruning_schedule'] = pruning_schedule.PolynomialDecay(
            0.2, 0.6, begin_step, end_step, 3, 1)

        model = prune.prune_low_magnitude(
            keras_test_utils.build_simple_dense_model(), **params)
        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])

        # Model hasn't been trained yet. Sparsity 0.0
        test_utils.assert_model_sparsity(self, 0.0, model)

        # Train 3 steps, past end_step. Sparsity 0.6 (final_sparsity)
        self._train_model(model, epochs=3)
        test_utils.assert_model_sparsity(self, 0.6, model)

        model = save_restore_fn(model)

        # Ensure sparsity is preserved.
        test_utils.assert_model_sparsity(self, 0.6, model)

        # Train one more step to ensure nothing happens that brings sparsity
        # back below 0.6.
        self._train_model(model, epochs=1)
        test_utils.assert_model_sparsity(self, 0.6, model)

        self._check_strip_pruning_matches_original(model, 0.6)
    def testPrunesModel_CustomTrainingLoop_ReachesTargetSparsity(self):
        pruned_model = prune.prune_low_magnitude(
            keras_test_utils.build_simple_dense_model())

        batch_size = 20
        x_train = np.random.rand(20, 10)
        y_train = keras.utils.to_categorical(
            np.random.randint(5, size=(batch_size, 1)), 5)

        loss = keras.losses.categorical_crossentropy
        optimizer = keras.optimizers.SGD()

        unused_arg = -1

        step_callback = pruning_callbacks.UpdatePruningStep()
        step_callback.set_model(pruned_model)
        pruned_model.optimizer = optimizer

        step_callback.on_train_begin()
        # 2 epochs
        for _ in range(2):
            step_callback.on_train_batch_begin(batch=unused_arg)
            inp = np.reshape(x_train,
                             [batch_size, 10])  # original shape: from [10].
            with tf.GradientTape() as tape:
                logits = pruned_model(inp, training=True)
                loss_value = loss(y_train, logits)
                grads = tape.gradient(loss_value,
                                      pruned_model.trainable_variables)
                optimizer.apply_gradients(
                    zip(grads, pruned_model.trainable_variables))
            step_callback.on_epoch_end(batch=unused_arg)

        test_utils.assert_model_sparsity(self, 0.5, pruned_model)
Пример #16
0
    def testRNNLayersWithRNNCellParams(self):
        model = keras.Sequential()
        model.add(
            prune.prune_low_magnitude(keras.layers.RNN([
                layers.LSTMCell(10),
                layers.GRUCell(10),
                layers.PeepholeLSTMCell(10),
                layers.SimpleRNNCell(10)
            ]),
                                      input_shape=(3, 4),
                                      **self.params))

        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        test_utils.assert_model_sparsity(self, 0.0, model)
        model.fit(np.random.randn(
            *self._batch(model.input.get_shape().as_list(), 32)),
                  np.random.randn(
                      *self._batch(model.output.get_shape().as_list(), 32)),
                  callbacks=[pruning_callbacks.UpdatePruningStep()])

        test_utils.assert_model_sparsity(self, 0.5, model)

        self._check_strip_pruning_matches_original(model, 0.5)
    def testPruneStopAndRestart_PreservesSparsity(self, save_restore_fn):
        begin_step, end_step = 0, 4
        params = self.params
        params['pruning_schedule'] = pruning_schedule.PolynomialDecay(
            0.2, 0.6, begin_step, end_step, 3, 1)

        model = prune.prune_low_magnitude(
            keras_test_utils.build_simple_dense_model(), **params)
        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        # Model hasn't been trained yet. Sparsity 0.0
        test_utils.assert_model_sparsity(self, 0.0, model)

        # Train only 1 step. Sparsity 0.2 (initial_sparsity)
        self._train_model(model, epochs=1)
        test_utils.assert_model_sparsity(self, 0.2, model)

        model = save_restore_fn(model)

        # Training has run all 4 steps. Sparsity 0.6 (final_sparsity)
        self._train_model(model, epochs=3)
        test_utils.assert_model_sparsity(self, 0.6, model)

        self._check_strip_pruning_matches_original(model, 0.6)
Пример #18
0
    def testPruneStopAndRestartOnModel(self, save_restore_fn):
        params = {
            'pruning_schedule':
            pruning_schedule.PolynomialDecay(0.2, 0.6, 0, 4, 3, 1),
            'block_size': (1, 1),
            'block_pooling_type':
            'AVG'
        }
        model = prune.prune_low_magnitude(
            keras_test_utils.build_simple_dense_model(), **params)
        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        # Model hasn't been trained yet. Sparsity 0.0
        test_utils.assert_model_sparsity(self, 0.0, model)

        model.fit(np.random.rand(20, 10),
                  np_utils.to_categorical(np.random.randint(5, size=(20, 1)),
                                          5),
                  batch_size=20,
                  callbacks=[pruning_callbacks.UpdatePruningStep()])
        # Training has run only 1 step. Sparsity 0.2 (initial_sparsity)
        test_utils.assert_model_sparsity(self, 0.2, model)

        model = save_restore_fn(model)
        model.fit(np.random.rand(20, 10),
                  np_utils.to_categorical(np.random.randint(5, size=(20, 1)),
                                          5),
                  batch_size=20,
                  epochs=3,
                  callbacks=[pruning_callbacks.UpdatePruningStep()])
        # Training has run all 4 steps. Sparsity 0.6 (final_sparsity)
        test_utils.assert_model_sparsity(self, 0.6, model)

        self._check_strip_pruning_matches_original(model, 0.6)
    def testMbyNSparsityPruning_SupportedLayers(self,
                                                layer_type,
                                                layer_arg,
                                                input_shape,
                                                m_by_n=(2, 4),
                                                sparsity_ratio=0.50):
        """Check that we prune supported layers with m by n sparsity."""
        self.params.update({'sparsity_m_by_n': m_by_n})

        model = keras.Sequential()
        model.add(
            prune.prune_low_magnitude(layer_type(*layer_arg),
                                      input_shape=input_shape,
                                      **self.params))
        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])

        test_utils.assert_model_sparsity(self, 0.0, model)
        model.fit(np.random.randn(
            *self._batch(model.input.get_shape().as_list(), 32)),
                  np.random.randn(
                      *self._batch(model.output.get_shape().as_list(), 32)),
                  callbacks=[pruning_callbacks.UpdatePruningStep()])

        test_utils.assert_model_sparsity_m_by_n(self, model, m_by_n)
        self._check_strip_pruning_matches_original(model, sparsity_ratio)
Пример #20
0
def float_cnn_allPrune(name_,
                       Inputs,
                       nclasses,
                       filters,
                       kernel,
                       strides,
                       pooling,
                       dropout,
                       activation,
                       pruning_params={}):
    print("Building model: float_cnn")
    length = len(filters)
    if any(
            len(lst) != length
            for lst in [filters, kernel, strides, pooling, dropout]):
        sys.exit(
            "One value for stride and kernel must be added for each filter! Exiting"
        )
    x = x_in = Inputs
    x = BatchNormalization()(x)
    x = ZeroPadding2D(padding=(1, 1), data_format="channels_last")(x)
    for i, (f, k, s, p,
            d) in enumerate(zip(filters, kernel, strides, pooling, dropout)):
        print((
            "Adding layer with {} filters, kernel_size=({},{}), strides=({},{})"
        ).format(f, k, k, s, s))
        x = prune.prune_low_magnitude((Conv2D(int(f),
                                              kernel_size=(int(k), int(k)),
                                              strides=(int(s), int(s)),
                                              use_bias=False,
                                              name='conv_%i' % i)),
                                      **pruning_params)(x)
        if float(p) != 0:
            x = MaxPooling2D(pool_size=(int(p), int(p)))(x)
        x = BatchNormalization()(x)
        x = Activation(activation, name='conv_act_%i' % i)(x)
    x = Flatten()(x)
    x = prune.prune_low_magnitude((Dense(128,
                                         kernel_initializer='lecun_uniform',
                                         use_bias=False,
                                         name='dense_1')), **pruning_params)(x)
    x = BatchNormalization()(x)
    x = Activation(activation, name='dense_act')(x)
    x_out = Dense(nclasses, activation='softmax', name='output')(x)
    model = Model(inputs=[x_in], outputs=[x_out], name=name_)
    return model
Пример #21
0
    def testPruneMiscObject(self):

        model = object()
        with self.assertRaises(ValueError) as e:
            _ = prune.prune_low_magnitude(model, **self.params)
        self.assertEqual(
            str(e.exception),
            self.INVALID_TO_PRUNE_PARAM_ERROR.format(input='object'))
Пример #22
0
 def testPruneFunctionalModel(self):
     i1 = keras.Input(shape=(10, ))
     i2 = keras.Input(shape=(10, ))
     x1 = layers.Dense(10)(i1)
     x2 = layers.Dense(10)(i2)
     outputs = layers.Add()([x1, x2])
     model = keras.Model(inputs=[i1, i2], outputs=outputs)
     pruned_model = prune.prune_low_magnitude(model, **self.params)
     self.assertEqual(self._count_pruned_layers(pruned_model), 3)
Пример #23
0
 def testPruneModelRecursively(self):
     internal_model = keras.Sequential(
         [keras.layers.Dense(10, input_shape=(10, ))])
     original_model = keras.Sequential([
         internal_model,
         layers.Dense(10),
     ])
     pruned_model = prune.prune_low_magnitude(original_model, **self.params)
     self.assertEqual(self._count_pruned_layers(pruned_model), 2)
Пример #24
0
def build_layerwise_model(input_shape, **pruning_params):
	return tf.keras.Sequential([
			prune.prune_low_magnitude(
					l.Conv2D(32, 5, padding='same', activation='relu'),
					input_shape=input_shape,
					**pruning_params),
			l.MaxPooling2D((2, 2), (2, 2), padding='same'),
			l.BatchNormalization(),
			prune.prune_low_magnitude(
					l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params),
			l.MaxPooling2D((2, 2), (2, 2), padding='same'),
			l.Flatten(),
			prune.prune_low_magnitude(
					l.Dense(1024, activation='relu'), **pruning_params),
			l.Dropout(0.4),
			prune.prune_low_magnitude(
					l.Dense(num_classes, activation='softmax'), **pruning_params)
	])
Пример #25
0
 def testPruneFunctionalModelWithLayerReused(self):
     # The model reuses the Dense() layer. Make sure it's only pruned once.
     inp = keras.Input(shape=(10, ))
     dense_layer = layers.Dense(10)
     x = dense_layer(inp)
     x = dense_layer(x)
     model = keras.Model(inputs=[inp], outputs=[x])
     pruned_model = prune.prune_low_magnitude(model, **self.params)
     self.assertEqual(self._count_pruned_layers(pruned_model), 1)
    def testPrunePretrainedModel_SameInferenceWithoutTraining(self):
        model = self._get_pretrained_model()
        pruned_model = prune.prune_low_magnitude(model, **self.params)

        input_data = np.random.rand(10, 10)

        out = model.predict(input_data)
        pruned_out = pruned_model.predict(input_data)

        self.assertTrue((out == pruned_out).all())
Пример #27
0
    def testStripPruningSequentialModel(self):
        model = keras.Sequential([
            layers.Dense(10),
            layers.Dense(10),
        ])

        pruned_model = prune.prune_low_magnitude(model, **self.params)
        stripped_model = prune.strip_pruning(pruned_model)
        self.assertEqual(self._count_pruned_layers(stripped_model), 0)
        self.assertEqual(model.get_config(), stripped_model.get_config())
Пример #28
0
    def testPruneValidLayersListSuccessful(self):
        model_layers = [
            self.keras_prunable_layer, self.keras_non_prunable_layer,
            self.custom_prunable_layer
        ]
        pruned_layers = prune.prune_low_magnitude(model_layers, **self.params)

        self.assertEqual(len(model_layers), len(pruned_layers))
        for layer, pruned_layer in zip(model_layers, pruned_layers):
            self._validate_pruned_layer(layer, pruned_layer)
    def testPruneWithHighSparsity_Fails(self):
        params = self.params
        params['pruning_schedule'] = pruning_schedule.ConstantSparsity(
            target_sparsity=0.99, begin_step=0, frequency=1)

        model = prune.prune_low_magnitude(
            keras_test_utils.build_simple_dense_model(), **params)

        with self.assertRaises(tf.errors.InvalidArgumentError):
            self._train_model(model, epochs=1)
Пример #30
0
    def testPruneModelValidLayersSuccessful(self):
        model = keras.Sequential([
            self.keras_prunable_layer, self.keras_non_prunable_layer,
            self.custom_prunable_layer
        ])
        pruned_model = prune.prune_low_magnitude(model, **self.params)
        pruned_model.build(input_shape=(1, 28, 28, 1))

        self.assertEqual(len(model.layers), len(pruned_model.layers))
        for layer, pruned_layer in zip(model.layers, pruned_model.layers):
            self._validate_pruned_layer(layer, pruned_layer)