示例#1
0
  def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief,
                              expected_worker, mock_run, mock_device_setter,
                              mock_environ, mock_server):
    tf_config = {
        'cluster': {
            'ps': ['ps1:800', 'ps2:800']
        },
        'task': {
            'type': job_name,
            'index': task_index,
        },
    }

    class FakeServer(object):
      target = 'some-target'

    mock_environ.get.return_value = json.dumps(tf_config)
    mock_server.return_value = FakeServer()

    model_train.parse_and_run()

    mock_device_setter.assert_called_once_with(
        2, worker_device=expected_worker, cluster=mock.ANY)
    mock_run.assert_called_once_with(
        'some-target', expected_is_chief, device_fn=mock.ANY)
示例#2
0
    def test_main_tfconfig_dist(self, job_name, task_index, expected_is_chief,
                                expected_worker, mock_run, mock_device_setter,
                                mock_environ, mock_server):
        tf_config = {
            'cluster': {
                'ps': ['ps1:800', 'ps2:800']
            },
            'task': {
                'type': job_name,
                'index': task_index,
            },
        }

        class FakeServer(object):
            target = 'some-target'

        mock_environ.get.return_value = json.dumps(tf_config)
        mock_server.return_value = FakeServer()

        model_train.parse_and_run()

        mock_device_setter.assert_called_once_with(
            2, worker_device=expected_worker, cluster=mock.ANY)
        mock_run.assert_called_once_with('some-target',
                                         expected_is_chief,
                                         device_fn=mock.ANY,
                                         use_tpu=mock.ANY)
示例#3
0
  def test_main_tfconfig_local(self, mock_run, mock_device_setter,
                               mock_environ):
    mock_environ.get.return_value = '{}'
    model_train.parse_and_run()

    mock_device_setter.assert_called_once_with(0)
    mock_run.assert_called_once_with('', True, device_fn=mock.ANY)
示例#4
0
    def test_main_tfconfig_local(self, mock_run, mock_device_setter,
                                 mock_environ):
        mock_environ.get.return_value = '{}'
        model_train.parse_and_run()

        mock_device_setter.assert_called_once_with(0)
        mock_run.assert_called_once_with('', True, device_fn=mock.ANY)
示例#5
0
  def test_main_internal(self, mock_run, mock_device_setter):
    FLAGS.master = 'some_master'
    FLAGS.ps_tasks = 10
    FLAGS.task = 5

    model_train.parse_and_run()

    mock_device_setter.assert_called_once_with(10)
    mock_run.assert_called_once_with('some_master', False, device_fn=mock.ANY)
示例#6
0
  def test_main_internal(self, mock_run, mock_device_setter):
    FLAGS.master = 'some_master'
    FLAGS.ps_tasks = 10
    FLAGS.task = 5

    model_train.parse_and_run()

    mock_device_setter.assert_called_once_with(10)
    mock_run.assert_called_once_with('some_master', False, device_fn=mock.ANY)
示例#7
0
 def _run_tiny_training(self, model_name, dataset):
   with mock.patch(
       'deepvariant.data_providers.get_dataset'
   ) as mock_get_dataset:
     mock_get_dataset.return_value = dataset
     FLAGS.train_dir = test_utils.test_tmpfile(model_name)
     FLAGS.batch_size = 2
     FLAGS.model_name = model_name
     FLAGS.save_interval_secs = 0
     FLAGS.number_of_steps = 1
     FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt'
     FLAGS.start_from_checkpoint = ''
     model_train.parse_and_run()
     # We have a checkpoint after training.
     mock_get_dataset.assert_called_once_with(FLAGS.dataset_config_pbtxt)
     self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))
示例#8
0
 def _run_tiny_training(self, model_name, dataset):
     with mock.patch(
             'deepvariant.data_providers.get_dataset') as mock_get_dataset:
         mock_get_dataset.return_value = dataset
         FLAGS.train_dir = test_utils.test_tmpfile(model_name)
         FLAGS.batch_size = 2
         FLAGS.model_name = model_name
         FLAGS.save_interval_secs = 0
         FLAGS.number_of_steps = 1
         FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt'
         FLAGS.start_from_checkpoint = ''
         model_train.parse_and_run()
         # We have a checkpoint after training.
         mock_get_dataset.assert_called_once_with(
             FLAGS.dataset_config_pbtxt)
         self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))
示例#9
0
 def _run_tiny_training(self, model_name, dataset, warm_start_from=''):
     """Runs one training step. This function always starts a new train_dir."""
     with mock.patch(
             'deepvariant.data_providers.'
             'get_input_fn_from_dataset') as mock_get_input_fn_from_dataset:
         mock_get_input_fn_from_dataset.return_value = dataset
         FLAGS.train_dir = tf_test_utils.test_tmpdir(uuid.uuid4().hex)
         FLAGS.batch_size = 2
         FLAGS.model_name = model_name
         FLAGS.save_interval_secs = -1
         FLAGS.save_interval_steps = 1
         FLAGS.number_of_steps = 1
         FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt'
         FLAGS.start_from_checkpoint = warm_start_from
         FLAGS.master = ''
         model_train.parse_and_run()
         # We have a checkpoint after training.
         mock_get_input_fn_from_dataset.assert_called_once_with(
             dataset_config_filename=FLAGS.dataset_config_pbtxt,
             mode=tf.estimator.ModeKeys.TRAIN,
             use_tpu=mock.ANY,
         )
         self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))