def test_caffe_interp_infer_wh(self): graph = build_graph( nodes_attributes, [('node_1', 'interp'), ('interp', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None }, 'node_1': { 'shape': np.array([1, 1024, 1, 1]) }, 'interp': { 'width': 65, 'height': 33, 'zoom_factor': 1, 'shrink_factor': 1, 'pad_beg': 0, 'pad_end': 0 } }) graph.graph['layout'] = 'NCHW' interp_node = Node(graph, 'interp') InterpOp.interp_infer(interp_node) exp_shape = np.array([1, 1024, 33, 65]) res_shape = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])
def test_caffe_interp_2_blobs(self): graph = build_graph( nodes_attributes, [('node_1', 'interp'), ('node_2', 'interp'), ('interp', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None }, 'node_1': { 'shape': np.array([1, 256, 33, 66]) }, 'node_2': { 'shape': np.array([1, 1, 3, 6]) }, 'interp': { 'zoom_factor': 1, 'shrink_factor': 1, 'pad_beg': 0, 'pad_end': 0, 'parse_2nd_input': 'shape', } }) graph.graph['layout'] = 'NCHW' interp_node = Node(graph, 'interp') InterpOp.interp_infer(interp_node) exp_shape = np.array([1, 256, 3, 6]) res_shape = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])
def extract(node): mapping_rule = { 'pad_end': 0, 'pad_beg': 0, 'align_corners': int(node.pb.attr['align_corners'].b) } InterpOp.update_node_stat(node, mapping_rule) return __class__.enabled
def replace_pattern(graph: Graph, match: dict): node = match['interpolate'] # common mode = node.mode assert mode in ['linear', 'nearest', 'cubic', 'area'] in_shape = node.in_port(0).data.get_shape() assert in_shape is not None and len(in_shape) == 4 out_shape = node.out_port(0).data.get_shape() assert out_shape is not None and len(out_shape) == 4 in_height, in_width = in_shape[2], in_shape[3] out_height, out_width = out_shape[2], out_shape[3] factor = factor_update( None if not node.has_valid('factor') else node.factor, [float(out_height) / in_height, float(out_width) / in_width], [in_height, in_width], [out_height, out_width], node.soft_get('name')) update_attrs = { 'width': out_width, 'height': out_height, 'factor': factor, } if (node.has_valid('shrink_factor') and node.has_valid('zoom_factor')) or factor is None: del update_attrs['factor'] if node.has('factor'): del node['factor'] if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \ and ((not node.has_valid('width') or node.width == 0) and (not node.has_valid('height') or node.height == 0)): update_attrs['width'] = 0 update_attrs['height'] = 0 # specific if mode in ['nearest', 'cubic', 'area' ] or node.has_and_set('convert_to_resample'): assert not node.align_corners assert node.pads_begin == 0 and node.pads_end == 0 update_attrs[ 'resample_type'] = InterpolateToInterpOrResample.type_map[mode] ResampleOp.update_node_stat(node, update_attrs) node.in_port(1).disconnect() elif mode == 'linear': update_attrs.update({ 'pad_beg': node.pads_begin, 'pad_end': node.pads_end, 'align_corners': node.align_corners, }) InterpOp.update_node_stat(node, update_attrs) node.in_port(1).disconnect() node['force_precision_in_ports'] = None
def test_tf_interp_infer_two_inputs(self): graph = build_graph( nodes_attributes, [('node_1', 'interp'), ('node_2', 'interp'), ('interp', 'node_3')], { 'node_1': { 'shape': np.array([1, 20, 30, 100]) }, 'node_2': { 'shape': np.array([2]), 'value': np.array([2, 3]) } }) graph.graph['layout'] = 'NHWC' interp_node = Node(graph, 'interp') InterpOp.interp_infer(interp_node) exp_shape = np.array([1, 2, 3, 100]) res_shape = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])
def test_tf_interp_infer_one_input_hw(self): graph = build_graph( nodes_attributes, [('node_1', 'interp'), ('interp', 'node_3')], { 'node_1': { 'shape': np.array([1, 20, 30, 100]) }, 'interp': { 'height': 4, 'width': 6, 'pad_beg': 0, 'pad_end': 0, 'zoom_factor': None, 'shrink_factor': None } }) graph.graph['layout'] = 'NHWC' interp_node = Node(graph, 'interp') InterpOp.interp_infer(interp_node) exp_shape = np.array([1, 4, 6, 100]) res_shape = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])
def test_caffe_interp_infer_zoom_shrink_error(self): graph = build_graph( nodes_attributes, [('node_1', 'interp'), ('interp', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'shape': None }, 'node_1': { 'shape': np.array([1, 256, 33, 65]) }, 'interp': { 'zoom_factor': 0, 'height': 0, 'width': 0, 'shrink_factor': 0, 'pad_beg': 0, 'pad_end': 0 } }) graph.graph['layout'] = 'NCHW' interp_node = Node(graph, 'interp') InterpOp.interp_infer(interp_node) self.assertIsNone(graph.node['node_3']['shape'])
def replace_pattern(graph: Graph, match: dict): node = match['interpolate'] assert 1 in node.in_ports() and not node.in_port(1).disconnected() and \ node.in_port(1).data.get_value() is not None, 'Interpolate node {} is corrupted: no 1-port input found' # common mode = node.mode assert mode in ['linear', 'nearest', 'cubic', 'area'] in_shape = node.in_port(0).data.get_shape() assert in_shape is not None and len(in_shape) in [4, 5] out_shape = node.out_port(0).data.get_shape() assert out_shape is not None and len(out_shape) in [4, 5] in_height, in_width = in_shape[2], in_shape[3] out_height, out_width = out_shape[2], out_shape[3] factor = factor_update( None if not node.has_valid('factor') else node.factor, [float(out_height) / in_height, float(out_width) / in_width], [in_height, in_width], [out_height, out_width], node.soft_get('name')) update_attrs = { 'width': out_width, 'height': out_height, 'factor': factor, } if (node.has_valid('shrink_factor') and node.has_valid('zoom_factor')) or factor is None: del update_attrs['factor'] if node.has('factor'): del node['factor'] if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \ and ((not node.has_valid('width') or node.width == 0) and (not node.has_valid('height') or node.height == 0)): update_attrs['width'] = 0 update_attrs['height'] = 0 # specific if mode in ['nearest', 'cubic', 'area' ] or node.has_and_set('convert_to_resample'): assert not node.align_corners assert node.pads_begin == 0 and node.pads_end == 0 update_attrs[ 'resample_type'] = InterpolateToInterpOrResample.type_map[mode] ResampleOp.update_node_stat(node, update_attrs) if not graph.graph[ 'cmd_params'].keep_shape_ops or graph.graph['fw'] != 'tf': node.in_port(1).disconnect() else: # we avoid making resample non-reshapable for tf version shape = Shape(graph, {}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) batch = node_to_get_batch_value(shape) features = node_to_get_features_dimension_value(shape) full_shape = new_shape_node_from_shape_nodes( [batch, features, node.in_port(1).get_source().node]) node.in_port(1).get_connection().set_source( full_shape.out_port(0)) full_shape['override_output_shape'] = True elif mode == 'linear': assert len(in_shape) == 4, 'Interp does not support 5D input' update_attrs.update({ 'pad_beg': node.pads_begin, 'pad_end': node.pads_end, 'align_corners': node.align_corners, }) InterpOp.update_node_stat(node, update_attrs) node.in_port(1).disconnect()