示例#1
0
def test_export_saved_model_from_filesystem(mock_exists, mock_makedirs, serve):
    checkpoint_dir = 'a/dir'
    model_path = 'possible/another/dir'

    with mock.patch('shutil.copy2') as mock_copy:
        serve.export_saved_model(checkpoint_dir, model_path)
        mock_copy.assert_called_once_with(checkpoint_dir, model_path)
def test_export_saved_model_from_filesystem(mock_exists, mock_makedirs):
    checkpoint_dir = 'a/dir'
    model_path = 'possible/another/dir'

    with patch('tf_container.serve._recursive_copy') as mock_copy:
        serve.export_saved_model(checkpoint_dir, model_path)
        mock_copy.assert_called_once_with(checkpoint_dir, model_path)
def test_export_saved_model_from_filesystem(mock_exists, mock_makedirs):
    checkpoint_dir = 'a/dir'
    model_path = 'possible/another/dir'

    with patch('tf_container.serve._recursive_copy') as mock_copy:
        serve.export_saved_model(checkpoint_dir, model_path)
        mock_copy.assert_called_once_with(checkpoint_dir, model_path)
def test_export_saved_model_from_s3(makedirs, boto_session, serve):
    serve.export_saved_model('s3://bucket/test', 'a/path', s3=boto_session)

    first_call = call('bucket', 'test/1/saved_model.pb', 'a/path/1/saved_model.pb')
    second_call = call('bucket', 'test/1/variables/variables.index', 'a/path/1/variables/variables.index')

    calls = [first_call, second_call]

    boto_session.download_file.assert_has_calls(calls)
def test_export_saved_model_from_s3(makedirs, boto_session, serve):
    serve.export_saved_model('s3://bucket/test', 'a/path', s3=boto_session)

    first_call = call('bucket', 'test/1/saved_model.pb', 'a/path/1/saved_model.pb')
    second_call = call('bucket', 'test/1/variables/variables.index', 'a/path/1/variables/variables.index')

    calls = [first_call, second_call]

    boto_session.download_file.assert_has_calls(calls)
def train():
    env = cs.TrainingEnvironment()

    checkpoint_dir = _get_checkpoint_dir(env)
    train_steps = env.hyperparameters.get('training_steps', 1000)
    eval_steps = env.hyperparameters.get('evaluation_steps', 100)

    # https://github.com/tensorflow/tensorflow/issues/15868
    # The default request timeout for S3, within the C++ SDK, is 3 seconds, which times out when
    # saving checkpoints of larger sizes.
    os.environ['S3_REQUEST_TIMEOUT_MSEC'] = str(
        env.hyperparameters.get('s3_checkpoint_save_timeout', 60000))

    if env.user_script_archive.lower().startswith('s3://'):
        env.download_user_module()
    env.pip_install_requirements()

    customer_script = env.import_user_module()

    trainer_class = _get_trainer_class()
    train_wrapper = trainer_class(customer_script=customer_script,
                                  current_host=env.current_host,
                                  hosts=env.hosts,
                                  train_steps=train_steps,
                                  eval_steps=eval_steps,
                                  input_channels=env.channel_dirs,
                                  model_path=checkpoint_dir,
                                  output_path=env.output_dir,
                                  customer_params=env.hyperparameters)

    tf_config = train_wrapper.build_tf_config()

    # only creating a parameter servers for distributed runs
    if len(env.hosts) > 1:
        _run_ps_server(env.current_host, env.hosts, tf_config)

    save_tf_config_env_var(tf_config)

    configure_mkl()

    train_wrapper.train()

    # only the master should export the model at the end of the execution
    if checkpoint_dir != env.model_dir and train_wrapper.task_type == 'master' and train_wrapper.saves_training(
    ):
        serve.export_saved_model(checkpoint_dir, env.model_dir)

    if train_wrapper.task_type != 'master':
        _wait_until_master_is_down(_get_master(tf_config))
def test_export_saved_model_from_s3(makedirs, boto_session, serve):
    serve.export_saved_model('s3://bucket/test', 'a/path', s3=boto_session)

    expected_boto_calls = [
        call('bucket', 'test/1/saved_model.pb', 'a/path/1/saved_model.pb'),
        call('bucket', 'test/1/variables/variables.index', 'a/path/1/variables/variables.index'),
        call('bucket', 'test/1/assets/vocabulary.txt', 'a/path/1/assets/vocabulary.txt')]

    expected_makedirs_calls = [
        call('a/path/1'),
        call('a/path/1/variables'),
        call('a/path/1/assets'),
    ]

    assert boto_session.download_file.mock_calls == expected_boto_calls
    assert makedirs.mock_calls == expected_makedirs_calls
def test_export_saved_model_from_s3(makedirs, boto_session):
    serve.export_saved_model('s3://bucket/test', 'a/path', s3=boto_session)

    expected_boto_calls = [
        call('bucket', 'test/1/saved_model.pb', 'a/path/1/saved_model.pb'),
        call('bucket', 'test/1/variables/variables.index', 'a/path/1/variables/variables.index'),
        call('bucket', 'test/1/assets/vocabulary.txt', 'a/path/1/assets/vocabulary.txt')]

    expected_makedirs_calls = [
        call('a/path/1'),
        call('a/path/1/variables'),
        call('a/path/1/assets'),
    ]

    assert boto_session.download_file.mock_calls == expected_boto_calls
    assert makedirs.mock_calls == expected_makedirs_calls
def train():
    env = cs.TrainingEnvironment()

    checkpoint_dir = _get_checkpoint_dir(env)
    train_steps = env.hyperparameters.get('training_steps', 1000)
    eval_steps = env.hyperparameters.get('evaluation_steps', 100)

    # https://github.com/tensorflow/tensorflow/issues/15868
    # The default request timeout for S3, within the C++ SDK, is 3 seconds, which times out when
    # saving checkpoints of larger sizes.
    os.environ['S3_REQUEST_TIMEOUT_MSEC'] = str(env.hyperparameters.get('s3_checkpoint_save_timeout', 60000))

    if env.user_script_archive.lower().startswith('s3://'):
        env.download_user_module()
    env.pip_install_requirements()

    customer_script = env.import_user_module()

    trainer_class = _get_trainer_class()
    train_wrapper = trainer_class(customer_script=customer_script,
                                  current_host=env.current_host,
                                  hosts=env.hosts,
                                  train_steps=train_steps,
                                  eval_steps=eval_steps,
                                  input_channels=env.channel_dirs,
                                  model_path=checkpoint_dir,
                                  output_path=env.output_dir,
                                  customer_params=env.hyperparameters)

    tf_config = train_wrapper.build_tf_config()

    # only creating a parameter servers for distributed runs
    if len(env.hosts) > 1:
        _run_ps_server(env.current_host, env.hosts, tf_config)

    save_tf_config_env_var(tf_config)

    configure_mkl()

    train_wrapper.train()

    # only the master should export the model at the end of the execution
    if checkpoint_dir != env.model_dir and train_wrapper.task_type == 'master' and train_wrapper.saves_training():
        serve.export_saved_model(checkpoint_dir, env.model_dir)

    if train_wrapper.task_type != 'master':
        _wait_until_master_is_down(_get_master(tf_config))