def testLowRankDecompMatrixCompressorInterface(self):
   spec = compression_op.LowRankDecompMatrixCompressor.get_default_hparams()
   compressor = compression_op.LowRankDecompMatrixCompressor(spec)
   b_matrix = np.random.normal(0, 1, [10, 5])
   c_matrix = np.random.normal(0, 1, [5, 10])
   a_matrix = np.matmul(b_matrix, c_matrix)
   [b_matrix_out, c_matrix_out] = compressor.static_matrix_compressor(a_matrix)
   a_matrix_recovered = np.matmul(b_matrix_out, c_matrix_out)
   self.assertLess(np.linalg.norm(a_matrix - a_matrix_recovered), 0.01)
Esempio n. 2
0
    def test_get_apply_matmul(self):
        with tf.Graph().as_default():
            with self.cached_session():
                hparams = ("name=block_compression,"
                           "compression_option=10,"
                           "begin_compression_step=1000,"
                           "end_compression_step=120000,"
                           "compression_frequency=100,"
                           "block_method=mask,"
                           "block_compression_factor=2,")
                compression_op_spec = (compression_op.BlockCompressionOp.
                                       get_default_hparams().parse(hparams))

                compressor_spec = (compression_op.LowRankDecompMatrixCompressor
                                   .get_default_hparams())
                matrix_compressor = compression_op.LowRankDecompMatrixCompressor(
                    spec=compressor_spec)

                global_step = tf.compat.v1.get_variable("global_step",
                                                        initializer=100)
                apply_comp = compression_op.ApplyCompression(
                    scope="default_scope",
                    compression_spec=compression_op_spec,
                    compressor=matrix_compressor,
                    global_step=global_step)

                # outer product - creates an 12x8 matrix
                a_matrix_init = np.outer(
                    np.array(
                        [1., 2., 3., 7., 8., 9., 1., 2., 5., -2., -7., -1.]),
                    np.array([4., 5., 6., 3., 1., 8., 3., 2.]))
                a_matrix = tf.compat.v1.get_variable(
                    "a_matrix",
                    initializer=a_matrix_init.astype(np.float32),
                    dtype=tf.float32)
                _ = apply_comp.apply_compression(a_matrix, scope="compressor")
                # input is 1x12 vector
                left_operand_init = np.expand_dims(np.array(
                    [1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.]),
                                                   axis=0)
                left_operand = tf.compat.v1.get_variable(
                    "left_operand",
                    initializer=left_operand_init.astype(np.float32),
                    dtype=tf.float32)
                c = apply_comp._compression_ops[-1]
                tf.compat.v1.global_variables_initializer().run()
                compressed_matmul = c.get_apply_matmul(left_operand)
                # check c, c_mask matrices have the right shapes
                self.assertSequenceEqual(list(c.c_matrix_tfvar.eval().shape),
                                         [12, 8])
                self.assertSequenceEqual(list(c.c_mask_tfvar.eval().shape),
                                         [12, 8])
                # check we get the correct number of nonzero entries in the mask
                self.assertEqual(np.count_nonzero(c.c_mask_tfvar.eval()), 48)
                # check that we get the expected output shape
                self.assertSequenceEqual(list(compressed_matmul.eval().shape),
                                         [1, 8])
  def test_get_apply_matmul(self):
    with tf.Graph().as_default():
      with self.cached_session():
        hparams = ("name=input_output_compression,"
                   "compression_option=9,"
                   "begin_compression_step=1000,"
                   "end_compression_step=120000,"
                   "compression_frequency=100,"
                   "compress_input=True,"
                   "compress_output=True,"
                   "input_compression_factor=2,"
                   "input_block_size=4,"
                   "output_compression_factor=2,"
                   "output_block_size=4,")
        compression_op_spec = (
            compression_op.InputOutputCompressionOp.get_default_hparams().parse(
                hparams))

        compressor_spec = (
            compression_op.LowRankDecompMatrixCompressor.get_default_hparams())
        matrix_compressor = compression_op.LowRankDecompMatrixCompressor(
            spec=compressor_spec)

        global_step = tf.compat.v1.get_variable("global_step", initializer=100)
        apply_comp = compression_op.ApplyCompression(
            scope="default_scope",
            compression_spec=compression_op_spec,
            compressor=matrix_compressor,
            global_step=global_step)

        # outer product - creates an 12x8 matrix
        a_matrix_init = np.outer(
            np.array([1., 2., 3., 7., 8., 9., 1., 2., 5., -2., -7., -1.]),
            np.array([4., 5., 6., 3., 1., 8., 3., 2.]))
        a_matrix = tf.compat.v1.get_variable(
            "a_matrix",
            initializer=a_matrix_init.astype(np.float32),
            dtype=tf.float32)
        _ = apply_comp.apply_compression(
            a_matrix, scope="compressor")
        # input is 1x12 vector
        left_operand_init = np.array(
            [1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.])
        left_operand = tf.compat.v1.get_variable(
            "left_operand",
            initializer=left_operand_init.astype(np.float32),
            dtype=tf.float32)
        c = apply_comp._compression_ops[-1]
        tf.compat.v1.global_variables_initializer().run()
        compressed_matmul = c.get_apply_matmul(left_operand)
        # check b, c and d matrices have the right shapes
        self.assertSequenceEqual(list(c.b_matrix_tfvar.eval().shape), [4, 2])
        self.assertSequenceEqual(list(c.c_matrix_tfvar.eval().shape), [6, 4])
        self.assertSequenceEqual(list(c.d_matrix_tfvar.eval().shape), [2, 4])

        # check that we get the expected output shape
        self.assertSequenceEqual(list(compressed_matmul.eval().shape), [8,])
