예제 #1
0
 def extract(cls, node: Node):
     Slice.update_node_stat(node, {
         'axis': None,
         'start': None,
         'end': None,
     })
     return cls.enabled
예제 #2
0
파일: slice_ext.py 프로젝트: pc2/CustoNN2
 def extract(node: Node):
     Slice.update_node_stat(node, {
         'axis': None,
         'start': None,
         'end': None,
     })
     return __class__.enabled
예제 #3
0
 def extract(node):
     pb = node.parameters
     num_slice_points = read_binary_integer32_token(pb)
     mapping_rule = {
         'axis': 1,
         'slice_point': read_blob(pb, num_slice_points, np.int32),
         'batch_dims': 0,
         'spatial_dims': 1,
         'infer': caffe_slice_infer
     }
     node.parameters.close()
     Slice.update_node_stat(node, mapping_rule)
     return __class__.enabled
예제 #4
0
    def extract(node):
        axis = np.array(onnx_attr(node, 'axes', 'ints', default=[]), dtype=np.int64)
        start = np.array(onnx_attr(node, 'starts', 'ints', default=[]), dtype=np.int64)
        end = np.array(onnx_attr(node, 'ends', 'ints', default=[]), dtype=np.int64)

        attrs = {
            'axis': axis if len(axis) != 0 else None,
            'start': start if len(start) != 0 else None,
            'end': end if len(end) != 0 else None,
        }

        # update the attributes of the node
        Slice.update_node_stat(node, attrs)
        return __class__.enabled
예제 #5
0
    def extract(cls, node):
        if get_onnx_opset_version(node) < 10:
            starts = int64_array(onnx_attr(node, 'starts', 'ints', default=[]))
            ends = int64_array(onnx_attr(node, 'ends', 'ints', default=[]))
            axes = int64_array(onnx_attr(node, 'axes', 'ints', default=[]))

            if len(starts) == 0 or len(ends) == 0:
                raise Error("starts or/and ends are not specified for the node {}".format(node.name))
            if len(axes) == 0:
                axes = np.arange(len(starts), dtype=np.int)

            attrs = {'axes': axes, 'starts': starts, 'ends': ends}
            AttributedSlice.update_node_stat(node, attrs)
        else:  # onnx_opset_version >= 10
            Slice.update_node_stat(node)
        return cls.enabled