Exemplo n.º 1
0
  def test_run_symmetric(self):
    """Test random mask driver with per-neuron sparsity."""
    experiment_dir = tempfile.mkdtemp()
    self._eval_flags = dict(
        epochs=1,
        experiment_dir=experiment_dir,
        mask_type='symmetric',
    )

    with flagsaver.flagsaver(**self._eval_flags):
      shuffled_mask.main([])

    outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
    files = glob.glob(outfile)

    self.assertTrue(len(files) == 1 and path.exists(files[0]))
Exemplo n.º 2
0
  def test_run_conv(self):
    """Tests if the driver for shuffled training runs correctly with CNN."""
    experiment_dir = tempfile.mkdtemp()
    eval_flags = dict(
        epochs=1,
        experiment_dir=experiment_dir,
        model='MNIST_CNN',
    )

    with flagsaver.flagsaver(**eval_flags):
      shuffled_mask.main([])

    outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
    files = glob.glob(outfile)

    self.assertTrue(len(files) == 1 and path.exists(files[0]))