示例#1
0
def test_format_inputs_to_input_config_list_duplicate_channel():
    record = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
    inputs = [record, record]

    with pytest.raises(ValueError) as ex:
        _Job._format_inputs_to_input_config(inputs)

    assert 'Duplicate channels not allowed.' in str(ex)
示例#2
0
def test_format_inputs_to_input_config_list_not_all_records():
    records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
    inputs = [records, 'mock']

    with pytest.raises(ValueError) as ex:
        _Job._format_inputs_to_input_config(inputs)

    assert 'List compatible only with RecordSets.' in str(ex)
def test_format_inputs_to_input_config_list_not_all_records():
    records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
    inputs = [records, "mock"]

    with pytest.raises(ValueError) as ex:
        _Job._format_inputs_to_input_config(inputs)

    assert "List compatible only with RecordSets or FileSystemRecordSets." in str(
        ex)
示例#4
0
def test_format_input_s3_input():
    input_dict = _Job._format_inputs_to_input_config(
        s3_input(
            "s3://foo/bar",
            distribution="ShardedByS3Key",
            compression="gzip",
            content_type="whizz",
            record_wrapping="bang",
        )
    )
    assert input_dict == [
        {
            "CompressionType": "gzip",
            "ChannelName": "training",
            "ContentType": "whizz",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3DataDistributionType": "ShardedByS3Key",
                    "S3Uri": "s3://foo/bar",
                }
            },
            "RecordWrapperType": "bang",
        }
    ]
示例#5
0
def test_dict_of_mixed_input_types():
    input_list = _Job._format_inputs_to_input_config({
        "a":
        "s3://foo/bar",
        "b":
        TrainingInput("s3://whizz/bang")
    })

    expected = [
        {
            "ChannelName": "a",
            "DataSource": {
                "S3DataSource": {
                    "S3DataDistributionType": "FullyReplicated",
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://foo/bar",
                }
            },
        },
        {
            "ChannelName": "b",
            "DataSource": {
                "S3DataSource": {
                    "S3DataDistributionType": "FullyReplicated",
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://whizz/bang",
                }
            },
        },
    ]

    # convert back into map for comparison so list order (which is arbitrary) is ignored
    assert {c["ChannelName"]: c
            for c in input_list} == {c["ChannelName"]: c
                                     for c in expected}
示例#6
0
def test_format_inputs_to_input_config_dict():
    inputs = {'train': BUCKET_NAME}

    channels = _Job._format_inputs_to_input_config(inputs)

    assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs[
        'train']
示例#7
0
def test_format_inputs_to_input_config_dict():
    inputs = {"train": BUCKET_NAME}

    channels = _Job._format_inputs_to_input_config(inputs)

    assert channels[0]["DataSource"]["S3DataSource"]["S3Uri"] == inputs[
        "train"]
示例#8
0
def test_format_inputs_to_input_config_record_set():
    inputs = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)

    channels = _Job._format_inputs_to_input_config(inputs)

    assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs.s3_data
    assert channels[0]['DataSource']['S3DataSource']['S3DataType'] == inputs.s3_data_type
示例#9
0
def test_format_inputs_to_input_config_training_input():
    inputs = TrainingInput(BUCKET_NAME)

    channels = _Job._format_inputs_to_input_config(inputs)

    assert (channels[0]["DataSource"]["S3DataSource"]["S3Uri"] ==
            inputs.config["DataSource"]["S3DataSource"]["S3Uri"])
示例#10
0
def test_format_inputs_to_input_config_s3_input():
    inputs = s3_input(BUCKET_NAME)

    channels = _Job._format_inputs_to_input_config(inputs)

    assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs.config['DataSource'][
        'S3DataSource']['S3Uri']
示例#11
0
def test_dict_of_mixed_input_types():
    input_list = _Job._format_inputs_to_input_config({
        'a': 's3://foo/bar',
        'b': s3_input('s3://whizz/bang')})

    expected = [
        {'ChannelName': 'a',
         'DataSource': {
             'S3DataSource': {
                 'S3DataDistributionType': 'FullyReplicated',
                 'S3DataType': 'S3Prefix',
                 'S3Uri': 's3://foo/bar'
             }
         }
         },
        {
            'ChannelName': 'b',
            'DataSource': {
                'S3DataSource': {
                    'S3DataDistributionType': 'FullyReplicated',
                    'S3DataType': 'S3Prefix',
                    'S3Uri': 's3://whizz/bang'
                }
            }
        }]

    # convert back into map for comparison so list order (which is arbitrary) is ignored
    assert {c['ChannelName']: c for c in input_list} == {c['ChannelName']: c for c in expected}
