def testConditionalMaskUpdate(self):
        weight = K.variable(np.linspace(1.0, 100.0, 100), name="weights")
        mask = K.ones(weight.get_shape())
        threshold = K.zeros([])

        def linear_sparsity(step):
            sparsity_val = ops.convert_to_tensor(
                [0.0, 0.1, 0.1, 0.3, 0.3, 0.5, 0.5, 0.5, 0.5, 0.5])
            return ops.convert_to_tensor(True), sparsity_val[step]

        # Set up pruning
        p = pruning_impl.Pruning(pruning_vars=[(weight, mask, threshold)],
                                 training_step_fn=self.training_step_fn,
                                 pruning_schedule=linear_sparsity,
                                 block_size=self.block_size,
                                 block_pooling_type=self.block_pooling_type)

        non_zero_count = []
        for _ in range(10):
            if context.executing_eagerly():
                p.conditional_mask_update()
                p.weight_mask_op()
                state_ops.assign_add(self.global_step, 1)
            else:
                K.get_session().run(p.conditional_mask_update())
                K.get_session().run(p.weight_mask_op())
                K.get_session().run(state_ops.assign_add(self.global_step, 1))

            non_zero_count.append(np.count_nonzero(K.get_value(weight)))

        # Weights pruned at steps 1,3,5
        expected_non_zero_count = [100, 90, 90, 70, 70, 50, 50, 50, 50, 50]
        self.assertAllEqual(expected_non_zero_count, non_zero_count)
    def testExtremelySparseMask(self):
        weight = tf.Variable(np.linspace(1.0, 100.0, 100), name="weights")
        weight_dtype = weight.dtype.base_dtype
        mask = tf.Variable(tf.ones(weight.get_shape(), dtype=weight_dtype),
                           name="mask",
                           dtype=weight_dtype)
        threshold = tf.Variable(tf.zeros([], dtype=weight_dtype),
                                name="threshold",
                                dtype=weight_dtype)
        self.initialize()

        extreme_sparsity = pruning_schedule.ConstantSparsity(0.9999, 0, 100, 1)
        p = pruning_impl.Pruning(pruning_vars=[(weight, mask, threshold)],
                                 training_step_fn=self.training_step_fn,
                                 pruning_schedule=extreme_sparsity,
                                 block_size=self.block_size,
                                 block_pooling_type=self.block_pooling_type)

        mask_before_pruning = K.get_value(mask)
        self.assertAllEqual(np.count_nonzero(mask_before_pruning), 100)

        if tf.executing_eagerly():
            p.conditional_mask_update()
        else:
            K.get_session().run(p.conditional_mask_update())

        # We should always have a single connection remaining.
        mask_after_pruning = K.get_value(mask)
        self.assertAllEqual(np.count_nonzero(mask_after_pruning), 1)
示例#3
0
    def build(self, input_shape):
        super(PruneLowMagnitude, self).build(input_shape)

        weight_vars, mask_vars, threshold_vars = [], [], []

        self.prunable_weights = self.layer.get_prunable_weights()

        # For each of the prunable weights, add mask and threshold variables
        for weight in self.prunable_weights:
            mask = self.add_variable(
                'mask',
                shape=weight.shape,
                initializer=tf.keras.initializers.get('ones'),
                dtype=weight.dtype,
                trainable=False,
                aggregation=tf.VariableAggregation.MEAN)
            threshold = self.add_variable(
                'threshold',
                shape=[],
                initializer=tf.keras.initializers.get('zeros'),
                dtype=weight.dtype,
                trainable=False,
                aggregation=tf.VariableAggregation.MEAN)

            weight_vars.append(weight)
            mask_vars.append(mask)
            threshold_vars.append(threshold)
        self.pruning_vars = list(zip(weight_vars, mask_vars, threshold_vars))

        # Add a scalar tracking the number of updates to the wrapped layer.
        self.pruning_step = self.add_variable(
            'pruning_step',
            shape=[],
            initializer=tf.keras.initializers.Constant(-1),
            dtype=tf.int64,
            trainable=False)

        def training_step_fn():
            return self.pruning_step

        # Create a pruning object
        self.pruning_obj = pruning_impl.Pruning(
            training_step_fn=training_step_fn,
            pruning_vars=self.pruning_vars,
            pruning_schedule=self.pruning_schedule,
            block_size=self.block_size,
            sparsity_m_by_n=self.sparsity_m_by_n,
            block_pooling_type=self.block_pooling_type)
    def _blockMasking(self, block_size, block_pooling_type, weight,
                      expected_mask):
        mask = K.ones(weight.get_shape())
        threshold = K.zeros([])

        # Set up pruning
        p = pruning_impl.Pruning(pruning_vars=[(weight, mask, threshold)],
                                 training_step_fn=self.training_step_fn,
                                 pruning_schedule=self.constant_sparsity,
                                 block_size=block_size,
                                 block_pooling_type=block_pooling_type)

        _, new_mask = p._maybe_update_block_mask(weight)
        # Check if the mask is the same size as the weights
        self.assertAllEqual(new_mask.get_shape(), weight.get_shape())
        mask_after_pruning = K.get_value(new_mask)
        self.assertAllEqual(mask_after_pruning, expected_mask)
    def testConstructsMaskAndThresholdCorrectly(self):
        p = pruning_impl.Pruning(
            lambda: 0,
            None,
            # Sparsity math often returns values with small tolerances.
            lambda x: (True, 0.200000018),
            (1, 1),
            None)

        # input matrix is [ 1.0, 2.0, ..., 8.0, 9.0, 10.0 ]
        threshold, mask = p._update_mask(np.arange(1, 11))

        self.assertEqual(3, K.get_value(threshold))
        self.assertAllEqual(
            # expected matrix is [ 0.0, 0.0, 1.0, 1.0 ... 1.0 ]
            np.concatenate((np.zeros(2), np.ones(8))),
            K.get_value(mask))
    def _sparsity_m_by_n_masking(self, weight, m_by_n=(2, 4)):
        mask = tf.Variable(tf.ones(weight.get_shape()), name="mask")
        threshold = tf.Variable(1, name="threshold")
        self.initialize()

        # Set up pruning
        p = pruning_impl.Pruning(
            pruning_vars=[(weight, mask, threshold)],
            training_step_fn=self.training_step_fn,
            pruning_schedule=self.constant_sparsity,
            block_size=(1, 1),
            block_pooling_type="AVG",
            sparsity_m_by_n=m_by_n,
        )

        _, new_mask = p._maybe_update_block_mask(weight)

        return new_mask
    def testUpdateSingleMask(self):
        weight = K.variable(np.linspace(1.0, 100.0, 100), name="weights")
        mask = K.ones(weight.get_shape())
        threshold = K.zeros([])

        p = pruning_impl.Pruning(pruning_vars=[(weight, mask, threshold)],
                                 training_step_fn=self.training_step_fn,
                                 pruning_schedule=self.constant_sparsity,
                                 block_size=self.block_size,
                                 block_pooling_type=self.block_pooling_type)

        mask_before_pruning = K.get_value(mask)
        self.assertAllEqual(np.count_nonzero(mask_before_pruning), 100)

        if context.executing_eagerly():
            p.conditional_mask_update()
        else:
            K.get_session().run(p.conditional_mask_update())

        mask_after_pruning = K.get_value(mask)
        self.assertAllEqual(np.count_nonzero(mask_after_pruning), 50)