def add_unsqueeze_for_new(graph: Graph, ss_node: Node):
        log.info(
            "StridedSlice op with new axis mask '{}' has been detected".format(
                ss_node.id))
        if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1:
            return

        shape_out = ss_node.out_node().shape
        dim = mo_array(range(len(ss_node['new_axis_mask'])))[mo_array(
            ss_node['new_axis_mask'], dtype=bool)]
        ss_shape = []
        for i in range(0, len(ss_node['new_axis_mask'])):
            if not ss_node['new_axis_mask'][i]:
                ss_shape.append(shape_out[i])
            else:
                ss_node['new_axis_mask'][i] = 0

        ss_node.out_port(0).data.set_shape(ss_shape)

        # insert Unsqueeze
        unsqueeze_node = Unsqueeze(graph,
                                   dict(name=ss_node.name +
                                        '/Unsqueeze_new')).create_node()
        ss_node.out_port(0).get_connection().insert_node(unsqueeze_node)
        unsqueeze_node.out_port(0).data.set_shape(shape_out)

        dims_node = Const(graph, {
            'name': unsqueeze_node.id + '/Indices',
            'value': int64_array(dim)
        }).create_node()
        dims_node.out_port(0).connect(unsqueeze_node.in_port(1))
Beispiel #2
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        if not node.has_valid('start') or not node.has_valid(
                'stop') or not node.has_valid('step'):
            return

        start_value = Const(
            graph,
            dict(value=node.start,
                 symbol_dict={'name':
                              node.id + '/const_start'})).create_node()
        limit_value = Const(
            graph,
            dict(value=node.stop,
                 symbol_dict={'name':
                              node.id + '/const_limit'})).create_node()
        delta_value = Const(
            graph,
            dict(value=node.step,
                 symbol_dict={'name':
                              node.id + '/const_delta'})).create_node()
        node.in_port(0).get_connection().set_source(start_value.out_port(0))
        node.in_port(1).get_connection().set_source(limit_value.out_port(0))
        node.in_port(2).get_connection().set_source(delta_value.out_port(0))
        if node.has_valid('repeat') and node.repeat > 1:
            rep = MXRepeat(
                graph,
                dict(name=node.id + '/mxrepeat', axis=0,
                     repeats=node.repeat)).create_node()
            node.out_port(0).get_destination().get_connection().set_source(
                rep.out_port(0))
            rep.in_port(0).connect(node.out_port(0))
Beispiel #3
0
    def replace_op(self, graph: Graph, node: Node):
        pb = node.parameters
        weights_size = read_binary_integer32_token(pb)
        weights = read_blob(pb, weights_size, dtype=np.int32) - 1

        node_name = node.soft_get('name', node.id)
        const_attrs = {
                       'name': node_name + '/indexes',
                       'value': np.array(weights),
                       'shape': [weights_size],
                       'data_type': np.int32
                      }
        indexes_node = Const(graph).create_node(attrs=const_attrs)

        perm_in_1 = Const(graph, {'value': int64_array([1, 0]), 'name': node_name + '/order'}).create_node()
        perm1_node = Transpose(graph, {'name': node_name + '/input_permute'}).create_node([node.in_node(0)])
        perm1_node.in_port(0).connect(node.in_port(0).get_source())
        perm1_node.in_port(1).connect(perm_in_1.out_port(0))

        gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, {'name': node_name + '/gather'})
        gather_node.in_port(0).connect(perm1_node.out_port(0))
        gather_node.in_port(1).connect(indexes_node.out_port(0))

        perm2_node = Transpose(graph, {'name': node_name + '/output_permute'}).create_node()
        perm2_node.in_port(0).connect(gather_node.out_port(0))
        perm2_node.in_port(1).connect(perm_in_1.out_port(0))

        return [perm2_node.id]
