コード例 #1
0
def test_args_mxnet_host_non_defaults():
    args = cli.parse_arguments('{} mxnet host --role-name role {} {}'.format(
        LOG_ARGS, COMMON_ARGS, HOST_ARGS).split())
    assert_common_non_defaults(args)
    assert_host_non_defaults(args)
    assert args.func.__module__ == 'sagemaker.cli.mxnet'
    assert args.func.__name__ == 'host'
コード例 #2
0
def test_mxnet_host(session, upload_model, model):
    args = cli.parse_arguments('mxnet host --role-name role'.split())
    args.func(args)
    session.assert_called()
    upload_model.assert_called()
    model.assert_called()
    model.return_value.deploy.assert_called()
コード例 #3
0
def test_args_mxnet_train_non_defaults():
    args = cli.parse_arguments("{} mxnet train --role-name role {} {}".format(
        LOG_ARGS, COMMON_ARGS, TRAIN_ARGS).split())
    assert_common_non_defaults(args)
    assert_train_non_defaults(args)
    assert args.func.__module__ == "sagemaker.cli.mxnet"
    assert args.func.__name__ == "train"
コード例 #4
0
def test_args_tensorflow_train_defaults():
    args = cli.parse_arguments('tensorflow train --role-name role'.split())
    assert_common_defaults(args)
    assert_train_defaults(args)
    assert args.training_steps is None
    assert args.evaluation_steps is None
    assert args.func.__module__ == 'sagemaker.cli.tensorflow'
    assert args.func.__name__ == 'train'
コード例 #5
0
def test_args_tensorflow_host_non_defaults():
    args = cli.parse_arguments(
        "{} tensorflow host --role-name role {} {}".format(
            LOG_ARGS, COMMON_ARGS, HOST_ARGS).split())
    assert_common_non_defaults(args)
    assert_host_non_defaults(args)
    assert args.func.__module__ == "sagemaker.cli.tensorflow"
    assert args.func.__name__ == "host"
コード例 #6
0
def test_args_tensorflow_train_non_defaults():
    args = cli.parse_arguments(
        '{} tensorflow train --role-name role --training-steps 10 --evaluation-steps 5 {} {}'
        .format(LOG_ARGS, COMMON_ARGS, TRAIN_ARGS).split())
    assert_common_non_defaults(args)
    assert_train_non_defaults(args)
    assert args.training_steps == 10
    assert args.evaluation_steps == 5
    assert args.func.__module__ == 'sagemaker.cli.tensorflow'
    assert args.func.__name__ == 'train'
コード例 #7
0
def test_args_mxnet_host_defaults():
    args = cli.parse_arguments('mxnet host --role-name role'.split())
    assert_common_defaults(args)
    assert_host_defaults(args)
    assert args.func.__module__ == 'sagemaker.cli.mxnet'
    assert args.func.__name__ == 'host'
コード例 #8
0
def test_mxnet_train(session, estimator):
    args = cli.parse_arguments('mxnet train --role-name role'.split())
    args.func(args)
    session.return_value.upload_data.assert_called()
    estimator.assert_called()
    estimator.return_value.fit.assert_called()
コード例 #9
0
def test_args_invalid_train_args_in_host():
    with pytest.raises(SystemExit):
        cli.parse_arguments(
            'tensorflow host --role-name role --hyperparameters foo.json'.
            split())
コード例 #10
0
def test_args_invalid_host_args_in_train():
    with pytest.raises(SystemExit):
        cli.parse_arguments(
            'mxnet train --role-name role --env FOO=bar'.split())
コード例 #11
0
def test_args_invalid_mxnet_python():
    with pytest.raises(SystemExit):
        cli.parse_arguments('mxnet train --role-name role nython py2'.split())
コード例 #12
0
def test_args_invalid_args():
    with pytest.raises(SystemExit):
        cli.parse_arguments(
            'tensorflow train --role-name role --notdata foo'.split())
コード例 #13
0
def test_args_invalid_subcommand():
    with pytest.raises(SystemExit):
        cli.parse_arguments('mxnet drain'.split())
コード例 #14
0
def test_args_invalid_framework():
    with pytest.raises(SystemExit):
        cli.parse_arguments('fakeframework train --role-name role'.split())
コード例 #15
0
def test_args_tensorflow_host_defaults():
    args = cli.parse_arguments('tensorflow host --role-name role'.split())
    assert_common_defaults(args)
    assert_host_defaults(args)
    assert args.func.__module__ == 'sagemaker.cli.tensorflow'
    assert args.func.__name__ == 'host'
コード例 #16
0
def test_args_mxnet_train_defaults():
    args = cli.parse_arguments("mxnet train --role-name role".split())
    assert_common_defaults(args)
    assert_train_defaults(args)
    assert args.func.__module__ == "sagemaker.cli.mxnet"
    assert args.func.__name__ == "train"