def testWrapper_CreatesProperCompressorOption1(self, low_rank_mock): hparams = self._create_compression_op_spec(1) mock_compressor = MatrixCompressorInterfaceMock( self._default_compressor_spec(hparams)) low_rank_mock.side_effect = [mock_compressor] with mock.patch.object(comp_op, 'ApplyCompression') as apply_mock: compression_wrapper.get_apply_compression(hparams, _GLOBAL_STEP) apply_mock.assert_called_with(scope='default_scope', compression_spec=hparams, compressor=mock_compressor, global_step=_GLOBAL_STEP)
def testWrapper_CreatesProperCompressorOption2(self, sim_hash_mock): hparams = self._create_compression_op_spec( comp_op_utils.CompressionOptions.SIMHASH_MATRIX_COMPRESSION) mock_compressor = MatrixCompressorInterfaceMock( self._default_compressor_spec(hparams)) sim_hash_mock.side_effect = [mock_compressor] with mock.patch.object(compression_wrapper, 'ApplyCompression') as apply_mock: compression_wrapper.get_apply_compression(hparams, _GLOBAL_STEP) apply_mock.assert_called_with(scope='default_scope', compression_spec=hparams, compressor=mock_compressor, global_step=_GLOBAL_STEP)
def get_matrix_compression_object( hparams, # pylint:disable=invalid-name global_step=None, sparsity=None): """Returns a pruning/compression object. Args: hparams: Pruning spec as defined in pruing.py; global_step: A tensorflow variable that is used for scheduling pruning/compression; sparsity: A tensorflow scalar variable storing the sparsity. Returns: A Pruning or compression_lib.compression_op.ApplyCompression object. """ if global_step is None: train_global_step = tf.train.get_global_step() if train_global_step is None: global_step = 0 else: global_step = tf.cast(train_global_step, tf.int32) if hparams.prune_option in [ 'weight', 'first_order_gradient', 'second_order_gradient' ]: return pruning.Pruning(hparams, global_step, sparsity) else: return compression_wrapper.get_apply_compression( hparams, global_step=global_step)
def testWrapper_CreatesProperCompressorOption1(self): hparams = self._create_compression_op_spec( comp_op_utils.CompressionOptions.LOWRANK_MATRIX_COMPRESSION) mock_compressor = MatrixCompressorInterfaceMock( self._default_compressor_spec(hparams)) self.enter_context( mock.patch.object(comp_op, 'LowRankDecompMatrixCompressor', side_effect=[mock_compressor])) with mock.patch.object(compression_wrapper, 'ApplyCompression') as apply_mock: compression_wrapper.get_apply_compression(hparams, _GLOBAL_STEP) apply_mock.assert_called_with(scope='default_scope', compression_spec=hparams, compressor=mock_compressor, global_step=_GLOBAL_STEP)