Пример #1
0
def test_key_selector_arr(mode, record_arr):
    params = {"inputs": {"img": "x"}, "outputs": {"label": "y"}}
    ks = KeySelector(mode, params)

    result = ks.load(record_arr)
    expected = ({"img": np.zeros((28, 28, 1))}, {"label": np.array([7])})
    if mode == RecordMode.SCORE:
        expected = (expected[0], )
    _assert_batch_equal(result, expected)
Пример #2
0
def test_key_selector_invalid_values(mode, inputs, outputs, sample_weights,
                                     err):
    params = {
        "inputs": inputs,
        "outputs": outputs,
        "sample_weights": sample_weights
    }
    with pytest.raises(err):
        KeySelector(mode, params)
Пример #3
0
def test_key_selector(mode, record):
    # no sample weights
    params = {
        "inputs": {
            "x1": ["a", "b"],
            "x2": ["c"]
        },
        "outputs": {
            "y1": ["d"],
            "y2": ["e"]
        },
    }
    ks = KeySelector(mode, params)

    result = ks.load(record)
    expected = (
        {
            "x1": np.array([1, 2]),
            "x2": np.array(["three"])
        },
        {
            "y1": np.array([4]),
            "y2": np.array(["five"])
        },
    )
    if mode == RecordMode.SCORE:
        expected = (expected[0], )
    _assert_batch_equal(result, expected)

    # sample weights
    params = {
        "inputs": {
            "x1": ["a", "b"],
            "x2": ["c"]
        },
        "outputs": {
            "y1": ["d"],
            "y2": ["e"]
        },
        "sample_weights": {
            "y1": "a",
            "y2": "b"
        },
    }
    ks = KeySelector(mode, params)

    result = ks(record)
    expected = (
        {
            "x1": np.array([1, 2]),
            "x2": np.array(["three"])
        },
        {
            "y1": np.array([4]),
            "y2": np.array(["five"])
        },
        {
            "y1": 1,
            "y2": 2
        },
    )
    if mode == RecordMode.SCORE:
        expected = (expected[0], )
    _assert_batch_equal(result, expected)
Пример #4
0
def test_key_selector_invalid_keys(mode, params):
    with pytest.raises(KeyError):
        KeySelector(mode, params)