def get_dim_from_layout(node: Node, dim: str): """ Gets index of dimension from layout specified for node. :param node: node to get dim for. :param dim: name of dimension to get index for. :return: tuple with index of the dimension and bool flag if the node has layout specified or no. """ layout = None graph = node.graph if 'layout_values' in graph.graph['cmd_params'] and graph.graph[ 'cmd_params'].layout_values: layout_values = graph.graph['cmd_params'].layout_values.copy() if '' in layout_values: in_nodes = graph.get_op_nodes(op='Parameter') if len(in_nodes) == 1: in_node = in_nodes[0] layout_values[in_node.soft_get('name', in_node.id)] = layout_values[''] del layout_values[''] name = node.soft_get('name', node.id) if name in layout_values: if layout_values[name]['source_layout']: layout = layout_values[name]['source_layout'] if layout: from openvino.runtime import Layout # pylint: disable=no-name-in-module,import-error layout_parsed = Layout(layout) has_dim = layout_parsed.has_name(dim) if has_dim: idx = layout_parsed.get_index_by_name(dim) if idx < 0: idx = len(node.shape) + idx return idx, True else: return None, True else: return None, False
class AppInputInfo: def __init__(self): self.element_type = None self.layout = Layout() self.original_shape = None self.partial_shape = None self.data_shapes = [] self.scale = [] self.mean = [] self.name = None @property def is_image(self): if str(self.layout) not in [ "[N,C,H,W]", "[N,H,W,C]", "[C,H,W]", "[H,W,C]" ]: return False return self.channels == 3 @property def is_image_info(self): if str(self.layout) != "[N,C]": return False return self.channels.relaxes(Dimension(2)) def getDimentionByLayout(self, character): if self.layout.has_name(character): return self.partial_shape[self.layout.get_index_by_name(character)] else: return Dimension(0) def getDimentionsByLayout(self, character): if self.layout.has_name(character): d_index = self.layout.get_index_by_name(character) dims = [] for shape in self.data_shapes: dims.append(shape[d_index]) return dims else: return [0] * len(self.data_shapes) @property def shapes(self): if self.is_static: return [self.partial_shape.to_shape()] else: return self.data_shapes @property def width(self): return len(self.getDimentionByLayout("W")) @property def widthes(self): return self.getDimentionsByLayout("W") @property def height(self): return len(self.getDimentionByLayout("H")) @property def heights(self): return self.getDimentionsByLayout("H") @property def channels(self): return self.getDimentionByLayout("C") @property def is_static(self): return self.partial_shape.is_static @property def is_dynamic(self): return self.partial_shape.is_dynamic