Esempio n. 1
0
  def test_prune_local_pruning_schedule(self):
    """Tests training/pruning driver with a single layer sparsity schedule."""
    experiment_dir = self.create_tempdir().full_path
    eval_flags = dict(
        epochs=10,
        pruning_schedule='{1:[(5, 0.33), (7, 0.66), (9, 0.95)]}',
        experiment_dir=experiment_dir,
    )

    with flagsaver.flagsaver(**eval_flags):
      prune.main([])

      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
      files = glob.glob(outfile)

      self.assertTrue(len(files) == 1 and path.exists(files[0]))
Esempio n. 2
0
  def test_prune_fixed_schedule(self):
    """Tests training/pruning driver with a fixed global sparsity."""
    experiment_dir = self.create_tempdir().full_path
    eval_flags = dict(
        epochs=1,
        pruning_rate=0.95,
        experiment_dir=experiment_dir,
    )

    with flagsaver.flagsaver(**eval_flags):
      prune.main([])

      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
      files = glob.glob(outfile)

      self.assertTrue(len(files) == 1 and path.exists(files[0]))