def testConditionalMaskUpdate(self): param_list = [ "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6", "nbins=100" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) weights = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) sparsity = tf.Variable(0.00, name="sparsity") # Set up pruning p = pruning.Pruning(pruning_hparams, sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.conditional_mask_update_op() sparsity_val = tf.linspace(0.0, 0.9, 10) increment_global_step = tf.assign_add(self.global_step, 1) non_zero_count = [] with self.cached_session() as session: tf.global_variables_initializer().run() for i in range(10): session.run(tf.assign(sparsity, sparsity_val[i])) session.run(mask_update_op) session.run(increment_global_step) non_zero_count.append(np.count_nonzero(masked_weights.eval())) # Weights pruned at steps 0,2,4,and,6 expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] self.assertAllEqual(expected_non_zero_count, non_zero_count)
def testWeightSpecificSparsity(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "weight_sparsity_map=[layer1:0.6,layer2/weights:0.75,.*kernel:0.6]", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with tf.variable_scope("layer1"): w1 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w1) with tf.variable_scope("layer2"): w2 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w2) with tf.variable_scope("layer3"): w3 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="kernel") _ = pruning.apply_mask(w3) p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = tf.assign_add(self.global_step, 1) with self.cached_session() as session: tf.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllClose(session.run(pruning.get_weight_sparsity()), [0.6, 0.75, 0.6])
def CreateDenseCoordinates(self, ranges): """Create a matrix of coordinate locations corresponding to a dense grid. Example: To create (x, y) coordinates corresponding over a 10x10 grid with step sizes 1, call ``CreateDenseCoordinates([(1, 10, 10), (1, 10, 10)])``. Args: ranges: A list of 3-tuples, each tuple is expected to contain (min, max, num_steps). Each list element corresponds to one dimesion. Each tuple will be passed into np.linspace to create the values for a single dimension. Returns: tf.float32 tensor of shape [total_points, len(ranges)], where total_points = product of all num_steps. """ total_points = int(np.prod([r_steps for _, _, r_steps in ranges])) cycle_steps = total_points stack_coordinates = [] for r_start, r_stop, r_steps in ranges: values = tf.linspace(tf.cast(r_start, tf.float32), tf.cast(r_stop, tf.float32), tf.cast(r_steps, tf.int32)) cycle_steps //= r_steps gather_idx = (tf.range(total_points) // cycle_steps) % r_steps stack_coordinates.append(tf.gather(values, gather_idx)) return tf.stack(stack_coordinates, axis=1)
def testUpdateSingleMask(self): with self.cached_session() as session: weights = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) sparsity = tf.Variable(0.95, name="sparsity") p = pruning.Pruning(sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.mask_update_op() tf.global_variables_initializer().run() masked_weights_val = masked_weights.eval() self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) session.run(mask_update_op) masked_weights_val = masked_weights.eval() self.assertAllEqual(np.count_nonzero(masked_weights_val), 5)
def testPartitionedVariableMasking(self): partitioner = tf.variable_axis_size_partitioner(40) with self.cached_session() as session: with tf.variable_scope("", partitioner=partitioner): sparsity = tf.Variable(0.5, name="Sparsity") weights = tf.get_variable("weights", initializer=tf.linspace( 1.0, 100.0, 100)) masked_weights = pruning.apply_mask( weights, scope=tf.get_variable_scope()) p = pruning.Pruning(sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.mask_update_op() tf.global_variables_initializer().run() masked_weights_val = masked_weights.eval() session.run(mask_update_op) masked_weights_val = masked_weights.eval() self.assertAllEqual(np.count_nonzero(masked_weights_val), 50)