def get_apply_compression(compression_op_spec, global_step):
    """Returns apply_compression operation matching compression_option input."""
    compressor_spec = comp_op.LowRankDecompMatrixCompressor.get_default_hparams(
    )
    compressor_spec.set_hparam('rank', compression_op_spec.rank)
    compressor_spec.set_hparam('block_size', compression_op_spec.block_size)
    logging.info('Compressor spec %s', compressor_spec.to_json())
    logging.info('Compression operator spec %s', compression_op_spec.to_json())

    if compression_op_spec.compression_option not in _COMPRESSION_OPTIONS:
        logging.info(
            'Compression_option %s not in expected options: %s. '
            'Will use low_rank decomp by default.',
            str(compression_op_spec.compression_option),
            ','.join([str(opt) for opt in _COMPRESSION_OPTIONS]))
        compression_op_spec.compression_option = 1

    apply_compression = None
    if compression_op_spec.compression_option == 1:
        compressor = comp_op.LowRankDecompMatrixCompressor(
            spec=compressor_spec)
        apply_compression = comp_op.ApplyCompression(
            scope='default_scope',
            compression_spec=compression_op_spec,
            compressor=compressor,
            global_step=global_step)
    elif compression_op_spec.compression_option == 2:
        compressor_spec.set_hparam('is_b_matrix_trainable', False)
        compressor = simhash_comp_op.SimhashMatrixCompressor(
            spec=compressor_spec)
        apply_compression = simhash_comp_op.SimhashApplyCompression(
            scope='default_scope',
            compression_spec=compression_op_spec,
            compressor=compressor,
            global_step=global_step)
    elif compression_op_spec.compression_option == 4:
        compressor_spec.set_hparam('is_b_matrix_trainable', False)
        compressor = simhash_comp_op.KmeansMatrixCompressor(
            spec=compressor_spec)
        apply_compression = simhash_comp_op.SimhashApplyCompression(
            scope='default_scope',
            compression_spec=compression_op_spec,
            compressor=compressor,
            global_step=global_step)
    elif compression_op_spec.compression_option == 8:
        compressor_spec.set_hparam('is_b_matrix_trainable', False)
        compressor = simhash_comp_op.KmeansMatrixCompressor(
            spec=compressor_spec)
        apply_compression = simhash_comp_op.SimhashApplyCompression(
            scope='default_scope',
            compression_spec=compression_op_spec,
            compressor=compressor,
            global_step=global_step)

    return apply_compression
