コード例 #1
0
def add_activation_function_after_node(graph: Graph, node: Node,
                                       activation_function: str):
    """
    The function adds node with activation function defined by string 'activation_function' which gets input from the
    node 'node'.
    :param graph: graph to operate on.
    :param node: node to add activation after.
    :param activation_function: string defining the activation function. These values are read from TensorFlow* object
    detection API pipeline configuration file
    :return: activation function node.
    """
    if activation_function == 'SOFTMAX':
        # softmax to be applied to the confidence
        softmax_conf_op = Softmax(graph, dict(axis=-1, nchw_layout=True))
        activation_node = softmax_conf_op.create_node([node],
                                                      dict(name=node.name +
                                                           '/softmax'))
    elif activation_function == 'SIGMOID':
        # sigmoid activation function to be applied to the confidence
        sigmoid_conf_op = Sigmoid(graph, dict(nchw_layout=True))
        activation_node = sigmoid_conf_op.create_node([node],
                                                      dict(name=node.name +
                                                           '/sigmoid'))
    elif activation_function == 'IDENTITY':
        # in case of Identity do nothing and just use result from the input node
        activation_node = node
    else:
        raise Error('Unknown post-processing activation function "{}".'.format(
            activation_function))
    return activation_node
コード例 #2
0
ファイル: softmax_ext.py プロジェクト: zkzt/openvino
 def extract(cls, node):
     # the default value for the TF Softmax is -1
     axis = -1
     if 'axis' in node.pb.attr:
         axis = node.pb.attr['axis'].i
     Softmax.update_node_stat(node, {'axis': axis})
     return cls.enabled
コード例 #3
0
ファイル: softmax.py プロジェクト: zhenlusu500/openvino
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):

        reduce_max_axis = match['reduce_indices_max'].value
        reduce_sum_axis = match['reduce_indices_sum'].value

        if reduce_max_axis.ndim == 0:
            reduce_max_axis = reduce_max_axis.reshape([1])

        if reduce_sum_axis.ndim == 0:
            reduce_sum_axis = reduce_sum_axis.reshape([1])

        if len(reduce_max_axis) != 1:
            log.info('The reductions indices contain more than 1 element. Cannot convert to Softmax.')
            return

        if not np.array_equal(reduce_max_axis, reduce_sum_axis):
            log.info('The reduce indices are not equal: {} vs {}. Cannot convert to Softmax'
                     ''.format(reduce_max_axis, reduce_sum_axis))
            return

        softmax = Softmax(graph, {'name': match['input'].name + '/Softmax', 'axis': reduce_sum_axis[0]}).create_node()
        match['input'].out_port(0).connect(softmax.in_port(0))
        match['div'].out_port(0).get_connection().set_source(softmax.out_port(0))

        log.debug('Successfully created SoftMax node')
コード例 #4
0
 def extract(cls, node):
     attrs = {
         'op': __class__.op,
         'axis': node.module.dim,
     }
     Softmax.update_node_stat(node, attrs)
     return cls.enabled
コード例 #5
0
ファイル: softmax_ext.py プロジェクト: pc2/CustoNN2
    def extract(node):
        axis = onnx_attr(node, 'axis', 'i', default=1)

        attrs = {'axis': axis}

        # update the attributes of the node
        Softmax.update_node_stat(node, attrs)
        return __class__.enabled
コード例 #6
0
    def extract(node):
        proto_layer = node.pb
        param = proto_layer.softmax_param

        attrs = {'axis': param.axis}

        # update the attributes of the node
        Softmax.update_node_stat(node, attrs)
        return __class__.enabled
コード例 #7
0
    def extract(cls, node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)

        update_attrs = {
            'type': 'SoftMax',
            'axis': attrs.int("axis", -1),
            'temperature': attrs.float('temperature', 1.0)
        }

        # update the attributes of the node
        Softmax.update_node_stat(node, update_attrs)
        return cls.enabled
コード例 #8
0
    def replace_op(self, graph: Graph, node: Node):
        log = LogOp(graph, {'name': node.name + '/Log_'}).create_node()
        softmax = Softmax(graph, {
            'axis': 1,
            'name': node.name + '/SoftMax_'
        }).create_node()

        # Connect nodes: input -> Softmax -> Log
        node.in_port(0).get_connection().set_destination(softmax.in_port(0))
        log.in_port(0).get_connection().set_source(softmax.out_port(0))

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [log.id]
    def extract(cls, node):
        attr = get_mxnet_layer_attrs(node.symbol_dict)
        mode = attr.str("mode", "instance")

        if mode == "channel":
            axis = 1
        else:
            axis = -1

        update_attrs = {
            'axis': axis,
        }

        # update the attributes of the node
        Softmax.update_node_stat(node, update_attrs)
        return cls.enabled
コード例 #10
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        assert node.has_valid(
            'axis'
        ), 'The node "{}" does not have mandatory attribute "axis"'.format(
            node_name)

        log = LogOp(graph, {}).create_node()
        softmax = Softmax(graph, {
            'axis': node.axis,
            'name': node_name + '/Softmax'
        }).create_node()
        rename_nodes([(node, node_name + '/delete'), (log, node_name)])

        # Connect nodes: input -> Softmax -> Log
        node.in_port(0).get_connection().set_destination(softmax.in_port(0))
        log.in_port(0).get_connection().set_source(softmax.out_port(0))
        return [log.id]
コード例 #11
0
    def extract(node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)

        axis = 1
        preserve_shape = attrs.str('preserve_shape', 'False')
        multi_output = attrs.str('multi_output', 'False')

        if preserve_shape == 'True':
            axis = -1

        if multi_output == 'True':
            axis = 1

        update_attrs = {
            'axis': axis,
        }

        # update the attributes of the node
        Softmax.update_node_stat(node, update_attrs)
        return __class__.enabled
コード例 #12
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        assert node.has_valid(
            'axis'
        ), 'The node "{}" does not have mandatory attribute "axis"'.format(
            node_name)

        flatten_node = FlattenONNX(graph, {
            'name': node_name + '/FlattenONNX_',
            'axis': node.axis
        }).create_node()
        shape_node = Shape(graph, {
            'name': node_name + '/ShapeOf_'
        }).create_node()
        softmax_node = Softmax(
            graph, {
                'name':
                node_name + '/Softmax_',
                'axis':
                1,
                'framework_node_name':
                node_name,
                'rename_condition':
                lambda n: len(n.graph.get_op_nodes(name=node_name)) == 0
            }).create_node()
        reshape_node = Reshape(graph, {}).create_node()

        rename_nodes([(node, node_name + '/delete'),
                      (reshape_node, node_name)])

        flatten_node.out_port(0).connect(softmax_node.in_port(0))
        softmax_node.out_port(0).connect(reshape_node.in_port(0))
        shape_node.out_port(0).connect(reshape_node.in_port(1))

        source = node.in_port(0).get_source()

        flatten_node.in_port(0).connect(source)
        shape_node.in_port(0).connect(source)

        return [reshape_node.id]
コード例 #13
0
ファイル: softmax_ext.py プロジェクト: zhenlusu500/openvino
 def extract(cls, node):
     Softmax.update_node_stat(node, {'infer': copy_shape_infer})
     return cls.enabled
コード例 #14
0
 def extract(cls, node):
     axis = onnx_attr(node, 'axis', 'i', default=1)
     Softmax.update_node_stat(node, {'axis': axis})
     return cls.enabled
コード例 #15
0
 def extract(node):
     Softmax.update_node_stat(node, {'infer': copy_shape_infer})
     return __class__.enabled