Пример #1
0
    def _GetSimpleTestConfig(self):
        model_name = 'image.mnist.LeNet5'
        cfg = model_registry.GetParams(model_name, 'Train')
        cfg.cluster.task = 0
        cfg.cluster.mode = 'sync'
        cfg.cluster.job = 'trainer_client'
        cfg.cluster.worker.name = '/job:localhost'
        cfg.cluster.worker.replicas = 1
        cfg.cluster.worker.gpus_per_replica = 0
        cfg.cluster.ps.name = '/job:localhost'
        cfg.cluster.ps.replicas = 1
        cfg.cluster.ps.gpus_per_replica = 0
        cfg.cluster.reporting_job = FLAGS.vizier_reporting_job

        # Generate 2 inputs.
        cfg.input.ckpt = FakeMnistData(self.get_temp_dir(),
                                       train_size=2,
                                       test_size=2)
        cfg.input.num_samples = 2
        cfg.input.batch_size = 2
        cfg.train.max_steps = 2
        ema_decay = 0.9999
        cfg.task.train.ema_decay = ema_decay
        cfg.train.ema_decay = ema_decay
        return cfg
Пример #2
0
    def _GetTestConfig(self):
        model_name = 'image.mnist.LeNet5'
        cfg = model_registry.GetParams(model_name, 'Dev')
        cfg.cluster.task = 0
        cfg.cluster.mode = 'sync'
        cfg.cluster.job = 'trainer_client'
        cfg.cluster.worker.name = '/job:local'
        cfg.cluster.worker.replicas = 1
        cfg.cluster.worker.gpus_per_replica = 0
        cfg.cluster.ps.name = '/job:local'
        cfg.cluster.ps.replicas = 1
        cfg.cluster.ps.gpus_per_replica = 0

        # Generate 2 inputs.
        self._tmpdir, cfg.input.ckpt = FakeMnistData(train_size=0, test_size=2)
        cfg.input.num_samples = 2
        cfg.train.max_steps = 2
        cfg.train.ema_decay = 0.9999
        return cfg
Пример #3
0
    def testRunLocally(self):
        logdir = os.path.join(tf.test.get_temp_dir(),
                              'run_locally_test' + str(random.random()))
        FLAGS.logdir = logdir
        FLAGS.run_locally = 'cpu'
        FLAGS.mode = 'sync'
        FLAGS.model = 'image.mnist.LeNet5'
        FLAGS.model_params_override = (
            'train.max_steps: 2; input.num_samples: 2; input.ckpt: %s' %
            FakeMnistData(self.get_temp_dir(), train_size=2, test_size=2))
        trainer.main(None)

        train_files = tf.io.gfile.glob(logdir + '/train/*')
        self.assertTrue(self._HasFile(train_files, 'ckpt'))
        self.assertTrue(self._HasFile(train_files, 'tfevents'))
        control_files = tf.io.gfile.glob(logdir + '/control/*')
        self.assertTrue(self._HasFile(control_files, 'params.txt'))
        self.assertTrue(self._HasFile(control_files, 'model_analysis.txt'))
        self.assertTrue(self._HasFile(control_files, 'tfevents'))
Пример #4
0
    def _GetTestConfig(self):
        model_name = 'image.mnist.LeNet5'
        # So this is how particular parameters are obtained
        cfg = model_registry.GetParams(model_name, 'Train')
        cfg.cluster.task = 0
        cfg.cluster.mode = 'sync'
        cfg.cluster.job = 'trainer_client'
        cfg.cluster.worker.name = '/job:localhost'
        cfg.cluster.worker.replicas = 1
        cfg.cluster.worker.gpus_per_replica = 0
        cfg.cluster.ps.name = '/job:localhost'
        cfg.cluster.ps.replicas = 1
        cfg.cluster.ps.gpus_per_replica = 0

        # Generate 2 inputs.
        cfg.input.ckpt = FakeMnistData(self.get_temp_dir(),
                                       train_size=2,
                                       test_size=2)
        cfg.input.num_samples = 2
        cfg.input.batch_size = 2
        cfg.train.max_steps = 2
        cfg.train.ema_decay = 0.9999
        return cfg