示例#1
0
 def extract(node):
     mapping_rule = {
         'resample_type': 'caffe.ResampleParameter.NEAREST',
         'fw': 'tf',
         'antialias': 0
     }
     ResampleOp.update_node_stat(node, mapping_rule)
     return __class__.enabled
    def replace_pattern(graph: Graph, match: dict):
        node = match['interpolate']

        # common
        mode = node.mode
        assert mode in ['linear', 'nearest', 'cubic', 'area']
        in_shape = node.in_port(0).data.get_shape()
        assert in_shape is not None and len(in_shape) == 4
        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None and len(out_shape) == 4
        in_height, in_width = in_shape[2], in_shape[3]
        out_height, out_width = out_shape[2], out_shape[3]
        factor = factor_update(
            None if not node.has_valid('factor') else node.factor,
            [float(out_height) / in_height,
             float(out_width) / in_width], [in_height, in_width],
            [out_height, out_width], node.soft_get('name'))
        update_attrs = {
            'width': out_width,
            'height': out_height,
            'factor': factor,
        }

        if (node.has_valid('shrink_factor')
                and node.has_valid('zoom_factor')) or factor is None:
            del update_attrs['factor']
            if node.has('factor'):
                del node['factor']

        if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or
            (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \
                and ((not node.has_valid('width') or node.width == 0) and
                     (not node.has_valid('height') or node.height == 0)):
            update_attrs['width'] = 0
            update_attrs['height'] = 0

        # specific
        if mode in ['nearest', 'cubic', 'area'
                    ] or node.has_and_set('convert_to_resample'):
            assert not node.align_corners
            assert node.pads_begin == 0 and node.pads_end == 0
            update_attrs[
                'resample_type'] = InterpolateToInterpOrResample.type_map[mode]
            ResampleOp.update_node_stat(node, update_attrs)
            node.in_port(1).disconnect()
        elif mode == 'linear':
            update_attrs.update({
                'pad_beg': node.pads_begin,
                'pad_end': node.pads_end,
                'align_corners': node.align_corners,
            })
            InterpOp.update_node_stat(node, update_attrs)
            node.in_port(1).disconnect()
        node['force_precision_in_ports'] = None
示例#3
0
    def extract(node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)

        node_attrs = {
            'type': 'Resample',
            'factor': attrs.int("scale", 1),
            'resample_type': 'caffe.ResampleParameter.NEAREST',
            'antialias': 0
        }
        # update the attributes of the node
        ResampleOp.update_node_stat(node, node_attrs)
        return __class__.enabled
示例#4
0
    def extract(node):
        mode = onnx_attr(node, 'mode', 's', default='nearest', dst_type=lambda x: x.decode())
        scales = onnx_attr(node, 'scales', 'floats', dst_type=lambda x: np.array(x, dtype=np.float32))
        width_scale = onnx_attr(node, 'width_scale', 'f')
        height_scale = onnx_attr(node, 'height_scale', 'f')

        supported_modes = ['nearest', 'linear']
        if mode not in supported_modes:
            raise Error(
                'Error decoding Upsample node {}, mode = {} is not in the list of supported modes {}.',
                node.name,
                mode,
                supported_modes
            )

        # TODO: this is a temporary limitation
        if mode != 'nearest':
            raise Error(
                'Upsample mode {} for node {} is not supported. Only nearest is supported.',
                mode,
                node.name
            )

        # TODO: this is a temporary limitation
        if scales is not None:
            raise Error(
                'Upsample scales attribute is defined for node {}. Only scale_width and scale_height are supported.',
                node.name
            )

        if width_scale is None or height_scale is None:
            raise Error(
                'One/both of widths_scale = {} and height_scale = {} is not defined for Upsampe node {}.',
                width_scale,
                height_scale,
                node.name
            )

        if width_scale != height_scale:
            raise Error(
                'Upsample node {} have different widths_scale = {} and height_scale = {}. It is not supported; they should match.',
                node.name,
                width_scale,
                height_scale
            )

        mode_to_resample_type = {'nearest': 'caffe.ResampleParameter.NEAREST'}
        assert mode in mode_to_resample_type
        assert width_scale == height_scale
        assert width_scale is not None
        ResampleOp.update_node_stat(node, {'resample_type': mode_to_resample_type[mode], 'factor': width_scale, 'antialias': 0})
        return __class__.enabled
    def replace_pattern(graph: Graph, match: dict):
        node = match['interpolate']

        assert 1 in node.in_ports() and not node.in_port(1).disconnected() and \
               node.in_port(1).data.get_value() is not None, 'Interpolate node {} is corrupted: no 1-port input found'

        # common
        mode = node.mode
        assert mode in ['linear', 'nearest', 'cubic', 'area']
        in_shape = node.in_port(0).data.get_shape()
        assert in_shape is not None and len(in_shape) in [4, 5]
        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None and len(out_shape) in [4, 5]
        in_height, in_width = in_shape[2], in_shape[3]
        out_height, out_width = out_shape[2], out_shape[3]
        factor = factor_update(
            None if not node.has_valid('factor') else node.factor,
            [float(out_height) / in_height,
             float(out_width) / in_width], [in_height, in_width],
            [out_height, out_width], node.soft_get('name'))
        update_attrs = {
            'width': out_width,
            'height': out_height,
            'factor': factor,
        }

        if (node.has_valid('shrink_factor')
                and node.has_valid('zoom_factor')) or factor is None:
            del update_attrs['factor']
            if node.has('factor'):
                del node['factor']

        if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or
            (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \
                and ((not node.has_valid('width') or node.width == 0) and
                     (not node.has_valid('height') or node.height == 0)):
            update_attrs['width'] = 0
            update_attrs['height'] = 0

        # specific
        if mode in ['nearest', 'cubic', 'area'
                    ] or node.has_and_set('convert_to_resample'):
            assert not node.align_corners
            assert node.pads_begin == 0 and node.pads_end == 0
            update_attrs[
                'resample_type'] = InterpolateToInterpOrResample.type_map[mode]
            ResampleOp.update_node_stat(node, update_attrs)

            if not graph.graph[
                    'cmd_params'].keep_shape_ops or graph.graph['fw'] != 'tf':
                node.in_port(1).disconnect()
            else:
                # we avoid making resample non-reshapable for tf version
                shape = Shape(graph, {}).create_node()
                node.in_port(0).get_source().connect(shape.in_port(0))

                batch = node_to_get_batch_value(shape)
                features = node_to_get_features_dimension_value(shape)
                full_shape = new_shape_node_from_shape_nodes(
                    [batch, features,
                     node.in_port(1).get_source().node])
                node.in_port(1).get_connection().set_source(
                    full_shape.out_port(0))
                full_shape['override_output_shape'] = True

        elif mode == 'linear':
            assert len(in_shape) == 4, 'Interp does not support 5D input'
            update_attrs.update({
                'pad_beg': node.pads_begin,
                'pad_end': node.pads_end,
                'align_corners': node.align_corners,
            })
            InterpOp.update_node_stat(node, update_attrs)
            node.in_port(1).disconnect()