示例#12
0
def test_format_input_multiple_channels():
    input_list = _Job._format_inputs_to_input_config({
        "a": "s3://blah/blah",
        "b": "s3://foo/bar"
    })
    expected = [
        {
            "ChannelName": "a",
            "DataSource": {
                "S3DataSource": {
                    "S3DataDistributionType": "FullyReplicated",
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://blah/blah",
                }
            },
        },
        {
            "ChannelName": "b",
            "DataSource": {
                "S3DataSource": {
                    "S3DataDistributionType": "FullyReplicated",
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://foo/bar",
                }
            },
        },
    ]

    # convert back into map for comparison so list order (which is arbitrary) is ignored
    assert {c["ChannelName"]: c
            for c in input_list} == {c["ChannelName"]: c
                                     for c in expected}
示例#13
0
def test_format_inputs_to_input_config_list():
    records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
    inputs = [records]

    channels = _Job._format_inputs_to_input_config(inputs)

    assert channels[0]["DataSource"]["S3DataSource"]["S3Uri"] == records.s3_data
    assert channels[0]["DataSource"]["S3DataSource"]["S3DataType"] == records.s3_data_type
示例#14
0
def test_format_input_single_unamed_channel():
    input_dict = _Job._format_inputs_to_input_config("s3://blah/blah")
    assert input_dict == [{
        "ChannelName": "training",
        "DataSource": {
            "S3DataSource": {
                "S3DataDistributionType": "FullyReplicated",
                "S3DataType": "S3Prefix",
                "S3Uri": "s3://blah/blah",
            }
        },
    }]
示例#15
0
def test_format_input_single_unamed_channel():
    input_dict = _Job._format_inputs_to_input_config('s3://blah/blah')
    assert input_dict == [{
        'ChannelName': 'training',
        'DataSource': {
            'S3DataSource': {
                'S3DataDistributionType': 'FullyReplicated',
                'S3DataType': 'S3Prefix',
                'S3Uri': 's3://blah/blah'
            }
        }
    }]
示例#16
0
def test_format_input_s3_input():
    input_dict = _Job._format_inputs_to_input_config(s3_input('s3://foo/bar', distribution='ShardedByS3Key',
                                                              compression='gzip', content_type='whizz',
                                                              record_wrapping='bang'))
    assert input_dict == [{
        'CompressionType': 'gzip',
        'ChannelName': 'training',
        'ContentType': 'whizz',
        'DataSource': {
            'S3DataSource': {
                'S3DataType': 'S3Prefix',
                'S3DataDistributionType': 'ShardedByS3Key',
                'S3Uri': 's3://foo/bar'}},
        'RecordWrapperType': 'bang'}]
示例#17
0
def test_format_inputs_to_input_config_file_system_record_set():
    file_system_id = "fs-0a48d2a1"
    file_system_type = "EFS"
    directory_path = "ipinsights"
    num_records = 1
    feature_dim = 1
    records = FileSystemRecordSet(
        file_system_id=file_system_id,
        file_system_type=file_system_type,
        directory_path=directory_path,
        num_records=num_records,
        feature_dim=feature_dim,
    )
    channels = _Job._format_inputs_to_input_config(records)
    assert channels[0]["DataSource"]["FileSystemDataSource"]["DirectoryPath"] == directory_path
    assert channels[0]["DataSource"]["FileSystemDataSource"]["FileSystemId"] == file_system_id
    assert channels[0]["DataSource"]["FileSystemDataSource"]["FileSystemType"] == file_system_type
    assert channels[0]["DataSource"]["FileSystemDataSource"]["FileSystemAccessMode"] == "ro"
示例#18
0
def test_format_inputs_to_input_config_string():
    inputs = BUCKET_NAME

    channels = _Job._format_inputs_to_input_config(inputs)

    assert channels[0]["DataSource"]["S3DataSource"]["S3Uri"] == inputs
示例#19
0
def test_format_inputs_none():
    channels = _Job._format_inputs_to_input_config(inputs=None)

    assert channels is None
示例#20
0
def test_format_inputs_to_input_config_exception():
    inputs = 1

    with pytest.raises(ValueError):
        _Job._format_inputs_to_input_config(inputs)
示例#21
0
def test_format_inputs_to_input_config_exception():
    inputs = 1

    with pytest.raises(ValueError):
        _Job._format_inputs_to_input_config(inputs)
示例#22
0
def test_unsupported_type_in_dict():
    with pytest.raises(ValueError):
        _Job._format_inputs_to_input_config({'a': 66})
示例#23
0
def test_format_inputs_to_input_config_dict():
    inputs = {'train': BUCKET_NAME}

    channels = _Job._format_inputs_to_input_config(inputs)

    assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs['train']
示例#24
0
def test_unsupported_type_in_dict():
    with pytest.raises(ValueError):
        _Job._format_inputs_to_input_config({"a": 66})
示例#25
0
def test_format_inputs_to_input_config_string():
    inputs = BUCKET_NAME

    channels = _Job._format_inputs_to_input_config(inputs)

    assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs