def test_layout_helpers(): layout = ov.Layout("NCHWD") assert (layout_helpers.has_batch(layout)) assert (layout_helpers.has_channels(layout)) assert (layout_helpers.has_depth(layout)) assert (layout_helpers.has_height(layout)) assert (layout_helpers.has_width(layout)) assert layout_helpers.batch_idx(layout) == 0 assert layout_helpers.channels_idx(layout) == 1 assert layout_helpers.height_idx(layout) == 2 assert layout_helpers.width_idx(layout) == 3 assert layout_helpers.depth_idx(layout) == 4 layout = ov.Layout("N...C") assert (layout_helpers.has_batch(layout)) assert (layout_helpers.has_channels(layout)) assert not (layout_helpers.has_depth(layout)) assert not (layout_helpers.has_height(layout)) assert not (layout_helpers.has_width(layout)) assert layout_helpers.batch_idx(layout) == 0 assert layout_helpers.channels_idx(layout) == -1 with pytest.raises(RuntimeError): layout_helpers.height_idx(layout) with pytest.raises(RuntimeError): layout_helpers.width_idx(layout) with pytest.raises(RuntimeError): layout_helpers.depth_idx(layout) layout = ov.Layout("NC?") assert (layout_helpers.has_batch(layout)) assert (layout_helpers.has_channels(layout)) assert not (layout_helpers.has_depth(layout)) assert not (layout_helpers.has_height(layout)) assert not (layout_helpers.has_width(layout)) assert layout_helpers.batch_idx(layout) == 0 assert layout_helpers.channels_idx(layout) == 1 with pytest.raises(RuntimeError): layout_helpers.height_idx(layout) with pytest.raises(RuntimeError): layout_helpers.width_idx(layout) with pytest.raises(RuntimeError): layout_helpers.depth_idx(layout)
def check_suitable_for_reverse(layout: Layout, ov_input): """ Internal function. Checks if input with layout is suitable for reversing channels :param: layout Existing source/target layout items specified by user :param: ov_input Model's input :return: True if reverse channels can be applied to input """ if not layout_helpers.has_channels(layout): return False if ov_input.get_partial_shape().rank.is_dynamic: return False c_idx = layout_helpers.channels_idx(layout) rank = ov_input.get_partial_shape().rank.get_length() if c_idx < 0: c_idx += rank if c_idx >= rank: raise Error('Layout {} for input {} is inconsistent with shape {}'.format( layout, ov_input.get_tensor().get_any_name(), ov_input.get_partial_shape())) c_num = ov_input.get_partial_shape()[c_idx] return c_num.is_dynamic or c_num.get_length() == 3