コード例 #1
0
def test_build_tf_config_with_multiple_hosts(trainer):
    hosts = ['algo-1', 'algo-2', 'algo-3', 'algo-4']
    current_host = 'algo-3'

    test_wrapper = trainer.Trainer(customer_script=mock_script,
                                   current_host=current_host,
                                   hosts=hosts,
                                   model_path=model_path)

    tf_config = test_wrapper.build_tf_config()

    expected_tf_config = {
        'environment': 'cloud',
        'cluster': {
            'master': ['algo-1:2222'],
            'ps': ['algo-1:2223', 'algo-2:2223', 'algo-3:2223', 'algo-4:2223'],
            'worker': ['algo-2:2222', 'algo-3:2222', 'algo-4:2222']
        },
        'task': {
            'index': 1,
            'type': 'worker'
        }
    }

    assert tf_config == expected_tf_config
    assert test_wrapper.task_type == 'worker'
コード例 #2
0
def test_configure_s3_file_system(os_env, botocore, boto_client, trainer):
    trainer.Trainer(customer_script=mock_script,
                    current_host=current_host,
                    hosts=hosts,
                    model_path='s3://my/s3/path')

    boto_client('s3').get_bucket_location.assert_called_once_with(Bucket='my')

    calls = [
        call('S3_USE_HTTPS', '1'),
        call('S3_REGION',
             boto_client('s3').get_bucket_location()['LocationConstraint'])
    ]

    os_env.__setitem__.assert_has_calls(calls, any_order=True)
コード例 #3
0
def test_trainer_keras_model_fn(os_environ, botocore, boto3, trainer, modules):
    '''
    this test ensures that customers functions model_fn, train_input_fn, eval_input_fn, and serving_input_fn are
    being invoked with the right params
    '''
    customer_script = MagicMock(spec=[
        'keras_model_fn', 'train_input_fn', 'eval_input_fn', 'serving_input_fn'
    ])

    _trainer = trainer.Trainer(customer_script=customer_script,
                               current_host=current_host,
                               hosts=hosts,
                               model_path='s3://my/s3/path',
                               customer_params={
                                   'training_steps': 10,
                                   'num_gpu': 20
                               },
                               training_path='mytrainingpath')

    modules.learn_runner.run.side_effect = lambda experiment_fn, training_path: experiment_fn(
        training_path)
    modules.Experiment.side_effect = execute_input_functions
    modules.saved_model_export_utils.make_export_strategy.side_effect = make_export_strategy_fn

    _trainer.train()

    expected_params = {
        'num_gpu': 20,
        'min_eval_frequency': 1000,
        'training_steps': 10,
        'save_checkpoints_secs': 300
    }

    modules.keras.estimator.model_to_estimator.assert_called_with(
        config=modules.RunConfig(),
        keras_model=customer_script.keras_model_fn(),
    )

    modules.learn_runner.run.assert_called()
    modules.Experiment.assert_called()

    customer_script.train_input_fn.assert_called_with('mytrainingpath',
                                                      expected_params)
    customer_script.eval_input_fn.assert_called_with('mytrainingpath',
                                                     expected_params)
    customer_script.serving_input_fn.assert_called_with(expected_params)
コード例 #4
0
def test_trainer_experiment_params(os_environ, botocore, boto3, trainer,
                                   modules):
    '''
    this test ensures that customers functions model_fn, train_input_fn, eval_input_fn, and serving_input_fn are
    being invoked with the right params
    '''
    customer_script = MagicMock(spec=[
        'model_fn', 'train_input_fn', 'eval_input_fn', 'serving_input_fn'
    ])

    _trainer = trainer.Trainer(customer_script=customer_script,
                               current_host=current_host,
                               hosts=hosts,
                               model_path='s3://my/s3/path',
                               eval_steps=23,
                               customer_params={
                                   'min_eval_frequency': 2,
                                   'local_eval_frequency': 3,
                                   'eval_delay_secs': 7,
                                   'continuous_eval_throttle_secs': 25,
                                   'train_steps_per_iteration': 13
                               },
                               training_path='mytrainingpath')

    modules.learn_runner.run.side_effect = lambda experiment_fn, training_path: experiment_fn(
        training_path)

    _trainer.train()

    modules.Experiment.assert_called_with(
        estimator=modules.estimator.Estimator(),
        train_input_fn=ANY,
        eval_input_fn=ANY,
        export_strategies=ANY,
        train_steps=ANY,
        eval_steps=23,
        min_eval_frequency=2,
        local_eval_frequency=3,
        eval_delay_secs=7,
        continuous_eval_throttle_secs=25,
        train_steps_per_iteration=13)
コード例 #5
0
def test_build_tf_config_with_one_host(trainer):
    hosts = ['algo-1']
    current_host = 'algo-1'

    test_wrapper = trainer.Trainer(customer_script=mock_script,
                                   current_host=current_host,
                                   hosts=hosts,
                                   model_path=model_path)

    tf_config = test_wrapper.build_tf_config()

    expected_tf_config = {
        'environment': 'cloud',
        'cluster': {
            'master': ['algo-1:2222']
        },
        'task': {
            'index': 0,
            'type': 'master'
        }
    }

    assert tf_config == expected_tf_config
    assert test_wrapper.task_type == 'master'
コード例 #6
0
def test_trainer_run_config_params(os_environ, botocore, boto3, trainer,
                                   modules):
    '''
    this test ensures that customers functions model_fn, train_input_fn, eval_input_fn, and serving_input_fn are
    being invoked with the right params
    '''
    customer_script = MagicMock(spec=[
        'model_fn', 'train_input_fn', 'eval_input_fn', 'serving_input_fn'
    ])

    _trainer = trainer.Trainer(customer_script=customer_script,
                               current_host=current_host,
                               hosts=hosts,
                               model_path='s3://my/s3/path',
                               eval_steps=23,
                               customer_params={
                                   'save_summary_steps': 1,
                                   'save_checkpoints_secs': 2,
                                   'save_checkpoints_steps': 3,
                                   'keep_checkpoint_max': 4,
                                   'keep_checkpoint_every_n_hours': 5,
                                   'log_step_count_steps': 6
                               },
                               training_path='mytrainingpath')
    modules.learn_runner.run.side_effect = lambda experiment_fn, training_path: experiment_fn(
        training_path)

    _trainer.train()

    modules.RunConfig.assert_called_with(save_summary_steps=1,
                                         save_checkpoints_secs=2,
                                         save_checkpoints_steps=3,
                                         keep_checkpoint_max=4,
                                         keep_checkpoint_every_n_hours=5,
                                         log_step_count_steps=6,
                                         model_dir='s3://my/s3/path')