def resolve_shared_inputs(node: Node, port_ids_to_duplicate: List[int]):
    """
    Duplicates shared constants that are consumed by more than one node. 
    If constant is consumed by several ports of one node - no duplication gets done
    """
    graph = node.graph

    for port_id in port_ids_to_duplicate:
        dst_port_map = defaultdict(list)
        for dst in node.in_port(
                port_id).get_source().get_connection().get_destinations():
            dst_port_map[dst.node].append(dst.idx)
        del dst_port_map[node]
        value = node.in_port(port_id).data.get_value()
        if value is None:
            log.debug(
                'Can not duplicate due no data for in_port {} of node {}'.
                format(port_id, node.name))
        for node, idxs in dst_port_map.items():
            const = Const(
                graph, {
                    'value': mo_array(value),
                    'name': node.soft_get('name', node.id) + '/duplicated_'
                }).create_node()
            for idx in idxs:
                node.in_port(idx).disconnect()
                const.out_port(0).connect(node.in_port(idx))
            const.infer(const)
    def find_and_replace_pattern(self, graph: Graph):
        for roll_node in graph.get_op_nodes(op='Roll'):
            if not roll_node.in_port(2).disconnected():
                return
            node_name = roll_node.soft_get('name', roll_node.id)

            # reshape to 1d tensor
            reshape_to_1d = create_op_node_with_second_input(
                graph, Reshape, int64_array([-1]),
                {'name': node_name + '/reshape'})
            roll_node.in_port(0).get_connection().insert_node(reshape_to_1d)

            # add zero const as axes input to roll
            const_zero = Const(graph, {
                'value': int64_array([0]),
                'name': node_name + '/axes'
            }).create_node()
            const_zero.out_port(0).connect(roll_node.in_port(2))

            # reshape to original shape
            shape_of = Shape(graph, {
                'name': node_name + '/shape_of'
            }).create_node()
            reshape_to_1d.in_port(0).get_connection().add_destination(
                shape_of.in_port(0))
            reshape_to_orig_shape = Reshape(graph, {}).create_node()
            rename_nodes([(roll_node, node_name + '/roll'),
                          (reshape_to_orig_shape, node_name)])
            shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1))
            roll_node.out_port(0).get_connection().insert_node(
                reshape_to_orig_shape)
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        node = match['reduce']
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        if len(connected_in_ports) == 1:
            node_name = node.soft_get('name', node.id)

            # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal
            # to None. The infer function handles this case because the input shape is known at this stage only
            if node.has_valid('axis'):
                const = Const(graph, {
                    'name': node_name + '/axis',
                    'value': node.axis
                }).create_node()
                node.add_input_port(1, skip_if_exist=True)
                const.out_port(0).connect(node.in_port(1))
                del graph.node[node.id]['axis']
            else:
                # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor.
                axes = create_op_with_const_inputs(
                    graph, Range, {
                        0: int64_array(0),
                        2: int64_array(1)
                    }, dict(name=node_name + '/axes'))
                end_of_range = Rank(graph, dict(name=node_name +
                                                '/range_end')).create_node()
                node.in_port(0).get_connection().get_source().connect(
                    end_of_range.in_port(0))
                end_of_range.out_port(0).connect(axes.in_port(1))

                node.add_input_port(1, skip_if_exist=True)
                axes.out_port(0).connect(node.in_port(1))
 def replace_pattern(self, graph: Graph, match: [str, Node]):
     node = match['transpose']
     assert len(node.in_nodes()) == 1
     order = np.arange(len(node.in_port(0).data.get_shape()))[::-1]
     const = Const(graph, {'value': order, 'name': node.soft_get('name', node.id) + '/Order'}).create_node()
     node.add_input_port(1, skip_if_exist=True)
     const.out_port(0).connect(node.in_port(1))
     node['reverse_order'] = False
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = Const(graph, {'name': 'init_value_' + pair_name,
                                              'value': np.zeros(int64_array([in_shape[0], in_shape[1]*node_t]), dtype=np.float32),
                                              'shape': int64_array([in_shape[0], in_shape[1]*node_t])}).create_node()
        memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': mo_array([in_shape[1]*(node_t-1)]),
                                       'offset': mo_array([in_shape[1]]), 'axis': mo_array([1])}).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': mo_array([in_shape[1]]),
                                    'offset': mo_array([0]), 'axis': mo_array([1])}).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
