def testSparsityDictErdosRenyiSparsitiesScale( self, shape1, shape2, default_sparsity): _ = self._setup_session() all_masks = [tf.get_variable(shape=shape1, name='var1/mask'), tf.get_variable(shape=shape2, name='var2/mask')] custom_sparsity = {} sparsities = sparse_utils.get_sparsities( all_masks, 'erdos_renyi', default_sparsity, custom_sparsity) sparsity1 = sparsities[all_masks[0].name] size1 = np.prod(shape1) sparsity2 = sparsities[all_masks[1].name] size2 = np.prod(shape2) # Ensure that total number of connections are similar. expected_zeros_uniform = ( sparse_utils.get_n_zeros(size1, default_sparsity) + sparse_utils.get_n_zeros(size2, default_sparsity)) # Ensure that total number of connections are similar. expected_zeros_current = ( sparse_utils.get_n_zeros(size1, sparsity1) + sparse_utils.get_n_zeros(size2, sparsity2)) # Due to rounding we can have some difference. This is expected but should # be less than number of rounding operations we make. diff = abs(expected_zeros_uniform - expected_zeros_current) tolerance = 2 self.assertLessEqual(diff, tolerance) # Ensure that ErdosRenyi proportions are preserved. factor1 = (shape1[-1] + shape1[-2]) / float(shape1[-1] * shape1[-2]) factor2 = (shape2[-1] + shape2[-2]) / float(shape2[-1] * shape2[-2]) self.assertAlmostEqual((1 - sparsity1) / factor1, (1 - sparsity2) / factor2)
def dnw_fn(mask, sparsity, dtype): """Creates a mask with smallest magnitudes with deterministic sparsity. Args: mask: tf.Tensor, used to obtain correct corresponding gradient. sparsity: float, between 0 and 1. dtype: tf.dtype, type of the return value. Returns: tf.Tensor """ del dtype var_name = sparse_utils.mask_extract_name_fn(mask.name) v = vars_dict[var_name] score_drop = math_ops.abs(v) n_total = np.prod(score_drop.shape.as_list()) n_prune = sparse_utils.get_n_zeros(n_total, sparsity) n_keep = n_total - n_prune # Sort the entire array since the k needs to be constant for TPU. _, sorted_indices = nn_ops.top_k(array_ops.reshape( score_drop, [-1]), k=n_total) sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1) # We will have zeros after having `n_keep` many ones. new_values = array_ops.where( math_ops.range(n_total) < n_keep, array_ops.ones_like(sorted_indices, dtype=mask.dtype), array_ops.zeros_like(sorted_indices, dtype=mask.dtype)) new_mask = array_ops.scatter_nd(sorted_indices_ex, new_values, new_values.shape) return array_ops.reshape(new_mask, mask.shape)
def testDNWSparsity(self, n_inp, n_out, default_sparsity): """Checking whether masked_grad is calculated after apply_gradients.""" # No drop since we don't want to change the mask but check whether the grad # is calculated after the gradient step. sess, train_op, _, mask, _ = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) _ = sess.run([train_op]) dnw_mask, = sess.run([mask]) n_ones = np.sum(dnw_mask) n_zeros = dnw_mask.size - n_ones n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size, default_sparsity) self.assertEqual(n_zeros, n_zeros_expected)
def testSparsityAfterDNWUpdates(self, n_inp, n_out, default_sparsity): """Checking whether mask is updated correctly.""" sess, train_op, _, mask, _ = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) # On all iterations mask should have least magnitude connections. for _ in range(5): sess.run([train_op]) dnw_mask, = sess.run([mask]) n_ones = np.sum(dnw_mask) n_zeros = dnw_mask.size - n_ones n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size, default_sparsity) self.assertEqual(n_zeros, n_zeros_expected)