Ejemplo n.º 1
0
    def extract(cls, node):
        param = node.pb.roi_pooling_param
        attrs = {
            'pooled_h': param.pooled_h,
            'pooled_w': param.pooled_w,
            'spatial_scale': param.spatial_scale,
        }

        ROIPooling.update_node_stat(node, attrs)
        return cls.enabled
Ejemplo n.º 2
0
 def extract(node):
     # update the attributes of the node and force 'op' to be 'CropAndResize' so extension that merges two of its
     # inputs would be called
     method = node.pb.attr['method'].s.decode('utf-8')
     if method != 'bilinear':
         log.warning(
             'The crop and resize method "{}" for node "{}" is not supported.'
             .format(method, node.soft_get('name')))
         return False
     ROIPooling.update_node_stat(node, {
         'spatial_scale': 1,
         'op': 'CropAndResize',
         'method': method
     })
     return __class__.enabled
Ejemplo n.º 3
0
    def extract(node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)

        spatial_scale = attrs.float("spatial_scale", None)
        pooled_size = attrs.tuple("pooled_size", int, (0, 0))
        data = {
            'type': 'ROIPooling',
            'spatial_scale': spatial_scale,
            'pooled_w': pooled_size[1],
            'pooled_h': pooled_size[0]
        }

        data.update(layout_attrs())

        # update the attributes of the node
        ROIPooling.update_node_stat(node, data)
        return __class__.enabled