def create_ss_interval_border(graph: Graph, slice_border_port: Port,
                              shape: np.ndarray, axes: np.ndarray,
                              node_name: str):
    """
    This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends"

    :param graph: graph to operate on.
    :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice.
    :param shape: input shape of the Slice
    :param axes: axes that "starts" and "ends" apply to
    :param node_name: Slice node name
    :return: Concat node that forms "begin"/"end" values for the StridedSlice
    """
    # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This
    # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is
    # supported by the StridedSlice layer
    clamp = create_op_with_const_inputs(graph,
                                        Clamp,
                                        port_value_dict={
                                            1: np.iinfo(np.int32).min,
                                            2: np.iinfo(np.int32).max
                                        },
                                        op_attrs=dict(name=node_name +
                                                      '/Clamp'))
    clamp.in_port(0).connect(slice_border_port)
    # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created
    # here to prevent type errors in Concat node
    cast = Cast(graph, dict(name=node_name + '/CastToI64',
                            dst_type=np.int64)).create_node()
    cast.in_port(0).connect(clamp.out_port(0))
    concat = Concat(graph, dict(name=node_name + '/Concat',
                                axis=0)).create_node()
    for value_idx, port_idx in enumerate(axes):
        concat.add_input_port(port_idx)
        # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct
        # Concat input port
        value = create_op_with_const_inputs(
            graph,
            Gather,
            port_value_dict={
                1: int64_array([value_idx]),
                2: int64_array(0)
            },
            op_attrs={'name': node_name + '/Gather'})
        cast.out_port(0).connect(value.in_port(0))
        value.out_port(0).connect(concat.in_port(port_idx))
    for port_idx in range(len(shape)):
        if not concat.is_in_port_connected(port_idx):
            concat.add_input_port(port_idx)
            # This border value would be ignored in StridedSlice because of the begin_mask\end_mask
            const = Const(
                graph, dict(name=node_name + '/Const',
                            value=int64_array([0]))).create_node()
            const.out_port(0).connect(concat.in_port(port_idx))

    return concat
Beispiel #10
0
def input_as_const(node: Node, attrs: dict, port: int, bin: str, value: np.ndarray):
    """
    Inserts constant node on input `port` of `node` with `values` and `attrs`. Marks input edge with bin `attribute`
    """
    graph = node.graph
    const = Const(graph, {'value': value, **attrs}).create_node()
    node.add_input_port(port, skip_if_exist=True)
    const.out_port(0).connect(node.in_port(port))
    node.in_port(port).bin = bin
    node.in_port(port).in_attrs.append('bin')
 def replace_sub_graph(graph: Graph, match: dict):
     node = match['op']
     for port_index, value_attr, attrs in node['embedded_inputs']:
         const = Const(graph, dict(value=node[value_attr])).create_node()
         node.add_input_port(port_index, skip_if_exist=True)
         const.out_port(0).connect(node.in_port(port_index))
         node.in_port(port_index).bin = attrs['bin']
         node.in_port(port_index).in_attrs.append('bin')
         del node[value_attr]
     del node['embedded_inputs']
Beispiel #12
0
def create_op_node_with_second_input(graph: Graph, op: callable, second_input_value: np.array, op_attrs=None,
                                     input_node=None):
    operation = op(graph, op_attrs)
    node = operation.create_node()
    if input_node is not None:
        input_node.out_port(0).connect(node.in_port(0))
    second_input_node = Const(graph, {'name': node.name + '/value', 'value': second_input_value}).create_node()
    second_input_node.out_port(0).connect(node.in_port(1))
    if graph.stage != 'front':
        second_input_node.infer(second_input_node)
    return node
 def __insert_mul_node_with_coeff(node: Node, port: int, coeff: float):
     if coeff != 1:
         mul_node = Mul(node.graph, {
             'name': node.id + '/coeff_mul'
         }).create_node()
         const_node = Const(node.graph, {
             'name': node.id + '/coeff',
             'value': mo_array([coeff])
         }).create_node()
         node.in_port(port).get_connection().insert_node(mul_node)
         const_node.out_port(0).connect(mul_node.in_port(1))
