コード例 #1
0
    def extract(cls, node):
        proto_layer = node.pb
        param = proto_layer.argmax_param

        update_attrs = {
            'out_max_val': int(param.out_max_val),
            'top_k': param.top_k,
            'axis': param.axis,
        }

        mapping_rule = merge_attrs(param, update_attrs)

        ArgMaxOp.update_node_stat(node, mapping_rule)
        # ArgMax must be converted to TopK but without the output with values
        ArgMaxOp.update_node_stat(node, {'remove_values_output': True})
        return cls.enabled
コード例 #2
0
ファイル: argmax_ext.py プロジェクト: zoeysgithub/openvino
    def extract(cls, node):
        keepdims = onnx_attr(node, 'keepdims', 'i', default=1)
        axis = onnx_attr(node, 'axis', 'i', default=0)

        attrs = {
            'axis': axis,

            # ONNX ArgMax always computes an index of one maximum value
            'top_k': 1,
            'out_max_val': 0,

            # Set attribute to trigger ArgMax replacer in case do not keep the dimension
            'keepdims': keepdims
        }

        ArgMaxOp.update_node_stat(node, attrs)
        return cls.enabled
コード例 #3
0
ファイル: argmax_ext.py プロジェクト: www096/openvino
 def extract(cls, node):
     ArgMaxOp.update_node_stat(
         node, {
             'out_max_val':
             0,
             'top_k':
             1,
             'axis':
             None,
             'dim_attrs': ['axis'],
             'keepdims':
             0,
             'remove_values_output':
             True,
             'output_type':
             tf_dtype_extractor(node.pb.attr['out_type'].type, np.int64),
         })
     return cls.enabled
コード例 #4
0
ファイル: argmax_ext.py プロジェクト: srinivasdasu24/dldt
 def extract(cls, node):
     ArgMaxOp.update_node_stat(node, {'out_max_val': 0, 'top_k': 1, 'axis': None,
                                      'dim_attrs': ['axis'], 'keepdims': 0, 'remove_values_output': True})
     return cls.enabled