def _GetSimpleTestConfig(self): model_name = 'image.mnist.LeNet5' cfg = model_registry.GetParams(model_name, 'Train') cfg.cluster.task = 0 cfg.cluster.mode = 'sync' cfg.cluster.job = 'trainer_client' cfg.cluster.worker.name = '/job:localhost' cfg.cluster.worker.replicas = 1 cfg.cluster.worker.gpus_per_replica = 0 cfg.cluster.ps.name = '/job:localhost' cfg.cluster.ps.replicas = 1 cfg.cluster.ps.gpus_per_replica = 0 cfg.cluster.reporting_job = FLAGS.vizier_reporting_job # Generate 2 inputs. cfg.input.ckpt = FakeMnistData(self.get_temp_dir(), train_size=2, test_size=2) cfg.input.num_samples = 2 cfg.input.batch_size = 2 cfg.train.max_steps = 2 ema_decay = 0.9999 cfg.task.train.ema_decay = ema_decay cfg.train.ema_decay = ema_decay return cfg
def _GetTestConfig(self): model_name = 'image.mnist.LeNet5' cfg = model_registry.GetParams(model_name, 'Dev') cfg.cluster.task = 0 cfg.cluster.mode = 'sync' cfg.cluster.job = 'trainer_client' cfg.cluster.worker.name = '/job:local' cfg.cluster.worker.replicas = 1 cfg.cluster.worker.gpus_per_replica = 0 cfg.cluster.ps.name = '/job:local' cfg.cluster.ps.replicas = 1 cfg.cluster.ps.gpus_per_replica = 0 # Generate 2 inputs. self._tmpdir, cfg.input.ckpt = FakeMnistData(train_size=0, test_size=2) cfg.input.num_samples = 2 cfg.train.max_steps = 2 cfg.train.ema_decay = 0.9999 return cfg
def testRunLocally(self): logdir = os.path.join(tf.test.get_temp_dir(), 'run_locally_test' + str(random.random())) FLAGS.logdir = logdir FLAGS.run_locally = 'cpu' FLAGS.mode = 'sync' FLAGS.model = 'image.mnist.LeNet5' FLAGS.model_params_override = ( 'train.max_steps: 2; input.num_samples: 2; input.ckpt: %s' % FakeMnistData(self.get_temp_dir(), train_size=2, test_size=2)) trainer.main(None) train_files = tf.io.gfile.glob(logdir + '/train/*') self.assertTrue(self._HasFile(train_files, 'ckpt')) self.assertTrue(self._HasFile(train_files, 'tfevents')) control_files = tf.io.gfile.glob(logdir + '/control/*') self.assertTrue(self._HasFile(control_files, 'params.txt')) self.assertTrue(self._HasFile(control_files, 'model_analysis.txt')) self.assertTrue(self._HasFile(control_files, 'tfevents'))
def _GetTestConfig(self): model_name = 'image.mnist.LeNet5' # So this is how particular parameters are obtained cfg = model_registry.GetParams(model_name, 'Train') cfg.cluster.task = 0 cfg.cluster.mode = 'sync' cfg.cluster.job = 'trainer_client' cfg.cluster.worker.name = '/job:localhost' cfg.cluster.worker.replicas = 1 cfg.cluster.worker.gpus_per_replica = 0 cfg.cluster.ps.name = '/job:localhost' cfg.cluster.ps.replicas = 1 cfg.cluster.ps.gpus_per_replica = 0 # Generate 2 inputs. cfg.input.ckpt = FakeMnistData(self.get_temp_dir(), train_size=2, test_size=2) cfg.input.num_samples = 2 cfg.input.batch_size = 2 cfg.train.max_steps = 2 cfg.train.ema_decay = 0.9999 return cfg