Beispiel #14
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)
        assert node.has_valid('axis')

        axis = Const(graph, {'name': name + '/axis', 'value': int64_array(node.axis)}).create_node()
        gather = Gather(graph, {'name': name}).create_node()
        node.in_port(0).get_connection().set_destination(gather.in_port(0))
        node.in_port(1).get_connection().set_destination(gather.in_port(1))
        axis.out_port(0).connect(gather.in_port(2))
        node.out_port(0).get_connection().set_source(gather.out_port(0))
Beispiel #15
0
 def replace_op(self, graph: Graph, node: Node):
     const = Const(
         graph,
         dict(value=mo_array(-1.),
              name=node.name + '/reciprocal_pow_const_')).create_node()
     reciprocal = Pow(graph, {
         'name': node.name + '/reciprocal_pow_'
     }).create_node()
     node.in_port(0).get_connection().set_destination(reciprocal.in_port(0))
     const.out_port(0).connect(reciprocal.in_port(1))
     return [reciprocal.id]
 def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
     node = match['reshape']
     connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
     if len(connected_in_ports) == 1:
         if node.has('dim'):
             const = Const(graph, {'value': node.dim}).create_node()
             node.add_input_port(1, skip_if_exist=True)
             const.out_port(0).connect(node.in_port(1))
             del node['dim']
         else:
             raise Error('The `dim` attribute for node {} is not set'.format(node.op))
Beispiel #17
0
    def unroll_ellipsis_for_inputs(graph: Graph, node: Node,
                                   ellipsis_start: int, num_insertions: int):
        node_name = node.soft_get('name', node.id)

        for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
            if i == 3 and not node.is_in_port_connected(3):
                continue  # no need to extend strides if they are not connected

            blank_values_arr = np.zeros(
                num_insertions) if input_name != 'strides' else np.ones(
                    num_insertions)
            blank_values_node = Const(
                graph, {
                    'name':
                    node_name +
                    '/const_to_unroll_{}_ellipsis'.format(input_name),
                    'value':
                    int64_array(blank_values_arr)
                }).create_node()

            concat_in_ports_count = 3 if ellipsis_start != 0 else 2
            concat = Concat(
                graph, {
                    'axis': 0,
                    'name': node_name + '/concat_{}'.format(input_name),
                    'in_ports_count': concat_in_ports_count
                }).create_node()

            if ellipsis_start != 0:
                split = create_op_with_const_inputs(graph, VariadicSplit, {
                    1:
                    int64_array(0),
                    2:
                    int64_array([ellipsis_start, -1])
                }, {
                    'name':
                    node_name + '/split_for_{}_ellipsis'.format(input_name),
                    'out_ports_count':
                    2
                })
                node.in_port(i).get_connection().set_destination(
                    split.in_port(0))

                concat.in_port(0).connect(split.out_port(0))
                concat.in_port(1).connect(blank_values_node.out_port(0))
                concat.in_port(2).connect(split.out_port(1))
            else:
                concat.in_port(0).connect(blank_values_node.out_port(0))
                node.in_port(i).get_connection().set_destination(
                    concat.in_port(1))

            concat.out_port(0).get_connection().set_destination(
                node.in_port(i))
    def add_squeeze_for_shrink(graph: Graph, ss_node: Node):
        # add Squeeze for shrink_axis_mask
        log.info(
            "StridedSlice op with shrink mask '{}' has been detected".format(
                ss_node.id))

        if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1:
            return

        shape_out = ss_node.out_node().shape
        dim = mo_array(range(len(ss_node['shrink_axis_mask'])))[mo_array(
            ss_node['shrink_axis_mask'], dtype=bool)]
        ss_shape = []
        i = 0
        k = 0

        # Don't permute reshape if channels were squeezed
        dont_permute = graph.graph['layout'] == 'NCHW'
        if graph.graph['layout'] == 'NHWC' and ss_node['shrink_axis_mask'][
                -1] == 1:
            dont_permute = True

        while k < len(shape_out):
            if i >= len(ss_node['shrink_axis_mask']
                        ) or not ss_node['shrink_axis_mask'][i]:
                ss_shape.append(shape_out[k])
                k = k + 1
            else:
                ss_node['shrink_axis_mask'][i] = 0
                ss_shape.append(1)
            i = i + 1

        while i < len(ss_node['shrink_axis_mask']):
            ss_node['shrink_axis_mask'][i] = 0
            ss_shape.append(1)
            i = i + 1

        ss_node.out_port(0).data.set_shape(ss_shape)

        # insert Squeeze
        squeeze_node = Squeeze(
            graph,
            dict(name=ss_node.name + '/Squeeze_shrink',
                 nchw_layout=dont_permute,
                 correct_data_layout=dont_permute)).create_node()
        ss_node.out_port(0).get_connection().insert_node(squeeze_node)
        squeeze_node.out_port(0).data.set_shape(shape_out)

        dims_node = Const(graph, {
            'name': squeeze_node.id + '/Indices',
            'value': int64_array(dim)
        }).create_node()
        dims_node.out_port(0).connect(squeeze_node.in_port(1))
