예제 #1
0
def test_layout_helpers():
    layout = ov.Layout("NCHWD")
    assert (layout_helpers.has_batch(layout))
    assert (layout_helpers.has_channels(layout))
    assert (layout_helpers.has_depth(layout))
    assert (layout_helpers.has_height(layout))
    assert (layout_helpers.has_width(layout))

    assert layout_helpers.batch_idx(layout) == 0
    assert layout_helpers.channels_idx(layout) == 1
    assert layout_helpers.height_idx(layout) == 2
    assert layout_helpers.width_idx(layout) == 3
    assert layout_helpers.depth_idx(layout) == 4

    layout = ov.Layout("N...C")
    assert (layout_helpers.has_batch(layout))
    assert (layout_helpers.has_channels(layout))
    assert not (layout_helpers.has_depth(layout))
    assert not (layout_helpers.has_height(layout))
    assert not (layout_helpers.has_width(layout))

    assert layout_helpers.batch_idx(layout) == 0
    assert layout_helpers.channels_idx(layout) == -1

    with pytest.raises(RuntimeError):
        layout_helpers.height_idx(layout)

    with pytest.raises(RuntimeError):
        layout_helpers.width_idx(layout)

    with pytest.raises(RuntimeError):
        layout_helpers.depth_idx(layout)

    layout = ov.Layout("NC?")
    assert (layout_helpers.has_batch(layout))
    assert (layout_helpers.has_channels(layout))
    assert not (layout_helpers.has_depth(layout))
    assert not (layout_helpers.has_height(layout))
    assert not (layout_helpers.has_width(layout))

    assert layout_helpers.batch_idx(layout) == 0
    assert layout_helpers.channels_idx(layout) == 1

    with pytest.raises(RuntimeError):
        layout_helpers.height_idx(layout)

    with pytest.raises(RuntimeError):
        layout_helpers.width_idx(layout)

    with pytest.raises(RuntimeError):
        layout_helpers.depth_idx(layout)
예제 #2
0
def check_suitable_for_reverse(layout: Layout, ov_input):
    """
    Internal function. Checks if input with layout is suitable for reversing channels
    :param: layout Existing source/target layout items specified by user
    :param: ov_input Model's input
    :return: True if reverse channels can be applied to input
    """
    if not layout_helpers.has_channels(layout):
        return False
    if ov_input.get_partial_shape().rank.is_dynamic:
        return False

    c_idx = layout_helpers.channels_idx(layout)
    rank = ov_input.get_partial_shape().rank.get_length()
    if c_idx < 0:
        c_idx += rank
    if c_idx >= rank:
        raise Error('Layout {} for input {} is inconsistent with shape {}'.format(
            layout, ov_input.get_tensor().get_any_name(), ov_input.get_partial_shape()))
    c_num = ov_input.get_partial_shape()[c_idx]
    return c_num.is_dynamic or c_num.get_length() == 3