def test_reshape_infer(self, input_value, input_shape, output_shape, ref_value, ref_shape): graph = build_graph( nodes_attributes, [('input', 'data'), ('data', 'reshape'), ('output_shape', 'output_shape_data'), ('output_shape_data', 'reshape'), ('reshape', 'reshape_out')], { 'data': { 'shape': input_shape, 'value': input_value }, 'output_shape': { 'value': output_shape, 'shape': output_shape.shape }, 'output_shape_data': { 'value': output_shape, 'shape': output_shape.shape }, }) node = Node(graph, 'reshape') Reshape.infer(node) if ref_value is not None: self.assertTrue( strict_compare_tensors( node.out_port(0).data.get_value(), shape_array(ref_value))) self.assertTrue( strict_compare_tensors( node.out_port(0).data.get_shape(), shape_array(ref_shape)))
def replace_pattern(graph: Graph, match: dict): node = match['op'] input_shape = node.in_port(0).data.get_shape() if len(input_shape) > 2: new_shape = Const(graph, { 'value': np.array([0, -1], dtype=np.int64) }).create_node() reshape = Reshape(graph, {}).create_node() source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) source.connect(reshape.in_port(0)) new_shape.out_port(0).connect(reshape.in_port(1)) new_shape.infer(new_shape) reshape.infer(reshape)
def replace_pattern(self, graph: Graph, match: dict): conv = match['conv'] assert len(conv.out_nodes()) == 1, "Convolution operation {} should have 1 output data node".format(conv.id) out_data = conv.out_node() assert out_data.has_valid('shape'), 'Output shape is undefined for {} in back phase'.format(conv.id) out_shape = out_data.shape if out_shape.size != 3: return assert len(conv.in_nodes()) >= 1, "Convolution operation {} should have more than 1 input data node".format( conv.id) inp_data = conv.in_node() assert inp_data.has_valid('shape'), 'Input shape is undefined for {} in back phase'.format(conv.id) inp_shape = inp_data.shape new_inp_shape = np.insert(inp_shape, 2, 1) # setting to None to be overwritten by infer function conv.kernel_spatial_idx = None conv.spatial_dims = None # inserting fake H dimension conv.dilation = np.insert(conv.dilation, 2, 1) conv.kernel_spatial = np.append([1], conv.kernel_spatial) conv.pad = np.insert(conv.pad, 2, [0, 0], axis=0) conv.stride = np.insert(conv.stride, 2, 1) weights_node = conv.in_node(1) weights_node.value = np.reshape(weights_node.value, np.insert(weights_node.value.shape, 2, 1)) weights_node.shape = np.array(weights_node.value.shape, dtype=np.int64) reshape = Reshape(graph, {'name': conv.name + '/reshape'}).create_node() reshape_dim = Const(graph, {'value': new_inp_shape, 'name': reshape.id + '/Dim'}).create_node() conv.in_port(0).get_connection().insert_node(reshape) reshape.in_port(1).connect(reshape_dim.out_port(0)) reshape_back = Reshape(graph, {'name': conv.name + '/reshape_back'}).create_node() reshape_back_dim = Const(graph, {'value': out_shape, 'name': reshape.id + '/Dim'}).create_node() conv.out_port(0).get_connection().insert_node(reshape_back) reshape_back.in_port(1).connect(reshape_back_dim.out_port(0)) # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour reshape_dim.infer(reshape_dim) reshape.infer(reshape) conv.infer(conv)
def replace_pattern(graph: Graph, match: dict): node = match['matmul'] name = node.soft_get('name', node.id) A_shape = node.in_port(0).data.get_shape() B_shape = node.in_port(1).data.get_shape() out_shape = node.out_port(0).data.get_shape() assert A_shape is not None and B_shape is not None and out_shape is not None B_value = node.in_port(1).data.get_value() if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[ B_shape != 1].size <= 2: # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] # to FullyConnected representation: [I, K] * [O, K] = [I, O] B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node) # weights normalization if not node.transpose_b: # FullyConnected weights layout is OI # MatMul second input layout is (B)IO transpose_order = list(range(B_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(weights_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if node.in_port(1).data.get_shape().size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(weights_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \ "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O) node.in_port(1).bin = 'weights' del node['transpose_b'] # input normalization if node.transpose_a: transpose_order = list(range(A_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(input_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if A_shape.size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(input_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \ "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O) del node['transpose_a'] FullyConnected.update_node_stat(node, {'out-size': O}) # output normalization if out_shape.size != 2: const = Const(graph, {'value': int64_array([*B, I, O])}).create_node() reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node() dst = node.out_port(0).get_destination() node.out_port(0).get_connection().set_destination(reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(dst) node.infer(node) const.infer(const) reshape.infer(reshape) else: assert A_shape.size == out_shape.size assert B_shape.size <= out_shape.size if B_shape.size != out_shape.size: unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size))) }).create_node() unsqueeze = Unsqueeze(graph, {}).create_node() B_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(B_source) unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0)) unsqueeze_dim.infer(unsqueeze_dim) unsqueeze.infer(unsqueeze) Gemm.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })
def replace_pattern(graph: Graph, match: dict): """ Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors): Searches for Tiles with 3D shapes and covers it with Reshapes. Example: Tile (axis=1, tiles=16): in_shape: [1,1,101] out_shape: [1,16,101] Old behaviour: Tile -> [1,16,101] New behaviour: Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101] """ tile = match['tile'] assert len(tile.out_nodes( )) == 1, "Tile operation {} should have 1 output data node".format( tile.id) out_data = tile.out_node() assert out_data.has_valid( 'shape'), 'Output shape is undefined for {} in back phase'.format( tile.id) out_shape = out_data.shape if out_shape.size != 3: return assert len(tile.in_nodes( )) == 1, "Tile operation {} should have 1 input data node".format( tile.id) inp_data = tile.in_node() assert inp_data.has_valid( 'shape'), 'Input shape is undefined for {} in back phase'.format( tile.id) inp_shape = inp_data.shape new_inp_shape = np.append(inp_shape, [1]) reshape = Reshape(graph, { 'name': tile.name + '/reshape' }).create_node() reshape_dim = Const(graph, { 'value': new_inp_shape, 'name': reshape.id + '/Dim' }).create_node() tile.in_port(0).get_connection().insert_node(reshape) reshape.in_port(1).connect(reshape_dim.out_port(0)) reshape_back = Reshape(graph, { 'name': tile.name + '/reshape_back' }).create_node() reshape_back_dim = Const(graph, { 'value': out_shape, 'name': reshape.id + '/Dim' }).create_node() tile.out_port(0).get_connection().insert_node(reshape_back) reshape_back.in_port(1).connect(reshape_back_dim.out_port(0)) # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour reshape_dim.infer(reshape_dim) reshape.infer(reshape) tile.infer(tile)