Esempio n. 5
0
    def test_get_apply_matmul(self):
        with tf.Graph().as_default():
            with self.cached_session():
                hparams = ("name=mixed_block_compression,"
                           "begin_compression_step=1000,"
                           "end_compression_step=120000,"
                           "compression_frequency=100,"
                           "compression_factor=4,"
                           "num_bases=2,")
                compression_op_spec = (compression_op.MixedBlockCompressionOp.
                                       get_default_hparams().parse(hparams))
                compressor_spec = (compression_op.LowRankDecompMatrixCompressor
                                   .get_default_hparams())
                matrix_compressor = compression_op.LowRankDecompMatrixCompressor(
                    spec=compressor_spec)

                global_step = tf.compat.v1.get_variable("global_step",
                                                        initializer=100)
                apply_comp = compression_wrapper.ApplyCompression(
                    scope="default_scope",
                    compression_spec=compression_op_spec,
                    compressor=matrix_compressor,
                    global_step=global_step)

                # outer product - creates an 12x8 matrix
                a_matrix_init = np.outer(
                    np.array(
                        [1., 2., 3., 7., 8., 9., 1., 2., 5., -2., -7., -1.]),
                    np.array([4., 5., 6., 3., 1., 8., 3., 2.]))
                a_matrix = tf.compat.v1.get_variable(
                    "a_matrix",
                    initializer=a_matrix_init.astype(np.float32),
                    dtype=tf.float32)
                _ = apply_comp.apply_compression(a_matrix, scope="compressor")
                # input is 1x12 vector
                left_operand_init = np.expand_dims(np.array(
                    [1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.]),
                                                   axis=0)
                left_operand = tf.compat.v1.get_variable(
                    "left_operand",
                    initializer=left_operand_init.astype(np.float32),
                    dtype=tf.float32)
                c = apply_comp._compression_ops[-1]
                tf.compat.v1.global_variables_initializer().run()
                compressed_matmul = c.get_apply_matmul(left_operand)
                # check block_matrices and linear_mixer tensors have the right shapes
                self.assertSequenceEqual(list(c.block_matrices.eval().shape),
                                         [3, 2, 4])
                self.assertSequenceEqual(list(c.linear_mixer.eval().shape),
                                         [4, 4, 2])
                # check that we get the expected output shape
                self.assertSequenceEqual(list(compressed_matmul.eval().shape),
                                         [1, 8])
