def test_export_saved_model_from_filesystem(mock_exists, mock_makedirs, serve): checkpoint_dir = 'a/dir' model_path = 'possible/another/dir' with mock.patch('shutil.copy2') as mock_copy: serve.export_saved_model(checkpoint_dir, model_path) mock_copy.assert_called_once_with(checkpoint_dir, model_path)
def test_export_saved_model_from_filesystem(mock_exists, mock_makedirs): checkpoint_dir = 'a/dir' model_path = 'possible/another/dir' with patch('tf_container.serve._recursive_copy') as mock_copy: serve.export_saved_model(checkpoint_dir, model_path) mock_copy.assert_called_once_with(checkpoint_dir, model_path)
def test_export_saved_model_from_s3(makedirs, boto_session, serve): serve.export_saved_model('s3://bucket/test', 'a/path', s3=boto_session) first_call = call('bucket', 'test/1/saved_model.pb', 'a/path/1/saved_model.pb') second_call = call('bucket', 'test/1/variables/variables.index', 'a/path/1/variables/variables.index') calls = [first_call, second_call] boto_session.download_file.assert_has_calls(calls)
def train(): env = cs.TrainingEnvironment() checkpoint_dir = _get_checkpoint_dir(env) train_steps = env.hyperparameters.get('training_steps', 1000) eval_steps = env.hyperparameters.get('evaluation_steps', 100) # https://github.com/tensorflow/tensorflow/issues/15868 # The default request timeout for S3, within the C++ SDK, is 3 seconds, which times out when # saving checkpoints of larger sizes. os.environ['S3_REQUEST_TIMEOUT_MSEC'] = str( env.hyperparameters.get('s3_checkpoint_save_timeout', 60000)) if env.user_script_archive.lower().startswith('s3://'): env.download_user_module() env.pip_install_requirements() customer_script = env.import_user_module() trainer_class = _get_trainer_class() train_wrapper = trainer_class(customer_script=customer_script, current_host=env.current_host, hosts=env.hosts, train_steps=train_steps, eval_steps=eval_steps, input_channels=env.channel_dirs, model_path=checkpoint_dir, output_path=env.output_dir, customer_params=env.hyperparameters) tf_config = train_wrapper.build_tf_config() # only creating a parameter servers for distributed runs if len(env.hosts) > 1: _run_ps_server(env.current_host, env.hosts, tf_config) save_tf_config_env_var(tf_config) configure_mkl() train_wrapper.train() # only the master should export the model at the end of the execution if checkpoint_dir != env.model_dir and train_wrapper.task_type == 'master' and train_wrapper.saves_training( ): serve.export_saved_model(checkpoint_dir, env.model_dir) if train_wrapper.task_type != 'master': _wait_until_master_is_down(_get_master(tf_config))
def test_export_saved_model_from_s3(makedirs, boto_session, serve): serve.export_saved_model('s3://bucket/test', 'a/path', s3=boto_session) expected_boto_calls = [ call('bucket', 'test/1/saved_model.pb', 'a/path/1/saved_model.pb'), call('bucket', 'test/1/variables/variables.index', 'a/path/1/variables/variables.index'), call('bucket', 'test/1/assets/vocabulary.txt', 'a/path/1/assets/vocabulary.txt')] expected_makedirs_calls = [ call('a/path/1'), call('a/path/1/variables'), call('a/path/1/assets'), ] assert boto_session.download_file.mock_calls == expected_boto_calls assert makedirs.mock_calls == expected_makedirs_calls
def test_export_saved_model_from_s3(makedirs, boto_session): serve.export_saved_model('s3://bucket/test', 'a/path', s3=boto_session) expected_boto_calls = [ call('bucket', 'test/1/saved_model.pb', 'a/path/1/saved_model.pb'), call('bucket', 'test/1/variables/variables.index', 'a/path/1/variables/variables.index'), call('bucket', 'test/1/assets/vocabulary.txt', 'a/path/1/assets/vocabulary.txt')] expected_makedirs_calls = [ call('a/path/1'), call('a/path/1/variables'), call('a/path/1/assets'), ] assert boto_session.download_file.mock_calls == expected_boto_calls assert makedirs.mock_calls == expected_makedirs_calls
def train(): env = cs.TrainingEnvironment() checkpoint_dir = _get_checkpoint_dir(env) train_steps = env.hyperparameters.get('training_steps', 1000) eval_steps = env.hyperparameters.get('evaluation_steps', 100) # https://github.com/tensorflow/tensorflow/issues/15868 # The default request timeout for S3, within the C++ SDK, is 3 seconds, which times out when # saving checkpoints of larger sizes. os.environ['S3_REQUEST_TIMEOUT_MSEC'] = str(env.hyperparameters.get('s3_checkpoint_save_timeout', 60000)) if env.user_script_archive.lower().startswith('s3://'): env.download_user_module() env.pip_install_requirements() customer_script = env.import_user_module() trainer_class = _get_trainer_class() train_wrapper = trainer_class(customer_script=customer_script, current_host=env.current_host, hosts=env.hosts, train_steps=train_steps, eval_steps=eval_steps, input_channels=env.channel_dirs, model_path=checkpoint_dir, output_path=env.output_dir, customer_params=env.hyperparameters) tf_config = train_wrapper.build_tf_config() # only creating a parameter servers for distributed runs if len(env.hosts) > 1: _run_ps_server(env.current_host, env.hosts, tf_config) save_tf_config_env_var(tf_config) configure_mkl() train_wrapper.train() # only the master should export the model at the end of the execution if checkpoint_dir != env.model_dir and train_wrapper.task_type == 'master' and train_wrapper.saves_training(): serve.export_saved_model(checkpoint_dir, env.model_dir) if train_wrapper.task_type != 'master': _wait_until_master_is_down(_get_master(tf_config))