Beispiel #19
0
    def replace_pattern(self, graph: Graph, match: [str, Node]):
        swapaxis = match['op']
        assert len(swapaxis.in_ports()) == 1
        assert swapaxis.has_and_set('order')
        order = swapaxis.order

        swapaxis.add_input_port(1)
        const = Const(graph, {'value': order, 'name': swapaxis.soft_get('name', swapaxis.id) + '/Order'}).create_node()
        const.out_port(0).connect(swapaxis.in_port(1))

        Transpose.update_node_stat(swapaxis, {'need_shape_inference': True})

        del swapaxis['order']
 def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
     node = match['transpose']
     connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
     if len(connected_in_ports) == 1:
         if node.has_valid('order'):
             const = Const(graph, {'value': node.order}).create_node()
             node.add_input_port(1, skip_if_exist=True)
             const.out_port(0).connect(node.in_port(1))
             del graph.node[node.id]['order']
         elif node.has('order') and node.order is None:
             assert node.has_and_set('reverse_order')
         else:
             raise Error('Can not deduce transpose `order` for {}: only one in_port and no `order` parameter.'
                         ''.format(node.soft_get('name', node.id)))
Beispiel #21
0
def create_fake_quantize_node(graph: Graph, name, data_type=np.float32):
    fq = FakeQuantize(graph, {
        'name': name,
        'levels': 0,
        'stop_value_propagation': True
    }).create_node()

    input_low = Const(graph, {
        'value': np.array(0.0, dtype=data_type)
    }).create_node()
    input_height = Const(graph, {
        'value': np.array(0.0, dtype=data_type)
    }).create_node()
    output_low = Const(graph, {
        'value': np.array(0.0, dtype=data_type)
    }).create_node()
    output_height = Const(graph, {
        'value': np.array(0.0, dtype=data_type)
    }).create_node()

    input_low.out_port(0).connect(fq.in_port(1))
    input_height.out_port(0).connect(fq.in_port(2))
    output_low.out_port(0).connect(fq.in_port(3))
    output_height.out_port(0).connect(fq.in_port(4))

    input_low.infer(input_low)
    input_height.infer(input_height)
    output_low.infer(output_low)
    output_height.infer(output_height)

    return fq
Beispiel #22
0
def create_op_with_const_inputs(graph: Graph, op: callable, port_value_dict: Dict[int, np.array],
                                op_attrs=None, input_node=None):
    operation = op(graph, op_attrs)
    node = operation.create_node()
    if input_node is not None:
        input_node.out_port(0).connect(node.in_port(0))

    for idx, value in port_value_dict.items():
        node.add_input_port(idx, skip_if_exist=True)
        value_input_node = Const(graph, {'name': node.name + '_input_port_' + str(idx) + '/value',
                                         'value': value}).create_node()
        value_input_node.out_port(0).connect(node.in_port(idx))
        if graph.stage != 'front':
            value_input_node.infer(value_input_node)
    return node
