예제 #1
0
  def testSparsityDictErdosRenyiSparsitiesScale(
      self, shape1, shape2, default_sparsity):
    _ = self._setup_session()
    all_masks = [tf.get_variable(shape=shape1, name='var1/mask'),
                 tf.get_variable(shape=shape2, name='var2/mask')]
    custom_sparsity = {}
    sparsities = sparse_utils.get_sparsities(
        all_masks, 'erdos_renyi', default_sparsity, custom_sparsity)
    sparsity1 = sparsities[all_masks[0].name]
    size1 = np.prod(shape1)
    sparsity2 = sparsities[all_masks[1].name]
    size2 = np.prod(shape2)
    # Ensure that total number of connections are similar.
    expected_zeros_uniform = (
        sparse_utils.get_n_zeros(size1, default_sparsity) +
        sparse_utils.get_n_zeros(size2, default_sparsity))
    # Ensure that total number of connections are similar.
    expected_zeros_current = (
        sparse_utils.get_n_zeros(size1, sparsity1) +
        sparse_utils.get_n_zeros(size2, sparsity2))
    # Due to rounding we can have some difference. This is expected but should
    # be less than number of rounding operations we make.
    diff = abs(expected_zeros_uniform - expected_zeros_current)
    tolerance = 2
    self.assertLessEqual(diff, tolerance)

    # Ensure that ErdosRenyi proportions are preserved.
    factor1 = (shape1[-1] + shape1[-2]) / float(shape1[-1] * shape1[-2])
    factor2 = (shape2[-1] + shape2[-2]) / float(shape2[-1] * shape2[-2])
    self.assertAlmostEqual((1 - sparsity1) / factor1,
                           (1 - sparsity2) / factor2)
예제 #2
0
        def dnw_fn(mask, sparsity, dtype):
            """Creates a mask with smallest magnitudes with deterministic sparsity.

      Args:
        mask: tf.Tensor, used to obtain correct corresponding gradient.
        sparsity: float, between 0 and 1.
        dtype: tf.dtype, type of the return value.

      Returns:
        tf.Tensor
      """
            del dtype
            var_name = sparse_utils.mask_extract_name_fn(mask.name)
            v = vars_dict[var_name]
            score_drop = math_ops.abs(v)
            n_total = np.prod(score_drop.shape.as_list())
            n_prune = sparse_utils.get_n_zeros(n_total, sparsity)
            n_keep = n_total - n_prune

            # Sort the entire array since the k needs to be constant for TPU.
            _, sorted_indices = nn_ops.top_k(array_ops.reshape(
                score_drop, [-1]),
                                             k=n_total)
            sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)
            # We will have zeros after having `n_keep` many ones.
            new_values = array_ops.where(
                math_ops.range(n_total) < n_keep,
                array_ops.ones_like(sorted_indices, dtype=mask.dtype),
                array_ops.zeros_like(sorted_indices, dtype=mask.dtype))
            new_mask = array_ops.scatter_nd(sorted_indices_ex, new_values,
                                            new_values.shape)
            return array_ops.reshape(new_mask, mask.shape)
예제 #3
0
 def testDNWSparsity(self, n_inp, n_out, default_sparsity):
   """Checking whether masked_grad is calculated after apply_gradients."""
   # No drop since we don't want to change the mask but check whether the grad
   # is calculated after the gradient step.
   sess, train_op, _, mask, _ = self._setup_graph(
       default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
   _ = sess.run([train_op])
   dnw_mask, = sess.run([mask])
   n_ones = np.sum(dnw_mask)
   n_zeros = dnw_mask.size - n_ones
   n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size, default_sparsity)
   self.assertEqual(n_zeros, n_zeros_expected)
예제 #4
0
 def testSparsityAfterDNWUpdates(self, n_inp, n_out, default_sparsity):
   """Checking whether mask is updated correctly."""
   sess, train_op, _, mask, _ = self._setup_graph(
       default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
   # On all iterations mask should have least magnitude connections.
   for _ in range(5):
     sess.run([train_op])
     dnw_mask, = sess.run([mask])
     n_ones = np.sum(dnw_mask)
     n_zeros = dnw_mask.size - n_ones
     n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size,
                                                 default_sparsity)
     self.assertEqual(n_zeros, n_zeros_expected)