Exemplo n.º 1
0
def test_get_model_sparsity():
  """Tests if the method get_model_sparsity in utils.py works correctly."""
  qmodel = create_quantized_network()

  # Generate sparsity levels to test
  sparsity_levels = np.concatenate((np.random.rand(10), [1.0, 0.0])).round(2)

  # Test various sparsity levels
  for true_sparsity in sparsity_levels:
    qmodel = set_network_sparsity(qmodel, true_sparsity)
    calc_sparsity = get_model_sparsity(qmodel)
    assert np.abs(calc_sparsity - true_sparsity) < 0.01
Exemplo n.º 2
0
def test_get_po2_model_sparsity():
  """Tests get_model_sparsity on a po2-quantized model.

  Models quantized with po2 quantizers should have a sparsity near 0 because
  if the exponent is set to 0, the value of the weight will equal 2^0 == 1 != 0
  """
  qmodel = create_quantized_po2_network()

  # Generate sparsity levels to test
  sparsity_levels = np.concatenate((np.random.rand(10), [1.0, 0.0])).round(2)

  # Test various sparsity levels
  for set_sparsity in sparsity_levels:
    qmodel = set_network_sparsity(qmodel, set_sparsity)
    calc_sparsity = get_model_sparsity(qmodel)
    assert np.abs(calc_sparsity - 0) < 0.01