Beispiel #23
0
    def swap_pad_and_unsqueeze(self, pad: Node, unsqueeze: Node):
        # insert additional items to the pads in the position specified by the Unsqueeze axis
        unsqueeze_axis = unsqueeze.in_port(1).data.get_value()
        for port_id in [1, 2]:
            current_value = pad.in_port(
                port_id).get_connection().data.get_value()
            new_value_node = Const(
                pad.graph, {
                    'name':
                    pad.soft_get('name', pad.id) + '/value_{}'.format(port_id),
                    'value':
                    shape_insert(current_value, unsqueeze_axis.item(), 0),
                    'override_output_shape':
                    True
                }).create_node()
            pad.in_port(port_id).disconnect()
            pad.in_port(port_id).connect(new_value_node.out_port(0))

        # swap Pad and Unsqueeze layers
        unsqueeze.in_port(0).disconnect()
        pad.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        unsqueeze.out_port(0).get_connection().set_source(pad.out_port(0))
        unsqueeze.out_port(0).connect(pad.in_port(0))

        # output shapes of Pad and Unsqueeze changed so need to recalculate them
        pad['override_output_shape'] = True
        unsqueeze['override_output_shape'] = True
    def replace_pattern(self, graph: Graph, match: dict):
        bias_add = match['BiasAdd']

        # Replace BiasAdd by Add operation
        new_add = Add(graph, {'name': bias_add.id + '/Add'}).create_node()

        bias_add.in_port(0).get_connection().set_destination(new_add.in_port(0))
        bias_add.in_port(1).get_connection().set_destination(new_add.in_port(1))
        bias_add.out_port(0).get_connection().set_source(new_add.out_port(0))

        if bias_add.data_format != 'NCHW':
            return

        input_shape = new_add.in_port(0).data.get_shape()
        bias_shape = new_add.in_port(1).data.get_shape()
        assert len(bias_shape) == 1

        unsqueeze_dims = np.arange(len(input_shape))
        channel_dim = get_features_dim('NCHW', len(input_shape))
        unsqueeze_dims = np.delete(unsqueeze_dims, channel_dim, 0)

        unsqueeze_node = Unsqueeze(graph, {'name': new_add.id + '/BiasUnsqueeze'}).create_node()
        unsqueeze_dims_node = Const(graph, {'name': new_add.id + '/Dims',
                                            'value': unsqueeze_dims}).create_node()
        # Reconnecting nodes
        unsqueeze_node.in_port(1).connect(unsqueeze_dims_node.out_port(0))
        unsqueeze_node['override_output_shape'] = True

        new_add.in_port(1).get_connection().insert_node(unsqueeze_node)
Beispiel #25
0
 def replace_pattern(self, graph: Graph, match: dict):
     """
     Adds Normalize layer weights, which are required by Inference Engine, 
     but do not always exist in MXNet model. 
     
     L2Normalization is mapped to Normalize layer
     so we need to generate Normalize weights filled with ones.
     
     Parameters
     ----------
     graph : Graph
        Graph with loaded model.
      match : dict
        Patterns which were found in graph structure.
     """
     l2_normalization_node = match['l2_normalization']
     if len(l2_normalization_node.in_nodes()) < 2:
         value = np.full([l2_normalization_node.in_node(0).shape[1]],
                         1.0,
                         dtype=np.float32)
         weights_node = Const(
             graph,
             dict(name=l2_normalization_node['name'] + '_weights',
                  value=value)).create_node()
         l2_normalization_node.add_input_port(1)
         l2_normalization_node.in_port(1).connect(weights_node.out_port(0))
         l2_normalization_node.in_port(1).bin = 'weights'