def main(argv):
    del argv  # unused

    tf.enable_v2_behavior()

    # Load MNIST data.
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (_, _) = mnist.load_data()
    x_train = x_train / 255.0
    x_train = x_train.reshape(60000, 784).astype('float32')

    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

    # Define model.
    input_dim = 28 * 28
    num_hidden_nodes = 50
    num_classes = 10

    lowrank_compressor = compression.LowRankDecompMatrixCompressor(
        compression.LowRankDecompMatrixCompressor.get_default_hparams())
    compressed_model = CompressedModel(input_dim, num_hidden_nodes,
                                       num_classes, lowrank_compressor)

    optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
    loss = tf.keras.losses.SparseCategoricalCrossentropy()
    epochs = 10

    step_number = 0
    for epoch in range(epochs):
        for x, y in train_dataset:
            with tf.GradientTape() as t:
                loss_value = loss(y, compressed_model(x))
            grads = t.gradient(loss_value,
                               compressed_model.trainable_variables)
            optimizer.apply_gradients(
                zip(grads, compressed_model.trainable_variables))

            compressed_model.run_alpha_update(step_number)

            step_number += 1
        print('Training loss at epoch {} is {}.'.format(epoch, loss_value))
  def testCompressionOpInterface(self):
    with tf.Graph().as_default():
      with self.cached_session() as sess:
        compression_hparams = ("name=cifar10_compression,"
                               "begin_compression_step=1000,"
                               "end_compression_step=120000,"
                               "compression_frequency=10,"
                               "compression_option=1,"
                               "update_option=0")
        global_step = tf.compat.v1.get_variable("global_step", initializer=30)
        c = compression_op.CompressionOp(
            spec=compression_op.CompressionOp.get_default_hparams().parse(
                compression_hparams),
            global_step=global_step)
        # Need to add initial value for a_matrix so that we would know what
        # to expect back.
        a_matrix_init = np.array([[1.0, 1.0, 1.0], [1.0, 0, 0], [1.0, 0, 0]])
        a_matrix = tf.compat.v1.get_variable(
            "a_matrix",
            initializer=a_matrix_init.astype(np.float32),
            dtype=tf.float32)
        matrix_compressor = compression_op.LowRankDecompMatrixCompressor(
            spec=compression_op.LowRankDecompMatrixCompressor
            .get_default_hparams().parse("num_rows=3,num_cols=3,rank=200"))

        [a_matrix_compressed, a_matrix_update_op] = c.get_apply_compression_op(
            a_matrix, matrix_compressor, scope="my_scope")

        tf.compat.v1.global_variables_initializer().run()
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.a_matrix_tfvar.eval())) < 0.00001),
            False)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.b_matrix_tfvar.eval())) < 0.00001),
            True)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.c_matrix_tfvar.eval())) < 0.00001),
            True)

        tf.compat.v1.assign(global_step, 1001).eval()
        sess.run(a_matrix_update_op)
        a_matrix_compressed.eval()
        self.assertEqual(c._global_step.eval(), 1001)
        self.assertAlmostEqual(c.alpha.eval(), 0.99)
        self.assertEqual(c._last_alpha_update_step.eval(), 1001)
        self.assertAllEqual(
            np.array([
                np.linalg.norm(c.a_matrix_tfvar.eval()),
                np.linalg.norm(c.b_matrix_tfvar.eval()),
                np.linalg.norm(c.c_matrix_tfvar.eval())
            ]) > 0, [True, True, True])

        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.b_matrix_tfvar.eval())) < 0.00001),
            False)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.c_matrix_tfvar.eval())) < 0.00001),
            False)

        [b_matrix,
         c_matrix] = matrix_compressor.static_matrix_compressor(a_matrix_init)
        # since the matrices may match up to signs, we take absolute values.
        self.assertAllEqual(
            np.linalg.norm(np.abs(b_matrix) - np.abs(c.b_matrix_tfvar.eval())) <
            0.00001, True)
        self.assertAllEqual(
            np.linalg.norm(np.abs(c_matrix) - np.abs(c.c_matrix_tfvar.eval())) <
            0.00001, True)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.b_matrix_tfvar.eval())) < 0.00001),
            False)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.c_matrix_tfvar.eval())) < 0.00001),
            False)

        print("before 1002 step, c.alpha is ", c.alpha.eval())
        tf.compat.v1.assign(global_step, 1001).eval()
        sess.run(a_matrix_update_op)
        a_matrix_compressed.eval()
        print("after 1002 step, c.alpha is ", c.alpha.eval())
        self.assertEqual(c._global_step.eval(), 1001)
        self.assertAlmostEqual(c.alpha.eval(), 0.99)
        self.assertEqual(c._last_alpha_update_step.eval(), 1001)
        self.assertAllEqual(
            np.all([
                np.linalg.norm(c.a_matrix_tfvar.eval()),
                np.linalg.norm(c.b_matrix_tfvar.eval()),
                np.linalg.norm(c.c_matrix_tfvar.eval())
            ]) > 0, True)

        print("before 2000 step, alpha is ", c.alpha.eval())
        tf.compat.v1.assign(global_step, 2000).eval()
        a_matrix_update_op.eval()
        a_matrix_compressed.eval()
        print("after 2000 step, alpha is ", c.alpha.eval())
        self.assertEqual(c._global_step.eval(), 2000)
        self.assertAlmostEqual(c.alpha.eval(), 0.98)
        self.assertEqual(c._last_alpha_update_step.eval(), 2000)
        self.assertAllEqual(
            np.array([
                np.linalg.norm(c.a_matrix_tfvar.eval()),
                np.linalg.norm(c.b_matrix_tfvar.eval()),
                np.linalg.norm(c.c_matrix_tfvar.eval())
            ]) > 0, [True, True, True])
  def testApplyCompression(self):
    with tf.Graph().as_default():
      with self.cached_session():
        compression_hparams = ("name=cifar10_compression,"
                               "begin_compression_step=1000,"
                               "end_compression_step=120000,"
                               "compression_frequency=100,"
                               "compression_option=1")
        compression_op_spec = (
            compression_op.CompressionOp.get_default_hparams().parse(
                compression_hparams))
        compressor_spec = (
            compression_op.LowRankDecompMatrixCompressor.get_default_hparams()
            .parse("num_rows=5,num_cols=5,rank=200"))
        matrix_compressor = compression_op.LowRankDecompMatrixCompressor(
            spec=compressor_spec)

        global_step = tf.compat.v1.get_variable("global_step", initializer=30)

        apply_comp = compression_op.ApplyCompression(
            scope="default_scope",
            compression_spec=compression_op_spec,
            compressor=matrix_compressor,
            global_step=global_step)
        # Need to add initial value for a_matrix so that we would know what
        # to expect back.
        a_matrix_init = np.outer(np.array([1., 2., 3.]), np.array([4., 5., 6.]))
        a_matrix = tf.compat.v1.get_variable(
            "a_matrix",
            initializer=a_matrix_init.astype(np.float32),
            dtype=tf.float32)
        a_matrix_compressed = apply_comp.apply_compression(
            a_matrix, scope="first_compressor")
        c = apply_comp._compression_ops[0]

        a_matrix2 = tf.compat.v1.get_variable(
            "a_matrix2",
            initializer=a_matrix_init.astype(np.float32),
            dtype=tf.float32)
        _ = apply_comp.apply_compression(a_matrix2, scope="second_compressor")
        c2 = apply_comp._compression_ops[1]

        _ = apply_comp.all_update_op()

        tf.compat.v1.global_variables_initializer().run()
        _ = a_matrix_compressed.eval()
        self.assertEqual(c._global_step.eval(), 30)
        self.assertEqual(c.alpha.eval(), 1.0)
        self.assertEqual(c2.alpha.eval(), 1.0)
        self.assertEqual(c._last_alpha_update_step.eval(), -1)
        self.assertAllEqual(
            np.array([
                np.linalg.norm(c.a_matrix_tfvar.eval()),
                np.linalg.norm(c.b_matrix_tfvar.eval()),
                np.linalg.norm(c.c_matrix_tfvar.eval())
            ]) > 0, [True, False, False])

        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.a_matrix_tfvar.eval())) < 0.00001),
            False)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.b_matrix_tfvar.eval())) < 0.00001),
            True)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.c_matrix_tfvar.eval())) < 0.00001),
            True)
        tf.compat.v1.assign(global_step, 1001).eval()
        # apply_comp_update_op.run()
        apply_comp._all_update_op.run()
        _ = a_matrix_compressed.eval()
        self.assertEqual(c._global_step.eval(), 1001)
        self.assertAlmostEqual(c.alpha.eval(), 0.99)
        self.assertEqual(c._last_alpha_update_step.eval(), 1001)
        self.assertAllEqual(
            np.array([
                np.linalg.norm(c.a_matrix_tfvar.eval()),
                np.linalg.norm(c.b_matrix_tfvar.eval()),
                np.linalg.norm(c.c_matrix_tfvar.eval())
            ]) > 0, [True, True, True])
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.b_matrix_tfvar.eval())) < 0.00001),
            False)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.c_matrix_tfvar.eval())) < 0.00001),
            False)

        [b_matrix,
         c_matrix] = matrix_compressor.static_matrix_compressor(a_matrix_init)

        self.assertAllEqual(
            np.linalg.norm(np.abs(b_matrix) - np.abs(c.b_matrix_tfvar.eval())) <
            0.00001, True)
        self.assertAllEqual(
            np.linalg.norm(np.abs(c_matrix) - np.abs(c.c_matrix_tfvar.eval())) <
            0.00001, True)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.b_matrix_tfvar.eval())) < 0.00001),
            False)
        self.assertAllEqual(
            np.all(np.abs(np.linalg.norm(c.c_matrix_tfvar.eval())) < 0.00001),
            False)

        tf.compat.v1.assign(global_step, 1001).eval()
        apply_comp._all_update_op.run()
        _ = a_matrix_compressed.eval()
        self.assertEqual(c._global_step.eval(), 1001)
        self.assertAlmostEqual(c.alpha.eval(), 0.99)
        self.assertEqual(c._last_alpha_update_step.eval(), 1001)
        self.assertAllEqual(
            np.array([
                np.linalg.norm(c.a_matrix_tfvar.eval()),
                np.linalg.norm(c.b_matrix_tfvar.eval()),
                np.linalg.norm(c.c_matrix_tfvar.eval())
            ]) > 0, [True, True, True])

        tf.compat.v1.assign(global_step, 2001).eval()
        apply_comp._all_update_op.run()
        _ = a_matrix_compressed.eval()
        self.assertEqual(c._global_step.eval(), 2001)
        self.assertAlmostEqual(c.alpha.eval(), 0.98)
        self.assertAlmostEqual(c2.alpha.eval(), 0.98)
        self.assertEqual(c._last_alpha_update_step.eval(), 2001)
        self.assertAllEqual(
            np.array([
                np.linalg.norm(c.a_matrix_tfvar.eval()),
                np.linalg.norm(c.b_matrix_tfvar.eval()),
                np.linalg.norm(c.c_matrix_tfvar.eval())
            ]) > 0, [True, True, True])
