def testWrapper_CreatesProperCompressorOption1(self, low_rank_mock):
        hparams = self._create_compression_op_spec(1)
        mock_compressor = MatrixCompressorInterfaceMock(
            self._default_compressor_spec(hparams))
        low_rank_mock.side_effect = [mock_compressor]

        with mock.patch.object(comp_op, 'ApplyCompression') as apply_mock:
            compression_wrapper.get_apply_compression(hparams, _GLOBAL_STEP)
            apply_mock.assert_called_with(scope='default_scope',
                                          compression_spec=hparams,
                                          compressor=mock_compressor,
                                          global_step=_GLOBAL_STEP)
Example #2
0
    def testWrapper_CreatesProperCompressorOption2(self, sim_hash_mock):
        hparams = self._create_compression_op_spec(
            comp_op_utils.CompressionOptions.SIMHASH_MATRIX_COMPRESSION)
        mock_compressor = MatrixCompressorInterfaceMock(
            self._default_compressor_spec(hparams))
        sim_hash_mock.side_effect = [mock_compressor]

        with mock.patch.object(compression_wrapper,
                               'ApplyCompression') as apply_mock:
            compression_wrapper.get_apply_compression(hparams, _GLOBAL_STEP)
            apply_mock.assert_called_with(scope='default_scope',
                                          compression_spec=hparams,
                                          compressor=mock_compressor,
                                          global_step=_GLOBAL_STEP)
def get_matrix_compression_object(
        hparams,  # pylint:disable=invalid-name
        global_step=None,
        sparsity=None):
    """Returns a pruning/compression object.

  Args:
    hparams: Pruning spec as defined in pruing.py;
    global_step: A tensorflow variable that is used for scheduling
    pruning/compression;
    sparsity: A tensorflow scalar variable storing the sparsity.

  Returns:
    A Pruning or compression_lib.compression_op.ApplyCompression object.
  """
    if global_step is None:
        train_global_step = tf.train.get_global_step()
        if train_global_step is None:
            global_step = 0
        else:
            global_step = tf.cast(train_global_step, tf.int32)
    if hparams.prune_option in [
            'weight', 'first_order_gradient', 'second_order_gradient'
    ]:
        return pruning.Pruning(hparams, global_step, sparsity)
    else:
        return compression_wrapper.get_apply_compression(
            hparams, global_step=global_step)
Example #4
0
    def testWrapper_CreatesProperCompressorOption1(self):
        hparams = self._create_compression_op_spec(
            comp_op_utils.CompressionOptions.LOWRANK_MATRIX_COMPRESSION)
        mock_compressor = MatrixCompressorInterfaceMock(
            self._default_compressor_spec(hparams))
        self.enter_context(
            mock.patch.object(comp_op,
                              'LowRankDecompMatrixCompressor',
                              side_effect=[mock_compressor]))

        with mock.patch.object(compression_wrapper,
                               'ApplyCompression') as apply_mock:
            compression_wrapper.get_apply_compression(hparams, _GLOBAL_STEP)
            apply_mock.assert_called_with(scope='default_scope',
                                          compression_spec=hparams,
                                          compressor=mock_compressor,
                                          global_step=_GLOBAL_STEP)