Beispiel #26
0
def create_bias_node(graph: Graph, src_node):
    logger.debug('Creating new bias for {}'.format(src_node.name))
    destination_ports = []
    for dest_port in src_node.out_port(0).get_destinations():
        destination_ports.append(dest_port)

    # Create Add and constant with zero bias
    bias_shape = src_node.out_port(0).data.get_shape()
    add_bias_shape = [1] * len(bias_shape)
    add_bias_shape[1] = bias_shape[1]
    weights = get_weights_for_node(src_node)
    bias_dtype = np.float32
    if weights and weights.out_port(0).is_data_type_defined():
        bias_dtype = weights.out_port(0).get_data_type()
    add_bias = Const(
        graph, {
            'value': np.zeros(add_bias_shape, dtype=bias_dtype),
            'shape': add_bias_shape,
            'need_shape_inference': True
        }).create_node()
    add_op = Add(graph, {
        'name': src_node.name + '/add_',
        'need_shape_inference': True
    }).create_node()

    # Connect Const to Add node
    add_op.in_port(1).connect(add_bias.out_port(0))

    # Reconnect src_node -> output to src_node -> Add -> output
    src_node.out_port(0).disconnect()
    src_node.out_port(0).get_connection().set_destination(add_op.in_port(0))

    for destination_port in destination_ports:
        add_op.out_port(0).connect(destination_port)
    add_bias.out_node(0)['Insert_Convert_operation_after'] = True
Beispiel #27
0
    def replace_sub_graph(graph: Graph, match: dict):
        strided_slice_node = match['strided_slice']
        const_node = match['const']
        reshape_node = match['reshape']
        pack_node = match['pack']

        if not const_node.has_valid('value') or not is_value_is_constant(const_node.value, -1):
            log.debug('The pattern does not correspond to flatten. The second reshape dimension is not -1. It is {}'.
                      format(const_node.soft_get('value')))
            return
        if len(pack_node.in_nodes()) != 2:
            log.debug('The pattern does not correspond to flatten. The "Pack" operation produces tensor with 3 items '
                      'but should produce just 2.')
            return

        expected_values = [0, 1, 1]  # expected values to a StridedSlice to get the batch size
        for ind in range(3):
            if not strided_slice_node.in_node(ind + 1).has_valid('value') or \
                    not is_value_is_constant(strided_slice_node.in_node(ind + 1).value, expected_values[ind]):
                log.debug('The pattern does not correspond to flatten because of the input with index {}. The value is '
                          '"{}".'.format(ind, strided_slice_node.soft_get('value')))
                return

        reshape_node.in_port(1).disconnect()
        reshape_const_node = Const(graph, {'value': int64_array([0, -1]),
                                           'name': reshape_node.soft_get('name', reshape_node.id) + '/shape'}).create_node()
        reshape_node.in_port(1).connect(reshape_const_node.out_port(0))
        reshape_node['special_zero'] = True
        log.debug('The node "{}" is actually a Flatten node'.format(reshape_node.soft_get('name')))
