def testSparsityDictErdosRenyiSparsitiesScale( self, shape1, shape2, default_sparsity): _ = self._setup_session() all_masks = [tf.get_variable(shape=shape1, name='var1/mask'), tf.get_variable(shape=shape2, name='var2/mask')] custom_sparsity = {} sparsities = sparse_utils.get_sparsities( all_masks, 'erdos_renyi', default_sparsity, custom_sparsity) sparsity1 = sparsities[all_masks[0].name] size1 = np.prod(shape1) sparsity2 = sparsities[all_masks[1].name] size2 = np.prod(shape2) # Ensure that total number of connections are similar. expected_zeros_uniform = ( sparse_utils.get_n_zeros(size1, default_sparsity) + sparse_utils.get_n_zeros(size2, default_sparsity)) # Ensure that total number of connections are similar. expected_zeros_current = ( sparse_utils.get_n_zeros(size1, sparsity1) + sparse_utils.get_n_zeros(size2, sparsity2)) # Due to rounding we can have some difference. This is expected but should # be less than number of rounding operations we make. diff = abs(expected_zeros_uniform - expected_zeros_current) tolerance = 2 self.assertLessEqual(diff, tolerance) # Ensure that ErdosRenyi proportions are preserved. factor1 = (shape1[-1] + shape1[-2]) / float(shape1[-1] * shape1[-2]) factor2 = (shape2[-1] + shape2[-2]) / float(shape2[-1] * shape2[-2]) self.assertAlmostEqual((1 - sparsity1) / factor1, (1 - sparsity2) / factor2)
def testSparsityDictErdosRenyiError(self, default_sparsity): _ = self._setup_session() all_masks = [tf.get_variable(shape=(2, 4), name='var1/mask'), tf.get_variable(shape=(2, 3), name='var2/mask'), tf.get_variable(shape=(1, 1, 3), name='var3/mask')] custom_sparsity = {'var3': 0.8} sparsities = sparse_utils.get_sparsities( all_masks, 'erdos_renyi', default_sparsity, custom_sparsity) self.assertEqual(sparsities[all_masks[2].name], 0.8)
def testSparsityDictRandom(self, default_sparsity): _ = self._setup_session() all_masks = [tf.get_variable(shape=(2, 3), name='var1/mask'), tf.get_variable(shape=(2, 3), name='var2/mask'), tf.get_variable(shape=(1, 1, 3), name='var3/mask')] custom_sparsity = {'var1': 0.8} sparsities = sparse_utils.get_sparsities( all_masks, 'random', default_sparsity, custom_sparsity) self.assertEqual(sparsities[all_masks[0].name], 0.8) self.assertEqual(sparsities[all_masks[1].name], default_sparsity) self.assertEqual(sparsities[all_masks[2].name], default_sparsity)