コード例 #1
0
def test_dot():
    # Check dot in input name
    shape_string = "input.1:[10,10,10]"
    shape_dict = parse_shape_string(shape_string)
    assert shape_dict == {"input.1": [10, 10, 10]}
コード例 #2
0
def test_invalid_colon():
    shape_string = "gpu_0/data_0:5,10 :test:10,10"
    with pytest.raises(argparse.ArgumentTypeError):
        parse_shape_string(shape_string)
コード例 #3
0
def test_invalid_slashes(shape_string):
    with pytest.raises(argparse.ArgumentTypeError):
        parse_shape_string(shape_string)
コード例 #4
0
def test_invalid_pattern():
    shape_string = "input:[a,10]"
    with pytest.raises(argparse.ArgumentTypeError):
        parse_shape_string(shape_string)
コード例 #5
0
def test_invalid_separators():
    shape_string = "input:5,10 input2:10,10"
    with pytest.raises(argparse.ArgumentTypeError):
        parse_shape_string(shape_string)
コード例 #6
0
def test_multiple_valid_gpu_inputs():
    # Check that multiple valid gpu inputs are parsed correctly.
    shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]"
    shape_dict = parse_shape_string(shape_string)
    expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}"
    assert str(shape_dict) == expected
コード例 #7
0
def test_negative_dimensions():
    # Check that negative dimensions parse to Any correctly.
    shape_string = "input:[-1,3,224,224]"
    shape_dict = parse_shape_string(shape_string)
    # Convert to strings to allow comparison with Any.
    assert str(shape_dict) == "{'input': [?, 3, 224, 224]}"
コード例 #8
0
def test_alternate_syntaxes(shape_string):
    shape_dict = parse_shape_string(shape_string)
    assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
コード例 #9
0
def test_alternate_syntax():
    shape_string = "input:0:[10,10,10] input2:[20,20,20,20]"
    shape_dict = parse_shape_string(shape_string)
    assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]}
コード例 #10
0
def test_shape_parser():
    # Check that a valid input is parsed correctly
    shape_string = "input:[10,10,10]"
    shape_dict = parse_shape_string(shape_string)
    assert shape_dict == {"input": [10, 10, 10]}