Beispiel #28
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        const_dtype = np.float32
        if node.has_valid('data_type'):
            const_dtype = node.data_type
        const = Const(graph, {'value': mo_array([1], dtype=const_dtype)}).create_node()
        add = Add(graph, {'name': node.name + '/Add_'}).create_node()
        log = Log(graph, {'name': node.name + '/Log_'}).create_node()

        # Connect nodes: input -> Add -> Log
        const.out_port(0).connect(add.in_port(0))
        node.in_port(0).get_connection().set_destination(add.in_port(1))
        add.out_port(0).connect(log.in_port(0))
        rename_nodes([(node, node_name + '/delete'), (log, node_name)])

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [log.id]
Beispiel #29
0
    def transform_map_fn_output_concatenation(external_match: dict,
                                              internal_match: dict):
        """
        Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node
        :param external_match: a match used for handling a part of the main graph responsible for output concatenation
        :param internal_match: a match used for handling a part of the body graph responsible for output concatenation
        """
        loop_node = external_match['while']
        stack_node = external_match['stack']
        list_reserve_node = external_match['reserve']
        body_graph = loop_node['body']

        tensor_list_set_item_node = internal_match['concatenation']
        tensor_list_set_item_node_name = tensor_list_set_item_node.soft_get(
            'name', tensor_list_set_item_node.id)
        list_result_node = internal_match['concatenation_result']

        # replace TensorListSetItem with Unsqueeze and use axis attribute for corresponding Result node
        # to concatenate results from different iterations
        unsqueeze_list_element = create_op_with_const_inputs(
            body_graph, Unsqueeze, {1: int64_array(0)},
            {'name': 'TensorListSetItemUnsqueeze'})
        tensor_list_set_item_node.in_port(2).get_connection().set_destination(
            unsqueeze_list_element.in_port(0))
        tensor_list_set_item_node.out_port(0).get_connection().set_source(
            unsqueeze_list_element.out_port(0))
        rename_nodes([(tensor_list_set_item_node,
                       tensor_list_set_item_node_name + '/AbandonedName'),
                      (unsqueeze_list_element, tensor_list_set_item_node_name)
                      ])
        list_result_node_layer_id = list_result_node.internal_layer_id
        Loop.update_port_map_value_ext(loop_node.output_port_map,
                                       'internal_layer_id',
                                       list_result_node_layer_id, 'axis', 0)

        # remove TensorListStack to by-pass the node since the result from the Loop node is already concatenated
        stack_node.out_port(0).get_connection().set_source(
            stack_node.in_port(0).get_connection().get_source())

        # disconnect ListReserve node because it is no longer needed for Loop
        list_reserve_node.out_port(0).disconnect()

        # connect a number of iterations with trip count that can be received from the second input of ListReserve
        # create a constant network with True value for execution_condition so that IE can ignore execution condition
        # and perform trip_counts iterations. This approach with known trip count value allows to avoid dynamism.
        loop_node.in_port(1).disconnect()
        list_reserve_node.in_port(1).get_source().connect(loop_node.in_port(1))
        for record in loop_node.output_port_map:
            if 'purpose' in record and record[
                    'purpose'] == 'execution_condition':
                exec_cond_layer_id = record['internal_layer_id']
                exec_cond_node = Loop.get_body_node_by_internal_id(
                    loop_node, exec_cond_layer_id)
                const_true = Const(body_graph, {
                    'value': np.array(True, dtype=np.bool)
                }).create_node()
                exec_cond_node.in_port(0).get_connection().set_source(
                    const_true.out_port(0))
Beispiel #30
0
    def find_and_replace_pattern(self, graph: Graph):
        global_poolings = graph.get_op_nodes(type='Pooling', global_pool=True)
        if len(global_poolings) == 0:
            return

        layout = graph.graph['layout']
        assert layout != 'NHWC', 'Global pooling transformation depends on layout (NHWC not enabled)'

        for pooling in global_poolings:
            name = pooling.soft_get('name', pooling.id)
            assert pooling.has_valid(
                'pool_method'
            ), 'Global Pooling {} has no `pool_method` attribute'.format(name)
            method = pooling['pool_method']
            assert method in self.pool_method_to_reduce_type, \
                'Unexpected Global Pooling method `{}` for node `{}`'.format(method, name)
            reduce_op_class = self.pool_method_to_reduce_type[method]

            reduce = reduce_op_class(graph, {
                'name': name + '/reduce',
                'keep_dims': True
            }).create_node()

            pooling.out_port(0).get_connection().set_source(reduce.out_port(0))
            src = pooling.in_port(0).get_connection().get_source()

            reduce.in_port(0).get_connection().set_source(src)

            start = Const(graph, {'value': int64_array(2)}).create_node()
            end = Rank(graph, {'name': name + '/input_rank'}).create_node()
            delta = Const(graph, {'value': int64_array(1)}).create_node()

            axis = Range(graph, {
                'name': name + '/global_pooling_reduce_axis'
            }).create_node()

            axis.in_port(0).connect(start.out_port(0))
            src.connect(end.in_port(0))
            axis.in_port(1).connect(end.out_port(0))
            axis.in_port(2).connect(delta.out_port(0))

            axis.out_port(0).connect(reduce.in_port(1))

            log.debug('Global {} pooling was converted to reduce: `{}`'.format(
                method, name))