Esempio n. 9
0
def get_apply_compression(compression_op_spec, global_step):
  """Returns apply_compression operation matching compression_option input."""
  compressor_spec = comp_op.LowRankDecompMatrixCompressor.get_default_hparams()
  if compression_op_spec.__contains__('rank'):
    compressor_spec.set_hparam('rank', compression_op_spec.rank)
  if compression_op_spec.__contains__('block_size'):
    compressor_spec.set_hparam('block_size', compression_op_spec.block_size)
  logging.info('Compressor spec %s', compressor_spec.to_json())
  logging.info('Compression operator spec %s', compression_op_spec.to_json())

  if compression_op_spec.compression_option not in list(CompressionOptions):
    # if unknown compression_option is given, default to low rank compression.
    logging.info(
        'Compression_option %s not in expected options: %s. '
        'Will use low_rank decomp by default.',
        str(compression_op_spec.compression_option),
        ','.join([str(opt) for opt in CompressionOptions]))
    compression_op_spec.compression_option = CompressionOptions.LOWRANK_MATRIX_COMPRESSION

  apply_compression = None
  if compression_op_spec.compression_option == CompressionOptions.LOWRANK_MATRIX_COMPRESSION:
    compressor = comp_op.LowRankDecompMatrixCompressor(spec=compressor_spec)
    apply_compression = ApplyCompression(
        scope='default_scope',
        compression_spec=compression_op_spec,
        compressor=compressor,
        global_step=global_step)
  elif compression_op_spec.compression_option == CompressionOptions.SIMHASH_MATRIX_COMPRESSION:
    compressor_spec.set_hparam('is_b_matrix_trainable', False)
    compressor = simhash_comp_op.SimhashMatrixCompressor(spec=compressor_spec)
    apply_compression = ApplyCompression(
        scope='default_scope',
        compression_spec=compression_op_spec,
        compressor=compressor,
        global_step=global_step)
  elif compression_op_spec.compression_option == CompressionOptions.KMEANS_MATRIX_COMPRESSION:
    compressor_spec.set_hparam('is_b_matrix_trainable', True)
    compressor = simhash_comp_op.KmeansMatrixCompressor(spec=compressor_spec)
    apply_compression = ApplyCompression(
        scope='default_scope',
        compression_spec=compression_op_spec,
        compressor=compressor,
        global_step=global_step)
  elif compression_op_spec.compression_option == CompressionOptions.KMEANS_AND_PRUNING_MATRIX_COMPRESSION:
    compressor_spec.set_hparam('is_b_matrix_trainable', True)
    compressor = simhash_comp_op.KmeansMatrixCompressor(spec=compressor_spec)
    apply_compression = ApplyCompression(
        scope='default_scope',
        compression_spec=compression_op_spec,
        compressor=compressor,
        global_step=global_step)
  elif compression_op_spec.compression_option == CompressionOptions.INPUTOUTPUT_COMPRESSION:
    compressor_spec.set_hparam('is_b_matrix_trainable', True)
    compressor_spec.set_hparam('is_c_matrix_trainable', True)
    compressor_spec.set_hparam('is_d_matrix_trainable', True)
    compressor = comp_op.LowRankDecompMatrixCompressor(spec=compressor_spec)
    apply_compression = ApplyCompression(
        scope='default_scope',
        compression_spec=compression_op_spec,
        compressor=compressor,
        global_step=global_step)
  elif compression_op_spec.compression_option == CompressionOptions.BLOCK_COMPRESSION:
    compressor_spec.set_hparam('is_c_matrix_trainable', True)
    compressor = comp_op.LowRankDecompMatrixCompressor(spec=compressor_spec)
    apply_compression = ApplyCompression(
        scope='default_scope',
        compression_spec=compression_op_spec,
        compressor=compressor,
        global_step=global_step)
  elif compression_op_spec.compression_option == CompressionOptions.MIXED_BLOCK_COMPRESSION:
    compressor_spec.set_hparam('is_c_matrix_trainable', True)
    compressor = comp_op.LowRankDecompMatrixCompressor(spec=compressor_spec)
    apply_compression = ApplyCompression(
        scope='default_scope',
        compression_spec=compression_op_spec,
        compressor=compressor,
        global_step=global_step)
  elif compression_op_spec.compression_option == CompressionOptions.DL_MATRIX_COMPRESSION:
    compressor = dl_comp_op.DLMatrixCompressor(spec=compressor_spec)
    apply_compression = ApplyCompression(
        scope='default_scope',
        compression_spec=compression_op_spec,
        compressor=compressor,
        global_step=global_step)

  return apply_compression