def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief, expected_worker, mock_run, mock_device_setter, mock_environ, mock_server): tf_config = { 'cluster': { 'ps': ['ps1:800', 'ps2:800'] }, 'task': { 'type': job_name, 'index': task_index, }, } class FakeServer(object): target = 'some-target' mock_environ.get.return_value = json.dumps(tf_config) mock_server.return_value = FakeServer() model_train.parse_and_run() mock_device_setter.assert_called_once_with( 2, worker_device=expected_worker, cluster=mock.ANY) mock_run.assert_called_once_with( 'some-target', expected_is_chief, device_fn=mock.ANY)
def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief, expected_worker, mock_run, mock_device_setter, mock_environ, mock_server): tf_config = { 'cluster': { 'ps': ['ps1:800', 'ps2:800'] }, 'task': { 'type': job_name, 'index': task_index, }, } class FakeServer(object): target = 'some-target' mock_environ.get.return_value = json.dumps(tf_config) mock_server.return_value = FakeServer() model_train.parse_and_run() mock_device_setter.assert_called_once_with( 2, worker_device=expected_worker, cluster=mock.ANY) mock_run.assert_called_once_with('some-target', expected_is_chief, device_fn=mock.ANY, use_tpu=mock.ANY)
def test_main_tfconfig_local(self, mock_run, mock_device_setter, mock_environ): mock_environ.get.return_value = '{}' model_train.parse_and_run() mock_device_setter.assert_called_once_with(0) mock_run.assert_called_once_with('', True, device_fn=mock.ANY)
def test_main_tfconfig_local(self, mock_run, mock_device_setter, mock_environ): mock_environ.get.return_value = '{}' model_train.parse_and_run() mock_device_setter.assert_called_once_with(0) mock_run.assert_called_once_with('', True, device_fn=mock.ANY)
def test_main_internal(self, mock_run, mock_device_setter): FLAGS.master = 'some_master' FLAGS.ps_tasks = 10 FLAGS.task = 5 model_train.parse_and_run() mock_device_setter.assert_called_once_with(10) mock_run.assert_called_once_with('some_master', False, device_fn=mock.ANY)
def test_main_internal(self, mock_run, mock_device_setter): FLAGS.master = 'some_master' FLAGS.ps_tasks = 10 FLAGS.task = 5 model_train.parse_and_run() mock_device_setter.assert_called_once_with(10) mock_run.assert_called_once_with('some_master', False, device_fn=mock.ANY)
def _run_tiny_training(self, model_name, dataset): with mock.patch( 'deepvariant.data_providers.get_dataset' ) as mock_get_dataset: mock_get_dataset.return_value = dataset FLAGS.train_dir = test_utils.test_tmpfile(model_name) FLAGS.batch_size = 2 FLAGS.model_name = model_name FLAGS.save_interval_secs = 0 FLAGS.number_of_steps = 1 FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' FLAGS.start_from_checkpoint = '' model_train.parse_and_run() # We have a checkpoint after training. mock_get_dataset.assert_called_once_with(FLAGS.dataset_config_pbtxt) self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))
def _run_tiny_training(self, model_name, dataset): with mock.patch( 'deepvariant.data_providers.get_dataset') as mock_get_dataset: mock_get_dataset.return_value = dataset FLAGS.train_dir = test_utils.test_tmpfile(model_name) FLAGS.batch_size = 2 FLAGS.model_name = model_name FLAGS.save_interval_secs = 0 FLAGS.number_of_steps = 1 FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' FLAGS.start_from_checkpoint = '' model_train.parse_and_run() # We have a checkpoint after training. mock_get_dataset.assert_called_once_with( FLAGS.dataset_config_pbtxt) self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))
def _run_tiny_training(self, model_name, dataset, warm_start_from=''): """Runs one training step. This function always starts a new train_dir.""" with mock.patch( 'deepvariant.data_providers.' 'get_input_fn_from_dataset') as mock_get_input_fn_from_dataset: mock_get_input_fn_from_dataset.return_value = dataset FLAGS.train_dir = tf_test_utils.test_tmpdir(uuid.uuid4().hex) FLAGS.batch_size = 2 FLAGS.model_name = model_name FLAGS.save_interval_secs = -1 FLAGS.save_interval_steps = 1 FLAGS.number_of_steps = 1 FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' FLAGS.start_from_checkpoint = warm_start_from FLAGS.master = '' model_train.parse_and_run() # We have a checkpoint after training. mock_get_input_fn_from_dataset.assert_called_once_with( dataset_config_filename=FLAGS.dataset_config_pbtxt, mode=tf.estimator.ModeKeys.TRAIN, use_tpu=mock.ANY, ) self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))