Exemplo n.º 1
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        num_inputs = len(connected_in_ports)
        assert node.has_valid(
            'equation'
        ), "Einsum node {} must contain `equation` attribute".format(node_name)
        equation = node.equation

        # parse the equation and extract input and output subscripts
        input_subscripts, output_subscript = Einsum.parse_equation(
            node_name, equation)

        # check that each operand has the corresponding input subscript
        assert len(input_subscripts) == num_inputs, "The number of input operands of Einsum node {} " \
                                                    "must match the number of input subscripts " \
                                                    "in `equation`".format(node_name)

        # check compatibility of dimension sizes with the same label and generate a dictionary of shapes for labels
        label_to_shape = {}
        for input_ind in range(num_inputs):
            input_shape = node.in_port(input_ind).data.get_shape()
            input_subscript = input_subscripts[input_ind]
            labels = Einsum.extract_subscript_labels(node_name,
                                                     input_subscript)
            num_dims = len(input_shape)
            num_labels = len(labels)
            num_broadcasted_dims = num_dims - num_labels + 1
            dim_ind = 0
            label_ind = 0
            while label_ind < num_labels and dim_ind < num_dims:
                label = labels[label_ind]
                if label == "...":
                    sub_shape = input_shape[dim_ind:dim_ind +
                                            num_broadcasted_dims]
                    if label in label_to_shape.keys():
                        common_shape = bi_directional_shape_broadcasting(
                            sub_shape, label_to_shape[label])
                        assert common_shape is not None, "The dimensions labeled of ellipsis must be broadcastable " \
                                                         "for Einsum node {}".format(node_name)
                        label_to_shape[label] = common_shape
                    else:
                        label_to_shape[label] = sub_shape
                    dim_ind += num_broadcasted_dims
                else:
                    dim_size = input_shape[dim_ind]
                    sub_shape = shape_array([dim_size])
                    assert label not in label_to_shape.keys() or np.array_equal(label_to_shape[label], sub_shape), \
                        "Sizes of dimensions with the same label of Einsum node {} " \
                        "must be compatible".format(node_name)
                    label_to_shape[label] = sub_shape
                    dim_ind += 1
                label_ind += 1

        # generate output shape based on the output subscript
        output_shape = shape_array([])
        labels = Einsum.extract_subscript_labels(node_name, output_subscript)
        for label in labels:
            assert label in label_to_shape.keys(), "The label in the output subscript must appear" \
                                                   " in input subscripts in equation {} " \
                                                   "of Einsum node {}".format(equation, node_name)
            output_shape = np.ma.concatenate(
                (output_shape, label_to_shape[label]))

        node.out_port(0).data.set_shape(output_shape)
Exemplo n.º 2
0
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-10 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    resize_name = resize.soft_get('name', resize.id)

    rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node()
    range_node = create_op_with_const_inputs(graph, Range, {
        0: int64_array(2),
        2: int64_array(1)
    }, {'name': resize_name + '/axes'})

    sizes_ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([2]),
        2: int64_array([0]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/sizes_ss',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([0]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([2]),
            2: int64_array([0]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/scales_ss',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([0]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })

    rank_node.out_port(0).connect(range_node.in_port(1))

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest',
            'coordinate_transformation_mode': 'asymmetric',
            'cube_coeff': -0.75,
            'nearest_mode': 'simple',
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': 'scales',
            'in_ports_count': 4
        }).create_node()

    range_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g.,
    # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because
    # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40.
    # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small
    # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0],
    # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because
    # input_shape[2] * scales[2] + 1.0e-5 =  39.99991.
    # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)).
    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values

    cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()

    shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
    mul_node = Mul(graph, {
        'name': resize_name + '/Mul'
    }).create_node([cast_shape_to_float, add_node])
    floor_node = Floor(graph, {
        'name': resize_name + '/Floor'
    }).create_node([mul_node])
    cast_mul_result_to_int = Cast(graph, {
        'dst_type': np.int64
    }).create_node([floor_node])
    cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
    sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

    scales_ss.out_port(0).connect(interpolate_node.in_port(2))

    connection_of_resize_input = resize.in_port(0).get_connection()
    connection_of_resize_input.set_destination(interpolate_node.in_port(0))

    connection_of_scales = resize.in_port(1).get_connection()
    connection_of_scales.set_destination(scales_ss.in_port(0))

    connection_of_resize_input.get_source().connect(shape_of.in_port(0))
    connection_of_resize_input.get_source().connect(rank_node.in_port(0))
    connection_of_scales.get_source().connect(add_node.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
Exemplo n.º 3
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port,
                                                       in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(
                in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1,
                                                      np.int32)
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Exemplo n.º 4
0
 def remove_node_and_reset_connections(self, graph, node: Node, in_port):
     node.in_port(0).disconnect()
     node.out_port(0).get_connection().set_source(in_port)
     graph.remove_node(node.id)
Exemplo n.º 5
0
def infer_nodes(graph: Graph,
                nodes: List[Node],
                constant_subgraph_only: bool = False):
    """
    Run "infer" function of the specified nodes.

    :param graph: graph with nodes
    :param nodes: list of node ids in the topological order
    :param constant_subgraph_only: flag which specifies whether only inference of constant sub-graphs should be done
    """
    debug_logger = log.getLogger().isEnabledFor(log.DEBUG)
    for n in nodes:
        # Data Flow Infer
        node = Node(graph, n)
        node_name = node.soft_get('name', node.id)
        try:
            if node.has(
                    'is_partial_inferred') and not node.is_partial_inferred:
                if node.has('infer') and not node.infer is None:
                    # we consider that operation will produce value if all inputs are constants or it is
                    # 'ShapeOf' operation
                    if constant_subgraph_only:
                        in_values = [
                            port.data.get_value()
                            for port in node.in_ports().values()
                        ]
                        if node.soft_get('op') == 'Parameter' or any(value is None for value in in_values) or \
                                (node.soft_get('op') == 'ShapeOf' and node.in_port(0).data.get_shape() is None):
                            # if here will be any new ShapeOf type operation, we should update condition above
                            continue

                    if debug_logger:
                        log.debug('-' * 20)
                        log.debug('Partial infer for {}'.format(
                            node.soft_get('name')))
                        log.debug('Op: {}'.format(node.soft_get('op')))
                        log.debug('Inputs:')
                        log_debug_dict(node.in_nodes(), 'input')

                    node.infer(node)
                    out_nodes = node.out_nodes()

                    # propagate nchw_layout attributes to data nodes
                    if node.has('nchw_layout'):
                        for out_node in out_nodes.values():
                            out_node['nchw_layout'] = node.nchw_layout

                    # In debug print current node attributes, input shapes/values and output shape/values
                    if debug_logger:
                        log.debug('Outputs:')
                        log_debug_dict(node.out_nodes(), 'output')

                    if not constant_subgraph_only:
                        not_all_output_shapes = False

                        for out_port, out_node in out_nodes.items():
                            not_all_output_shapes = False
                            if not out_node.has_valid('shape'):
                                log.error(
                                    'Shape is not defined for output {} of "{}".'
                                    .format(out_port, node_name))
                                not_all_output_shapes = True

                        if not_all_output_shapes:
                            raise Error(
                                'Not all output shapes were inferred or fully defined for node "{}". '
                                + refer_to_faq_msg(40), node_name)
                elif node.kind != 'data':
                    raise Error(
                        'There is no registered "infer" function for node "{}" with op = "{}". '
                        +
                        'Please implement this function in the extensions. ' +
                        refer_to_faq_msg(37), node_name, node.soft_get('op'))
                node.is_partial_inferred = True
        except Exception as err:
            log.error('Cannot infer shapes or values for node "{}".'.format(
                node.soft_get('name')))
            log.error(str(err))
            log.error('')
            log.error(
                'It can happen due to bug in custom shape infer function {}.'.
                format(node.soft_get('infer')))
            log.error(
                'Or because the node inputs have incorrect values/shapes.')
            log.error(
                'Or because input shapes are incorrect (embedded to the model or passed via --input_shape).'
            )
            debug_messages = '\n'.join([
                'Layer "' + node_name + '": ' + node_attrs['debug_message']
                for node_name, node_attrs in graph.nodes(data=True)
                if 'debug_message' in node_attrs
            ])
            if debug_messages != "":
                log.error('')
                log.error('Other possible failure reasons are listed below:')
                log.error(debug_messages)
            if not debug_logger:
                log.error(
                    'Run Model Optimizer with --log_level=DEBUG for more information.'
                )
            else:
                log.debug('Node "{}" attributes: {}'.format(
                    node.soft_get('name'), node.graph.node[node.id]))
            raise Error('Stopped shape/value propagation at "{}" node. '.
                        format(node.soft_get('name')) +
                        refer_to_faq_msg(38)) from err
        control_flow_infer(graph, n)
Exemplo n.º 6
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        assert len(connected_in_ports) == 2, \
            "Incorrect number of inputs for {} node".format(node_name)

        data_shape = node.in_port(0).data.get_shape()
        data_value = node.in_port(0).data.get_value()
        indices_shape = node.in_port(1).data.get_shape()
        indices_value = node.in_port(1).data.get_value()

        assert node.has_valid(
            'batch_dims'
        ), "Node {} must contain `batch_dims` attribute".format(node_name)
        batch_dims = node.batch_dims

        # check that a number of batch dimensions is less than both ranks of data and indices tensors
        assert batch_dims < len(
            data_shape
        ), "Number of batch dimensions must be less than a rank of data"
        assert batch_dims < len(
            indices_shape
        ), "Number of batch dimensions must be less than a rank of indices"

        # check that batch dimensions of data and indices are the same
        for batch_dim in range(batch_dims):
            assert compatible_dims(data_shape[batch_dim], indices_shape[batch_dim]), \
                "The dimension {} for data and indices tensors must be the same".format(batch_dim)

        # check ranks of input tensors
        assert len(data_shape) > 0, "Data must not be a scalar"
        assert len(indices_shape) > 0, "Indices must not be a scalar"
        assert (batch_dims + indices_shape[-1]) <= len(data_shape), \
            "Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions"
        assert node['version'] in ['opset5', 'opset8'], 'Unsupported version of GatherND operation: {}, operation ' \
                                                        'name : {}'.format(node['version'], node.soft_get('name'))

        # compute output shape
        batch = []
        if batch_dims > 0:
            if node['version'] == 'opset5':  # Support old version of gatherND shape inference
                if is_fully_defined(data_shape[:batch_dims]):
                    batch = [np.prod(data_shape[:batch_dims]).tolist()]
                else:
                    batch = [dynamic_dimension_value]
            elif node['version'] == 'opset8':
                for dim in range(batch_dims):
                    assert compatible_dims(indices_shape[dim], data_shape[dim]),\
                        "Batch dimensions in data.shape and indices.shape must be compatible"
                if is_fully_defined(indices_shape[:batch_dims]):
                    batch = indices_shape[:batch_dims].tolist()
                elif is_fully_defined(data_shape[:batch_dims]):
                    batch = data_shape[:batch_dims].tolist()
                else:
                    for ind in range(batch_dims):
                        if indices_shape[ind] != dynamic_dimension_value:
                            batch.append(indices_shape[ind])
                        elif data_shape[ind] != dynamic_dimension_value:
                            batch.append(data_shape[ind])
                        else:
                            batch.append(dynamic_dimension_value)

        slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):])

        output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape
        node.out_port(0).data.set_shape(output_shape)

        # compute output value if all input indices are defined
        if is_fully_defined(indices_value) and data_value is not None:
            batch_dims_size = 1

            for i in range(batch_dims):
                batch_dims_size *= indices_shape[i]

            output_data = []

            reshaped_indices = indices_value.reshape(batch_dims_size, -1,
                                                     indices_shape[-1])

            reshaped_data = data_value.reshape((batch_dims_size, ) + tuple(
                (data_shape[batch_dims:])))

            for batch_dim in range(reshaped_indices.shape[0]):
                for outer_dim in range(reshaped_indices.shape[1]):
                    gather_index = tuple(
                        reshaped_indices[batch_dim][outer_dim])
                    output_data.append(reshaped_data[(batch_dim, ) +
                                                     gather_index])
            output_value = np.asarray(
                output_data, dtype=data_value.dtype).reshape(output_shape)
            node.out_port(0).data.set_value(output_value)
Exemplo n.º 7
0
def copy_graph_with_ops(graph: Graph) -> Graph:
    """
    Function to copy graph and apply extenders to appropriate nodes
    :param graph: Graph to copy
    :return:Copied graph with applied extenders
    """
    new_graph = Graph()
    new_graph.stage = 'back'
    new_graph.graph = graph.graph

    node_connections = dict()
    mapping_of_old_idx_into_new = dict()

    restore_correct_ports(graph)

    # Nodes preprocessing stage in source graph
    # Firstly propagate values only for Const nodes, because other preprocessings
    # assumes Const nodes are already preprocessed.
    for op in graph.get_op_nodes(type='Const'):
        preprocessing_op_nodes[op.type](op)

    for op in graph.get_op_nodes():
        if op.soft_get('type') != 'Const' and op.soft_get(
                'type') in preprocessing_op_nodes:
            preprocessing_op_nodes[op.type](op)

    # Create a new copy of graph with correct attributes (shape & type infer, backend attrs etc.)
    for op in graph.get_op_nodes():

        # Save input shapes restored from IR
        op['old_input_shapes'] = list()
        for n in op.in_nodes():
            op.old_input_shapes.append(int64_array(op.in_node(n).shape))

        # Apply extenders to nodes in source graph
        if op.type in Extender.registered_ops:
            Extender.get_extender_class_by_name(op.type).extend(op)
        else:
            log.debug(
                'Extender for node {} with type={} not found, please note.'.
                format(op.name, op.type))

        # Add node with necessary type and extended attrs in new graph
        op_type = op.soft_get('type_to_create', op.type)

        if op_type in custom_ops:
            node = custom_ops[op_type](new_graph, op.attrs()).create_node()
        else:
            if op_type not in Op.registered_ops:
                log.warning(
                    'Operation {} is not found in MO operations, please check it! '
                    'Simple shape infer function is used'.format(op_type))
                node = Op(new_graph, op.attrs()).create_node()
                assert 'type' in node, 'Operation {} have no `type` attribute.'.format(
                    node.soft_get('name'))
                node['op'] = node.type
                node['infer'] = Extender.use_shapes_from_ir
                if 'ir_data_attrs' in op:
                    node['IE'] = [('layer', [
                        ('id', lambda node: node.node), 'name', 'type',
                        'version'
                    ], [('data', list(op.ir_data_attrs.keys()), []), '@ports',
                        '@consts'])]

            else:
                node = Op.get_op_class_by_name(op_type)(
                    new_graph, op.attrs()).create_node()

            # Fill out_ports_count attribute
            if 'out_ports_count' not in node and node.soft_get(
                    'type') != 'Result':
                node['out_ports_count'] = len(op.out_edges())

        # This attribute is no longer needed and we can delete it
        if 'ir_data_attrs' in node:
            del node['ir_data_attrs']

        if op.has_and_set('need_copy_input_blobs'):
            copy_input_blobs(op, node)

        # Collect node connections
        mapping_of_old_idx_into_new[op.id] = node.id
        node_connections[op.id] = collect_node_outputs(op)

    # Restore connections in new graph
    for input_node_idx, its_outputs in list(node_connections.items()):
        for out_port_idx, out_port_dest in its_outputs.items():
            for dest_in_port_idx, dest_node_idx in out_port_dest:
                src = Node(new_graph,
                           mapping_of_old_idx_into_new[input_node_idx])
                dst = Node(new_graph,
                           mapping_of_old_idx_into_new[dest_node_idx])
                src.out_port(out_port_idx).connect(
                    dst.in_port(dest_in_port_idx))

    # Nodes postprocessing stage in new graph
    for op in new_graph.get_op_nodes():
        # Call normalize node outputs for restored operations to connect temporary Result operations for disconnected
        # output ports. We need to do that for correct shape inference. These Result operations will be removed during
        # IR emitting. For TopK operation outputs normalizing we should use specific
        # function TopKNormalizer.normalize_outputs.
        if op.soft_get('type') != 'TopK':
            Op.normalize_outputs(op)

        # Set correct_data_type attribute to Const data nodes to correct processing of restored values
        if op.soft_get('type') == 'Const':
            assert len(op.out_nodes()) == 1 and op.out_node(0).soft_get('kind') == 'data',\
                'Const node {} not properly corrected to appropriate data node'.format(op.soft_get('name'))
            op.out_node(0)['correct_data_type'] = True

            if op.has_and_set('rt_info'):
                op.out_node(0)['rt_info'] = op.rt_info

        # operations postprocessing with some special types
        if op.soft_get('type') in postprocessing_op_nodes:
            postprocessing_op_nodes[op.type](op)

        restore_tensor_names(op)

    # clean up graph to shape inference
    new_graph.clean_up()

    return new_graph
Exemplo n.º 8
0
 def type_infer(node: Node):
     assert node.in_port(1).get_source().get_data_type() == node.in_port(2).get_source().get_data_type(), \
         'The data type of the second and the third inputs must be equal for the node {}'.format(node.name)
     node.out_port(0).set_data_type(
         node.in_port(1).get_source().get_data_type())
Exemplo n.º 9
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 3, \
            "Select operation must have 3 inputs: 'condition', 'then' and 'else' tensors for node {}".format(node_name)

        condition_value = node.in_port(0).data.get_value()
        condition_shape = node.in_port(0).data.get_shape()
        resulting_tensors = [
            node.in_port(1).data.get_value(),
            node.in_port(2).data.get_value()
        ]

        a_shape = node.in_port(1).data.get_shape()
        b_shape = node.in_port(2).data.get_shape()
        broadcast_rule = node.soft_get('auto_broadcast', 'numpy')

        if broadcast_rule == 'numpy':
            msg = "In Select node '{}' condition and then/else shapes must be broadcastable. " \
                  "But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
                    node_name, condition_shape, a_shape, b_shape)

            output_shape = bi_directional_shape_broadcasting(a_shape, b_shape)
            assert output_shape is not None, msg

            output_is_scalar = len(output_shape) == 0

            # if Select was created from TF Where operations then 1D condition must have the same size
            # as 0-index dimension of output_shape. This condition is different from being numpy compatible
            # but by adding ones to the end we can achieve numpy compatibility, as in transformation SelectBroadcast.py
            if node.has_valid('format') and node['format'] == 'tf' and len(
                    condition_shape) == 1:
                # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py#L4596-L4598
                msg_tf = "In Select node '{}' if 'condition' is a 1D tensor then it's size " \
                         "must be matching with the first dimension of then/else branches. " \
                         "But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
                            node_name, condition_shape, a_shape, b_shape)

                # check equality only if both values non-dynamic
                if is_fully_defined(
                        condition_shape[0]
                ) and not output_is_scalar and is_fully_defined(
                        output_shape[0]):
                    assert condition_shape[0] == output_shape[0], msg_tf
                ones_shape = len(output_shape) if output_is_scalar else len(
                    output_shape) - 1
                condition_shape = np.concatenate(
                    (condition_shape, np.ones(ones_shape, dtype=np.int64)))

            output_shape = bi_directional_shape_broadcasting(
                output_shape, condition_shape)
            assert output_shape is not None, msg

        elif broadcast_rule == 'pdpd':
            # todo: add pdpd broadcasting rule
            # note that additionally to output_shape resulting_tensors must be broadcasted as well
            raise Error("PDPD broadcasting rule is not implemented yet")
        else:  # broadcasting is not allowed
            assert compatible_shapes(a_shape, b_shape) and compatible_shapes(condition_shape, a_shape), \
                'In node \'{}\' for Select operation when broadcasting is off all inputs must be of the same shape. ' \
                'But instead got: cond_shape={}, then_shape={}, else_shape={}'.format(
                    node_name, condition_shape, a_shape, b_shape)
            output_shape = shape_array([
                i if i is not dynamic_dimension else j
                for i, j in zip(a_shape, b_shape)
            ])

        node.out_port(0).data.set_shape(output_shape)

        if condition_value is not None:
            if is_fully_defined(condition_value) and np.all(
                    condition_value == condition_value.item(0)):
                # in some graphs Select condition is always True[False] and
                # one of the branches is None (which is not selected)
                # if we use np.where for such cases then dtype of output_value will be object (non numeric type)
                # and subsequent numpy operation on such tensors will fail
                output_value = resulting_tensors[not np.
                                                 bool(condition_value.item(0))]
                if output_value is None:
                    return
                if broadcast_rule == 'numpy':
                    output_value = bi_directional_broadcasting(
                        output_value, output_shape)
                elif broadcast_rule == 'pdpd':
                    # todo: add pdpd broadcasting rule
                    raise Error(
                        "PDPD broadcasting rule is not implemented yet")

                node.out_port(0).data.set_value(output_value)
            elif resulting_tensors[0] is not None and resulting_tensors[
                    1] is not None:
                output_value = np.ma.where(condition_value,
                                           resulting_tensors[0],
                                           resulting_tensors[1])
                node.out_port(0).data.set_value(output_value)
    def replace_timeheightconv(self, graph: Graph, node: Node):
        req_time_offsets = node.soft_get('time_offsets')
        offsets = node.soft_get("offsets", [[]])
        all_time_offsets = list(set(offsets[:, 0]))
        all_time_offsets.sort()
        in_name = node.soft_get('name', node.id)
        rename_node(node, in_name + '/to_delete')

        # create memoryoffsets for context gathering
        # we need concat if time offsets more than 1
        concat = Concat(graph, attrs={'name': in_name + '/Concat',
                                      'in_ports_count': len(all_time_offsets)}).create_node()
        i = 0
        for t in all_time_offsets:
            # if time offset included in required_time_offsets we don't need default value
            has_default = t not in req_time_offsets
            memoff = MemoryOffset(graph, attrs={'name': in_name + '/MemoryOffset_' + str(i),
                                                't': t, 'has_default': has_default, 'splitted': False,
                                                'pair_name': in_name + '/MemoryOffset_pair_' + str(i)}).create_node()
            concat.in_port(i).connect(memoff.out_port(0))
            memoff.in_port(0).connect(node.in_port(0).get_source())
            i = i + 1

        stride = node.soft_get("height_subsample", 1)

        kernel = int64_array([0, 0])
        kernel[0] = len(set(offsets[:, 0]))
        kernel[1] = len(set(offsets[:, 1]))

        pad_h = int64_array([0, 0])
        pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0
        pad_h[1] = stride * node.height_out - (node.height_in - max([max(offsets[:, 1]), 0]))

        dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / (kernel[0] - 1) if kernel[0] > 1 else 1
        dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / (kernel[1] - 1) if kernel[0] > 1 else 1

        conv_attrs = {
            'name': in_name,
            'output': node['out_channels'],
            'height_in': node.height_in,
            'bias_term': None,
            'pad': int64_array([[0, 0], [0, 0], [0, 0], pad_h]),
            'pad_spatial_shape': int64_array([[0, 0], pad_h]),
            'dilation': int64_array([1, 1, dilation_t, dilation_h]),
            'kernel': int64_array([node.out_channels, node.in_channels, kernel[0], kernel[1]]),
            'stride': int64_array([1, 1, 1, stride]),
            'kernel_spatial': kernel,
            'input_feature_channel': 1,
            'output_feature_channel': 0,
            'channel_dims': int64_array([1]),
            'spatial_dims': int64_array([2, 3]),
            'batch_dims': int64_array([0]),
            'kernel_spatial_idx': int64_array([2, 3]),
            'group': 1,
            'reshape_kernel': True,
            'bias_addable': True,
        }
        conv = Convolution(graph, attrs=conv_attrs).create_node()
        conv.in_port(0).connect(concat.out_port(0))
        conv.in_port(1).connect(node.in_port(1).get_source())

        # change layout for weights from OHWI to OIHW
        # in future should be replaced by common Permute mechanics
        weights = conv.in_port(1).get_source().node.value
        weights = weights.reshape(int64_array([node.out_channels, -1, node.in_channels]))
        weights = weights.transpose(int64_array([0, 2, 1]))
        weights = weights.flatten()
        conv.in_port(1).get_source().node.value = weights

        conv.in_port(2).connect(node.in_port(2).get_source())
        node.out_port(0).get_connection().set_source(conv.out_port(0))
        graph.remove_node(node.id)
Exemplo n.º 11
0
 def reverse_infer(node: Node):
     input_shape = node.in_port(0).data.get_shape()
     if input_shape is None and node.is_in_port_connected(2) and node.in_port(2).data.get_shape() is not None:
         shape = undefined_shape_of_rank(node.in_port(2).data.get_shape()[0])
         node.in_port(0).data.set_shape(shape)
Exemplo n.º 12
0
 def infer(node: Node):
     input_shape = node.in_port(0).data.get_shape()
     output_shape = input_shape.copy()
     output_shape[1] = node.const_dim + (input_shape[1] -
                                         node.const_dim) * len(node.context)
     node.out_port(0).data.set_shape(output_shape)
Exemplo n.º 13
0
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_inputs = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_inputs) == 2 and all([i in connected_inputs for i in range(2)]), \
            "Reshape should have 2 connected input ports, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_inputs)

        input_shape = node.in_port(0).data.get_shape()
        assert input_shape is not None

        new_shape = node.in_port(1).data.get_value()
        assert new_shape is not None, 'Dynamic Reshape second input is not supported. Node {}'.format(
            name)

        assert np.argwhere(new_shape == -1).size <= 1, \
            'Reshape second input should not have several `-1` values set. ' \
            'Node: {}, reshape second input value {}'.format(name, new_shape)

        num_of_input_elements = np.prod(input_shape)
        num_of_output_elements = 1
        for index, x in enumerate(new_shape):
            if x is dynamic_dimension:
                num_of_output_elements = dynamic_dimension_value
            elif x == 0 and node.has_and_set('special_zero'):
                if input_shape[index] is not dynamic_dimension:
                    num_of_output_elements *= input_shape[index]
            elif x != -1:
                num_of_output_elements *= x

        # input_shape = [dynamic, 5, 6], new_shape = [0, -1] => output_shape [dynamic, 30]
        # marker that no dynamic input dimensions or all of them are copied with "0" magic value
        all_dynamic_dimension_are_copied = True
        if not is_fully_defined(input_shape):
            for index, x in enumerate(input_shape):
                if x is dynamic_dimension:
                    if index >= len(new_shape) or new_shape[index] != 0:
                        all_dynamic_dimension_are_copied = False

        undefined_dim = dynamic_dimension
        if num_of_output_elements is not dynamic_dimension and all_dynamic_dimension_are_copied and \
                is_fully_defined(new_shape):
            undefined_dim = num_of_input_elements // num_of_output_elements
        output_shape = []
        for index, x in enumerate(new_shape):
            if x == 0 and node.has_and_set('special_zero'):
                output_shape.append(input_shape[index])
            elif x == -1:
                output_shape.append(undefined_dim)
            else:
                output_shape.append(x)

        # even if the new_shape contains some dynamic values we can calculate the actual value by deducing it from the
        # input shape if it is static: input_shape = [5, 3, 8], new_shape = [4, d] => output_shape = [4, 30]
        if is_fully_defined(input_shape) and not is_fully_defined(new_shape):
            dynamic_indices = np.argwhere(
                [item is dynamic_dimension for item in new_shape])
            num_of_output_elements = 1
            if dynamic_indices.size == 1:
                for index, x in enumerate(new_shape):
                    if x == 0 and node.has_and_set('special_zero'):
                        num_of_output_elements *= input_shape[index]
                    elif x is not dynamic_dimension and x != -1:
                        num_of_output_elements *= x
            assert num_of_input_elements % num_of_output_elements == 0, \
                'Incorrect number of output elements deduced for node {}: '.format(name)
            output_shape[dynamic_indices[0]
                         [0]] = num_of_input_elements // num_of_output_elements

        assert not is_fully_defined(input_shape) or not is_fully_defined(output_shape) or \
               np.prod(input_shape) == np.prod(output_shape), \
               "Number of elements in input {} and output {} of reshape node {} mismatch" \
               "".format(input_shape, output_shape, name)

        PermuteInputs().set_input_permutation(node.in_node(1), node,
                                              'output:0', 'shape')

        if node.in_port(0).data.get_value() is not None and is_fully_defined(
                output_shape):
            node.out_port(0).data.set_value(
                node.in_port(0).data.get_value().reshape(output_shape))
        else:
            node.out_port(0).data.set_shape(output_shape)
Exemplo n.º 14
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)

        input_shape = node.in_port(0).data.get_shape()
        input_value = node.in_port(0).data.get_value()
        target_shape_shape = node.in_port(1).data.get_shape()
        target_shape = node.in_port(1).data.get_value()
        assert node.has_and_set(
            'mode'), 'Broadcasting mode is not defined for node "{}"'.format(
                node_name)
        # Dynamic target shape is possible to infer only if shape of target shape is static and 1D
        if target_shape is None and len(target_shape_shape) == 1 and (
                len(input_shape) <= 1 or node.mode == 'explicit'):
            assert is_fully_defined(target_shape_shape)
            new_shape = undefined_shape_of_rank(target_shape_shape.item(0))
            node.out_port(0).data.set_shape(new_shape)
            return
        assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(
            node_name)

        PermuteInputs().set_input_permutation(node.in_node(1), node,
                                              'output:0', 'shape')

        if input_value is not None and not node.has_and_set('stop_value_propagation') and \
                is_fully_defined(target_shape):
            if node.mode == 'numpy':
                node.out_port(0).data.set_value(
                    uni_directional_broadcasting(input_value, target_shape))
            elif node.mode == 'bidirectional':
                node.out_port(0).data.set_value(
                    bi_directional_broadcasting(input_value, target_shape))
            elif node.mode == 'explicit':
                axes_mapping = node.in_port(2).data.get_value()
                assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \
                                                 'is not supported. Node: `{}`'.format(node_name)
                PermuteInputs().set_input_permutation(node.in_node(2), node,
                                                      'output:0', 'axis')
                axes_mapping = node.in_port(2).data.get_value()
                node.out_port(0).data.set_value(
                    explicit_broadcasting(input_value, target_shape,
                                          axes_mapping))
            else:
                raise Error('The node "{}" has unsupported mode "{}"'.format(
                    node_name, node.mode))
        else:
            if node.mode == 'numpy':
                node.out_port(0).data.set_shape(
                    uni_directional_shape_broadcasting(input_shape,
                                                       target_shape))
            elif node.mode == 'bidirectional':
                node.out_port(0).data.set_shape(
                    bi_directional_shape_broadcasting(input_shape,
                                                      target_shape))
            elif node.mode == 'explicit':
                axes_mapping = node.in_port(2).data.get_value()
                assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \
                                                 'is not supported. Node: `{}`'.format(node_name)
                PermuteInputs().set_input_permutation(node.in_node(2), node,
                                                      'output:0', 'axis')
                axes_mapping = node.in_port(2).data.get_value()
                new_shape, _ = explicit_shape_broadcasting(
                    input_shape, target_shape, axes_mapping)
                node.out_port(0).data.set_shape(new_shape)
            else:
                raise Error('The node "{}" has unsupported mode "{}"'.format(
                    node_name, node.mode))
Exemplo n.º 15
0
    def update_if_output_ports_shape(if_node: Node):
        """
        Update shape and values for If output ports.

        :param if_node: The If node to update output ports and shapes
        :return: None
        """
        node_name = if_node.soft_get('name', if_node.id)

        then_outputs = [
            node for node in if_node.then_graph.get_op_nodes()
            if node.has('output_id')
        ]
        else_outputs = [
            node for node in if_node.else_graph.get_op_nodes()
            if node.has('output_id')
        ]
        outputs_mapping = {}
        outputs_number = len(if_node.out_ports())

        if outputs_number == 0 and len(
                if_node.out_ports(control_flow=True)) != 0:
            # Some models have if with control flow outputs.
            # These shape inference for such ifs
            # TODO: need to rethink and redo support for control flow edges in if operation
            for node in if_node.out_nodes(control_flow=True).values():
                node.shape = int64_array([])
            return

        for port_id in if_node.out_ports().keys():
            outputs_mapping[port_id] = {}

        # variables then_contains_fake_outputs/else_contains_fake_outputs contains True value
        # if all outputs from then_body/else_body have shape [0]. It means then_body/else_body does not return data
        # and further shape_inference for this branch is not possible.
        # TODO: exclude support fake_outputs from this code when we will support shape_inference with empty tensors

        then_contains_fake_outputs = \
            If.results_mapping_and_finding_fake_outputs(then_outputs, 'then_graph', outputs_mapping)
        else_contains_fake_outputs = \
            If.results_mapping_and_finding_fake_outputs(else_outputs, 'else_graph', outputs_mapping)

        # use_then_shape is True when else_body or when both bodies do not return data. If use_then_shape is True If's
        # outputs will have the same shapes as then_body results
        use_then_shape = else_contains_fake_outputs or not then_contains_fake_outputs

        cond_value = if_node.in_port(0).data.get_value()

        for port_id in outputs_mapping:
            then_else_nodes = outputs_mapping[port_id]
            assert 'then_graph' in then_else_nodes.keys(), 'then_graph does not connect with If.out_port[{0}] ' \
                                                           'in {1} node!'.format(port_id, node_name)
            assert 'else_graph' in then_else_nodes.keys(), 'else_graph does not connect with If.out_port[{0}] ' \
                                                           'in {1} node!'.format(port_id, node_name)

            then_shape = then_else_nodes['then_graph'].in_port(
                0).data.get_shape()
            then_value = then_else_nodes['then_graph'].in_port(
                0).data.get_value()
            else_shape = then_else_nodes['else_graph'].in_port(
                0).data.get_shape()
            else_value = then_else_nodes['else_graph'].in_port(
                0).data.get_value()

            if is_fully_defined(cond_value):
                if cond_value.item() is True:
                    if then_value is not None:
                        if_node.out_port(port_id).data.set_value(then_value)
                    else:
                        if_node.out_port(port_id).data.set_shape(then_shape)
                else:
                    if else_value is not None:
                        if_node.out_port(port_id).data.set_value(else_value)
                    else:
                        if_node.out_port(port_id).data.set_shape(else_shape)
            else:
                if then_contains_fake_outputs ^ else_contains_fake_outputs:
                    # if exactly one of the outputs is fake then use another one
                    if_node.out_port(port_id).data.set_shape(
                        then_shape if use_then_shape else else_shape)
                else:
                    # find "intersection" which is equal to the dimension value if corresponding dimensions are equal
                    # and dynamic otherwise
                    assert len(then_shape) == len(else_shape), 'Ranks of "then" and "else" output tensors are ' \
                                                               'different for node {} for port {}'.format(node_name,
                                                                                                          port_id)
                    output_shape = [
                        d1 if is_fully_defined(d1) and is_fully_defined(d2)
                        and d1 == d2 else dynamic_dimension_value
                        for d1, d2 in zip(then_shape, else_shape)
                    ]
                    if_node.out_port(port_id).data.set_shape(output_shape)
Exemplo n.º 16
0
    def replace_op(self, graph: Graph, node: Node):
        input_out_port = node.in_port(0).get_source()

        memory_pair_input = unique_id('id')
        memory_pair_output = unique_id('id')

        # Input -> FullyConnected
        fc_layer_after_input_attrs = {
            'name': 'input_fullyconnected',
            'out-size': node.gifo_x_weights_shape[0],
            'transpose_weights': True,
            'bias_term': True,
        }

        fc_layer_after_input = FullyConnected(
            graph, fc_layer_after_input_attrs).create_node()
        fc_layer_after_input.in_port(0).connect(input_out_port)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1,
                       'weights', node.gifo_x_weights)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2,
                       'biases', node.gifo_biases)

        init_value_prev_lstm_output = create_const_with_batch_from_input(
            input_out_port, node.gifo_r_weights_shape[1])
        prev_lstm_output = ReadValue(graph, {
            'name': 'prev_memory_output',
            'variable_id': memory_pair_input
        }).create_node()
        prev_lstm_output.in_port(0).connect(
            init_value_prev_lstm_output.out_port(0))

        # *Memory(output) -> FullyConnected
        fc_layer_from_prev_state_attrs = {
            'name': 'prev_memory_output_fullyconnected',
            'out-size': node.gifo_r_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False,
        }

        fc_layer_from_prev_state = FullyConnected(
            graph, fc_layer_from_prev_state_attrs).create_node()
        fc_layer_from_prev_state.in_port(0).connect(
            prev_lstm_output.out_port(0))
        input_as_const(fc_layer_from_prev_state,
                       fc_layer_from_prev_state_attrs, 1, 'weights',
                       node.gifo_r_weights)

        # Memory -> FullyConnected  \
        #                           *Eltwise(sum)
        # Input -> FullyConnected   /
        join_input_prev_state_sum = Add(graph, {
            'name': 'join_input_eltwise'
        }).create_node()
        join_input_prev_state_sum.in_port(0).connect(
            fc_layer_from_prev_state.out_port(0))
        join_input_prev_state_sum.in_port(1).connect(
            fc_layer_after_input.out_port(0))

        # *Eltwise(sum) -> Split
        # it is split into 4 nodes: Act, Eltw*3
        # the following order is mandatory
        #       ___Tanh
        #      /
        # Split ---(2)Eltwise(sum)
        #     |\
        #     | \__(3)Eltwise(sum)
        #     |____(4)Eltwise(sum)
        split_joined_input_axis = Const(graph, {
            'value': np.int64(1)
        }).create_node()
        split_joined_input = Split(graph, {
            'name': 'join_input_split',
            'num_splits': 4,
            'out_ports_count': 4
        }).create_node()
        split_joined_input.in_port(0).connect(
            join_input_prev_state_sum.out_port(0))
        split_joined_input.in_port(1).connect(
            split_joined_input_axis.out_port(0))

        init_value_prev_lstm_state = create_const_with_batch_from_input(
            split_joined_input.out_port(0), node.input_gate_weights.shape[0])
        prev_lstm_state = ReadValue(graph, {
            'name': 'prev_memory_state',
            'variable_id': memory_pair_output
        }).create_node()
        prev_lstm_state.in_port(0).connect(
            init_value_prev_lstm_state.out_port(0))

        # *Memory(state) -> *ScaleShift(input)
        state_input_scaleshift_attrs = {
            'name': 'input_scaleshift',
            'bias_term': False
        }
        state_input_scaleshift = ScaleShiftOp(
            graph, state_input_scaleshift_attrs).create_node()
        state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1,
                       'weights', node.input_gate_weights)

        # *Memory(state) -> *ScaleShift(forget)
        state_forget_scaleshift_attrs = {
            'name': 'forget_scaleshift',
            'bias_term': False
        }
        state_forget_scaleshift = ScaleShiftOp(
            graph, state_forget_scaleshift_attrs).create_node()
        state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs,
                       1, 'weights', node.forget_gate_weights)

        # Split                                 \
        #                                       (2)Eltwise(sum)
        # Memory(state) -> *ScaleShift(input)  /
        join_prev_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_input_eltwise'
            }).create_node()
        join_prev_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(1))
        join_prev_lstm_input_joined_input_sum.in_port(1).connect(
            state_input_scaleshift.out_port(0))
        # Split                                 \
        #                                       (3)Eltwise(sum)
        # Memory(state) -> *ScaleShift(forget)  /
        join_prev_lstm_input_joined_forget_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_forget_sum',
            }).create_node()
        join_prev_lstm_input_joined_forget_sum.in_port(0).connect(
            split_joined_input.out_port(2))
        join_prev_lstm_input_joined_forget_sum.in_port(1).connect(
            state_forget_scaleshift.out_port(0))

        # Split -> Tanh
        remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
        remember_tahn.in_port(0).connect(split_joined_input.out_port(0))

        # Split -> (2)Eltwise(sum) -> *Sigmoid
        remember_sigmoid = Sigmoid(graph, {
            'name': 'remember_sigmoid'
        }).create_node()
        remember_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_input_sum.out_port(0))

        # Split -> (3)Eltwise(sum) -> **Sigmoid
        forget_sigmoid = Sigmoid(graph, {
            'name': 'forget_sigmoid'
        }).create_node()
        forget_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_forget_sum.out_port(0))

        # *Memory(state)                        \
        #                                       (6)Eltwise(mul)
        # Split -> (3)Eltwise(sum) -> **Sigmoid /
        join_forget_prev_state_mul = Mul(graph, {
            'name': 'join_forget_prev_state_mul'
        }).create_node()
        join_forget_prev_state_mul.in_port(0).connect(
            forget_sigmoid.out_port(0))
        join_forget_prev_state_mul.in_port(1).connect(
            prev_lstm_state.out_port(0))

        # Split -> Tahn                         \
        #                                       (5)Eltwise(mul)
        # Split -> (2)Eltwise(sum) -> *Sigmoid   /
        join_remember_candidates_mul = Mul(
            graph, {
                'name': 'join_remember_candidates_mul'
            }).create_node()
        join_remember_candidates_mul.in_port(0).connect(
            remember_tahn.out_port(0))
        join_remember_candidates_mul.in_port(1).connect(
            remember_sigmoid.out_port(0))

        # (5)Eltwise(mul)  \
        #               (7)Eltwise(sum)
        # (6)Eltwise(mul)   /
        join_forget_remember_sum = Add(graph, {
            'name': 'join_forget_remember_sum'
        }).create_node()
        join_forget_remember_sum.in_port(0).connect(
            join_forget_prev_state_mul.out_port(0))
        join_forget_remember_sum.in_port(1).connect(
            join_remember_candidates_mul.out_port(0))

        # (7)Eltwise(sum) -> Clamp
        join_forget_clamp = create_op_with_const_inputs(
            graph, Clamp, {
                1: float32_array(-node.clip_value),
                2: float32_array(node.clip_value)
            }, {'name': 'join_forget_clamp'}, join_forget_remember_sum)
        #
        # Clamp -> (2)Memory(state)
        next_lstm_state = Assign(graph, {
            'name': 'next_lstm_state',
            'variable_id': memory_pair_output
        }).create_node()
        next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))

        res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
        res_node.in_port(0).connect(next_lstm_state.out_port(0))

        # Clamp -> (2)Tahn
        state_filtered_tahn = Tanh(graph, {
            'name': 'state_filtered_tahn'
        }).create_node()
        state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))

        # Clamp -> (2)ScaleShift
        clamp_scaleshift_attrs = {
            'name': 'clamp_scaleshift',
            'bias_term': False
        }
        clamp_scaleshift = ScaleShiftOp(graph,
                                        clamp_scaleshift_attrs).create_node()
        clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
        input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights',
                       node.output_gate_weights)

        # Split                 \
        #                       (4)Eltwise(sum)
        # Clamp -> (2)ScaleShift /
        join_next_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_next_lstm_input_joined_input_sum',
            }).create_node()
        join_next_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(3))
        join_next_lstm_input_joined_input_sum.in_port(1).connect(
            clamp_scaleshift.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid
        output_sigmoid = Sigmoid(graph, {
            'name': 'output_sigmoid'
        }).create_node()
        output_sigmoid.in_port(0).connect(
            join_next_lstm_input_joined_input_sum.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid         \
        #                                       (5)Eltwise(mul)
        # Clamp -> (2)Tahn                      /
        joined_output_mul = Mul(graph, {
            'name': 'joined_output_mul'
        }).create_node()
        joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
        joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))

        # (5)Eltwise(mul) -> (3)FullyConnected
        fc_output_attrs = {
            'name': 'FullyConnected',
            'out-size': node.projection_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False
        }
        fc_output = FullyConnected(graph, fc_output_attrs).create_node()
        fc_output.in_port(0).connect(joined_output_mul.out_port(0))
        input_as_const(fc_output, fc_output_attrs, 1, 'weights',
                       node.projection_weights)

        #                   / (2)Memory(output)
        # (3)FullyConnected
        #                   \ Output (any next node) (edge created automatically after replacement)
        next_lstm_output = Assign(graph, {
            'name': 'next_lstm_output',
            'variable_id': memory_pair_input
        }).create_node()
        next_lstm_output.in_port(0).connect(fc_output.out_port(0))

        res_node_lstm_output = Result(graph, {
            'name': 'next_lstm_output_out'
        }).create_node()
        res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))

        return [fc_output.id]
Exemplo n.º 17
0
    def lift_up_through_eltwise(node: Node, reverse_channels: Node):
        r"""
        BEFORE                      AFTER

                                    previous_op              previous_op'
                                          \                    /
        previous_op  previous_op'     ReverseChannels     ReverseChannels
                 \     /                            \     /
                Eltwise                             Eltwise
                   |                                  |
             ReverseChannels                       next_op
                  |
                next_op

        returns two objects:
        first - boolean value whatever we should continue propagating current ReverseChannels operation up or not
        second - list of new ReverseChannels operations that were produced while propagating reverse_channels up
        """
        before_shape = reverse_channels.in_port(0).data.get_shape()

        port_axis = []
        for idx, port in node.in_ports().items():
            shape = port.data.get_shape()

            non_one_dims = np.where(shape != 1)[0]
            if shape[reverse_channels.axis] == 1:
                continue  # nothing to flip for this input
            if len(non_one_dims) == 1 and shape[
                    non_one_dims.item()] == reverse_channels.order.size:
                axis = non_one_dims.item()
            elif np.array_equal(before_shape, shape):
                axis = reverse_channels.axis
            else:
                # shape has multiple non-one values and shape is not fully broadcasted to value port shape
                # it is safe not to propagate reverse channels
                return False, []
            port_axis.append((port, axis))

        copies = []
        for port, axis in port_axis:
            reverse_channels_copy = reverse_channels.copy_node(
                {'axis': mo_array(axis)})

            src = port.get_connection().get_source()
            if src.node.soft_get('type') == 'Parameter':
                # For Parameter nodes tensor debug attributes should not move to the last node
                # of subgraph. It is needed for the proper mapping of input framework name.
                # For this reason "source" mode is used to keep tensor debug attributes at Parameter node.
                port.get_connection().set_source(
                    reverse_channels_copy.out_port(0),
                    attributes_save_mode="source")
            else:
                port.get_connection().set_source(
                    reverse_channels_copy.out_port(0))
            src.connect(reverse_channels_copy.in_port(0))

            copies.append(reverse_channels_copy)

        reverse_channels.out_port(0).get_connection().set_source(
            reverse_channels.in_port(0).get_connection().get_source())
        reverse_channels.in_port(0).disconnect()

        # propagated reverse_channels successfully through current node, will continue propagation
        return True, copies
Exemplo n.º 18
0
    def dequantize_data(fake_quantize: Node, dst_type: type,
                        quantized_type: type) -> Node:
        graph = fake_quantize.graph
        quantized_data = fake_quantize.in_port(0).get_source().node
        name = fake_quantize.soft_get('name', fake_quantize.id)

        assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \
            'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id))

        dequantizing_cast = Cast(
            graph,
            dict(name=quantized_data.name +
                 "/to_{}".format(np_data_type_to_destination_type(dst_type)),
                 dst_type=dst_type,
                 stop_value_propagation=True)).create_node()
        fake_quantize.in_port(0).get_connection().set_destination(
            dequantizing_cast.in_port(0))

        # limits of dequantize
        in_low = fake_quantize.in_port(1).get_source()
        in_high = fake_quantize.in_port(2).get_source()
        out_low = fake_quantize.in_port(3).get_source()
        out_high = fake_quantize.in_port(4).get_source()

        # scale calculation
        output_range = Sub(graph, {
            'name': name + '/output_range'
        }).create_node()
        output_range.in_port(0).connect(out_high)
        output_range.in_port(1).connect(out_low)

        input_range = Sub(graph, {'name': name + '/input_range'}).create_node()
        input_range.in_port(0).connect(in_high)
        input_range.in_port(1).connect(in_low)

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        scale.in_port(0).connect(output_range.out_port(0))
        scale.in_port(1).connect(input_range.out_port(0))

        # shift calculation
        descaled_output_low = Div(graph, {
            'name': name + '/descaled_output_low'
        }).create_node()
        descaled_output_low.in_port(0).connect(out_low)
        descaled_output_low.in_port(1).connect(scale.out_port(0))

        shift = Sub(graph, {'name': name + '/shift'}).create_node()
        shift.in_port(0).connect(in_low)
        shift.in_port(1).connect(descaled_output_low.out_port(0))

        zero = Const(graph, {
            'name': name + '/zero',
            'value': mo_array(0, dtype=dst_type)
        }).create_node()
        scale_eq_zero = Equal(graph, {
            'name': name + '/scale_eq_zero'
        }).create_node()
        scale_eq_zero.in_port(0).connect(scale.out_port(0))
        scale_eq_zero.in_port(1).connect(zero.out_port(0))

        zero_point = Select(graph, {
            'name': name + '/zero_point'
        }).create_node()
        zero_point.in_port(0).connect(scale_eq_zero.out_port(0))
        zero_point.in_port(1).connect(zero.out_port(0))
        zero_point.in_port(2).connect(shift.out_port(0))

        # DeQuantize(x) == Mul(Sub(x, zero_point), scale)
        sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
        sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
        sub_zp.in_port(1).connect(zero_point.out_port(0))

        mul_scale = Mul(graph, {
            'name': name + '/mulpiply_by_scale'
        }).create_node()
        mul_scale.in_port(0).connect(sub_zp.out_port(0))
        mul_scale.in_port(1).connect(scale.out_port(0))

        fake_quantize.out_port(0).get_connection().set_source(
            mul_scale.out_port(0))

        graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
Exemplo n.º 19
0
 def type_infer(node: Node):
     node.out_port(0).set_data_type(node.in_port(2).get_data_type())
Exemplo n.º 20
0
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) == 3 and 0 in connected_in_ports and 1 in connected_in_ports and \
               2 in connected_in_ports, "Gather should have 3 connected input port, but it doesn't for " \
                                        "node: `{}`. Ports: {}".format(name, connected_in_ports)

        data_shape = node.in_port(0).data.get_shape()
        assert data_shape is not None
        indices_shape = node.in_port(1).data.get_shape()
        assert indices_shape is not None
        axis = node.in_port(2).data.get_value()

        # axis of Gather could be accepted as both scalar and 1D tensor
        if isinstance(axis, np.ndarray):
            axis = axis.item()
        assert axis is not None, 'axis input is undefined'

        assert -len(data_shape) <= axis < len(data_shape), \
            'axis must be within interval [-data_rank, data_rank). Instead got axis = {}, data_rank = {} '.\
            format(axis, len(data_shape))

        batch_dims = node.batch_dims
        assert -len(indices_shape) <= batch_dims <= len(indices_shape), \
            'batch_dims must be within interval [-indices_rank, indices_rank]. Instead got batch_dims = {}, ' \
            'indices_rank = {} '.format(batch_dims, len(indices_shape))

        # normalize to positive values
        axis = axis + len(data_shape) if axis < 0 else axis
        batch_dims = batch_dims + len(
            indices_shape) if batch_dims < 0 else batch_dims

        assert np.ma.allequal(data_shape[:batch_dims], indices_shape[:batch_dims]), \
            'data and indices inputs must have equal first dimensions until batch_dims'

        assert batch_dims <= axis, \
            'normalized batch_dims must be <= axis. Instead got batch_dims = {}, axis = {}'.format(axis, batch_dims)

        # we import PermuteInputs locally because it uses Gather inside and we have recursive imports
        from openvino.tools.mo.graph.perm_inputs import PermuteInputs
        PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0',
                                              'axis')

        batch_dims_range = indices_shape[:batch_dims]
        out_shape = np.concatenate(
            (data_shape[:axis], indices_shape[batch_dims:],
             data_shape[axis + 1:]))

        data_value = node.in_port(0).data.get_value()
        indices_value = node.in_port(1).data.get_value()
        if data_value is not None and indices_value is not None and is_fully_defined(
                indices_value):
            if batch_dims == 0:
                node.out_port(0).data.set_value(
                    np.ma.take(data_value, indices_value, axis))
            else:
                out_value = np.empty(out_shape)
                for batch_idx in np.ndindex(tuple(batch_dims_range)):
                    out_value[batch_idx] = np.ma.take(data_value[batch_idx],
                                                      indices_value[batch_idx],
                                                      axis - batch_dims)
                node.out_port(0).data.set_value(out_value)
        else:
            node.out_port(0).data.set_shape(out_shape)
Exemplo n.º 21
0
    def infer(node: Node):
        num_of_inputs = len(node.in_ports())
        opset = node.get_opset()
        max_num_of_inputs = 6 if opset == 'opset5' else 5
        input_msg_fmt = 'NonMaxSuppression node {} from {} must have from 2 to {} inputs'
        node_name = node.soft_get('name', node.id)
        inputs_msg = input_msg_fmt.format(node_name, opset, max_num_of_inputs)
        assert 2 <= num_of_inputs <= max_num_of_inputs, inputs_msg

        boxes_shape = node.in_port(0).data.get_shape()
        assert boxes_shape is not None, 'The shape of tensor with boxes is not defined'
        scores_shape = node.in_port(1).data.get_shape()
        assert scores_shape is not None, 'The shape of tensor with scores is not defined'
        assert len(boxes_shape
                   ) == 3, 'Length of tensors with boxes must be equal to 3'
        assert len(scores_shape
                   ) == 3, 'Length of tensors with scores must be equal to 3'

        # According to the specification of the operation NonMaxSuppression,
        # the input 'max_output_boxes_per_class' (port 2) is optional, with default value 0.
        if num_of_inputs >= 3:
            max_output_boxes_per_class = node.in_port(2).data.get_value()
        else:
            max_output_boxes_per_class = 0

        if not max_output_boxes_per_class:
            log.info(
                'Set default "max_output_boxes_per_class" for node {} to number of boxes'
                .format(node.name))
            max_output_boxes_per_class = boxes_shape[1]

        # convert the np.array value to a scalar to avoid issue with ragged numpy array generation in the shape
        # calculation formulas below
        if isinstance(max_output_boxes_per_class, np.ndarray):
            max_output_boxes_per_class = max_output_boxes_per_class.item()

        num_classes = scores_shape[1]
        num_input_boxes = boxes_shape[1]
        assert scores_shape[2] is dynamic_dimension or scores_shape[2] == num_input_boxes or scores_shape[2] is None \
               or num_input_boxes is None, 'Number of boxes mismatch for operation {}'.format(node_name)

        if node.get_opset() in ['opset4', 'opset5']:
            max_number_of_boxes = min(
                num_input_boxes,
                max_output_boxes_per_class) * boxes_shape[0] * num_classes
        else:
            max_number_of_boxes = min(
                num_input_boxes,
                boxes_shape[0] * max_output_boxes_per_class * num_classes)
        node.out_port(0).data.set_shape(shape_array([max_number_of_boxes, 3]))

        if opset == 'opset5':
            node.out_port(0).data.set_shape(
                shape_array([dynamic_dimension_value, 3]))
            num_of_outputs = len([
                port for port in node.out_ports().values()
                if not port.disconnected()
            ])
            if num_of_outputs >= 2 and node.has_port('out', 1):
                node.out_port(1).data.set_shape(
                    shape_array([dynamic_dimension_value, 3]))
            if num_of_outputs >= 3 and node.has_port('out', 2):
                node.out_port(2).data.set_shape(shape_array([1]))
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-11 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    input_shape = resize.in_port(0).data.get_shape()
    input_rank = len(input_shape)
    resize_name = resize.soft_get('name', resize.id)
    if input_rank not in {4, 5}:
        log.warning(
            'The input shape is not 4D or 5D for op with name {}'.format(
                resize_name))
        return

    assert (resize.is_in_port_connected(0) and (resize.is_in_port_connected(2) or resize.is_in_port_connected(3))), \
        "Scales or sizes inputs must be connected to Node {} with op {}.".format(resize.soft_get("name", resize.id),
                                                                                 resize.op)

    assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
        'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op,
                                                                                 resize.soft_get("name", resize.id))

    layout = graph.graph['layout']

    if input_rank == 4:
        begin_dim = get_height_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1
    else:
        begin_dim = get_depth_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1

    sizes_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_sizes',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_scales',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    axes_node = Const(
        graph, {
            'name': resize_name + '/axis',
            'value': int64_array(np.arange(begin_dim, end_dim))
        }).create_node()

    shape_calculation_mode = 'sizes' if resize.is_in_port_connected(
        3) else 'scales'

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': convert_mode(resize.mode),
            'coordinate_transformation_mode':
            resize.coordinate_transformation_mode,
            'cube_coeff': resize.cube_coeff,
            'nearest_mode': resize.nearest_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': shape_calculation_mode,
            'in_ports_count': 4
        }).create_node()

    axes_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values

    if not resize.is_in_port_connected(3):
        cast_shape_to_float = Cast(graph, {
            'dst_type': dst_dtype
        }).create_node()
        mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
        cast_add_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()
        floor_node = Floor(graph, {
            'name': resize_name + '/Floor'
        }).create_node()
        mul_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
        cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_scales = resize.in_port(2).get_connection()
        connection_of_scales.set_destination(scales_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_scales.get_source().connect(mul_node.in_port(1))
    else:
        cast_shape_to_float = Cast(graph, {
            'dst_type': dst_dtype
        }).create_node()
        cast_sizes_to_float = Cast(graph, {
            'dst_type': dst_dtype
        }).create_node()
        div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
        cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
        cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        div_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(scales_ss.in_port(0))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_sizes = resize.in_port(3).get_connection()
        connection_of_sizes.set_destination(sizes_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_sizes.get_source().connect(
            cast_sizes_to_float.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
Exemplo n.º 23
0
 def remove_duplication(self, graph: Graph, fq: Node) -> []:
     # Keep only input operation
     fq.out_port(0).get_connection().set_source(fq.in_port(0).get_source())
     fq.in_port(0).disconnect()
     graph.remove_node(fq.id)
     return []
Exemplo n.º 24
0
    def infer(node: Node):
        """
        Deconvolution has an input argument that explicitly determines output shape, so in contrast
        to the forward Conv2d we shouldn't infer output shape. We just use this output shape as
        an input shape and pass it to our utilities that computes numeric values for padding.
        They also deliver output shape that is interpreted here as input shape for convolution.
        We need to check that the real input shape and shape inferred by those utility functions match.
        """
        output_shape = shape_array(node.in_node(2).value)
        output_shape[0] = node.in_port(0).data.get_shape()[0]
        kernel_shape = node.in_port(1).data.get_shape()
        node['kernel_shape'] = kernel_shape
        if output_shape is None or kernel_shape is None or node.spatial_dims is None or node.stride is None:
            return

        if not node.has_valid('kernel_spatial_idx'):
            node['kernel_spatial_idx'] = np.delete(
                [x for x in range(len(kernel_shape))],
                (node.input_feature_channel, node.output_feature_channel))

        if not node.has_valid('dilation'):
            node['dilation'] = np.full([len(output_shape)], 1, dtype=np.int64)

        if node.has_valid('get_group'):
            node['group'] = node.get_group(node)

        spatial_dims = node.spatial_dims
        output_spatial = shape_array(output_shape[spatial_dims])
        stride_spatial = shape_array(node.stride[spatial_dims])
        node['kernel_spatial'] = shape_array(
            kernel_shape[node.kernel_spatial_idx])
        node.pad_spatial_shape, input_spatial_for_check = tf_window_op_pad_infer(
            output_spatial, node.kernel_spatial, stride_spatial, node.auto_pad)

        assert compatible_shapes(input_spatial_for_check,
                                 node.in_node(0).shape[spatial_dims])

        pad = np.zeros((len(output_shape), 2), dtype=np.int64)
        pad[spatial_dims] = node.pad_spatial_shape
        node.pad = pad

        node.output = output_shape[node.channel_dims][0]
        node.output_shape = output_shape
        node.out_port(0).data.set_shape(output_shape)

        mark_input_bins(node, ['weights'], 1)
        assign_dims_to_weights(node.in_node(1), node.kernel_spatial_idx,
                               node.input_feature_channel,
                               node.output_feature_channel, len(kernel_shape))

        # OK, now we are sure this is a supported Deconvolution layer
        node.type = 'Deconvolution'
        node.op = 'Deconv2D'

        # Add permute_attrs
        PermuteAttrs.create_permute_attrs(
            node,
            attrs=[
                ('pad', 'input:0'),
                ('stride', 'input:0'),
                ('output_shape', 'input:0'),
                ('batch_dims', 'input:0'),
                ('channel_dims', 'input:0'),
                ('spatial_dims', 'input:0'),
                ('kernel_shape', 'input:1'),
                ('kernel_spatial_idx', 'input:1'),
                ('input_feature_channel', 'input:1'),
                ('output_feature_channel', 'input:1'),
            ])

        # is needed to permute Deconv weights from the original TF [H, W, C_OUT, C_IN] into IE [C_IN, C_OUT, H, W]
        # but for other nodes in weights subgraph permutations must turned off
        # by marking with MarkSubGraphsWithCorrectLayout even if graph layout is NCHW.
        PermuteAttrs.set_permutation(
            node.in_node(1), node, node.soft_get('get_weights_permute', None))
        PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1',
                                              'transpose')
        PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0',
                                              'shape')

        node['force_precision_in_ports'] = {2: 'int64'}
Exemplo n.º 25
0
    def replace(node: Node, const: Node):
        graph = node.graph
        shape = const.shape
        const_name = const.soft_get('name', const.id)

        non_one_dims = np.argwhere(shape != 1).flatten()
        one_dims = np.argwhere(shape == 1).flatten()

        if not (non_one_dims.size == 1 and 5 < np.prod(shape) < 500):
            # (5;500) range is deduced to affect less models
            return

        value = const.value
        if not np.array_equal(
                np.arange(0, np.prod(shape), 1).reshape(shape), value):
            return

        positive_idx = non_one_dims.item(0)
        negative_idx = positive_idx - len(shape)

        node_name = node.soft_get('name', node.id)
        gather = create_op_with_const_inputs(
            graph, Gather, {
                1: int64_array(negative_idx),
                2: int64_array(0)
            }, {'name': node_name + '/BroadcastingDim'})
        gather_for_const = create_op_with_const_inputs(
            graph, Gather, {
                1: int64_array(negative_idx),
                2: int64_array(0)
            }, {'name': const_name + '/BroadcastingDim'})
        shapeof_node = Shape(graph, {
            'name': const_name + '/ShapeOf'
        }).create_node()
        shapeof_node.out_port(0).connect(gather_for_const.in_port(0))

        equal_node = create_op_with_const_inputs(
            graph, Equal, {1: int64_array(1)},
            {'name': node_name + '/ConstOne'})
        gather.out_port(0).connect(equal_node.in_port(0))

        select_node = Select(graph, {
            'name': node_name + '/Select',
            'auto_broadcast': 'numpy'
        }).create_node([equal_node, gather_for_const, gather])

        const.out_port(0).connect(shapeof_node.in_port(0))

        range_node = create_op_with_const_inputs(
            graph, Range, {
                0: mo_array(0, dtype=value.dtype),
                2: mo_array(1, dtype=value.dtype)
            }, {
                'name': const_name + '/Range',
                'dtype': value.dtype
            })
        select_node.out_port(0).connect(range_node.in_port(1))

        node.in_port(1).get_connection().add_destination(gather.in_port(0))

        node.in_port(0).get_connection().set_source(range_node.out_port(0))

        if one_dims.size:
            unsqueeze = create_op_node_with_second_input(
                graph, Unsqueeze, one_dims,
                {'name': const_name + '/KeepShape'})
            range_node.out_port(0).get_connection().insert_node(unsqueeze)
            rename_nodes([(const, const_name + '/ToBeDeleted'),
                          (unsqueeze, const_name)])
        else:
            rename_nodes([(const, const_name + '/ToBeDeleted'),
                          (range_node, const_name)])
Exemplo n.º 26
0
def concat_convolutions(graph: Graph, start_node: Node, last_node: Node):
    """
    This function converts group of convolutions into one
    """

    # Check that concatenation makes in the same order
    conv_nodes = get_next_operation(start_node)
    assert len(conv_nodes) == len(last_node.in_nodes())
    gconv = conv_nodes[0]

    for id in range(len(conv_nodes)):
        conv = conv_nodes[id]
        if conv.out_node().id != last_node.in_node(id).id:
            return False
        # Check that all convolutions have same weights shapes
        if not np.array_equal(conv.in_node(1).shape, gconv.in_node(1).shape):
            log.debug(
                'Grouped convolutions fusion : convolutions have different weights shape'
            )
            return False

    # Check that split and concat dims are valid
    channel_dim = gconv.channel_dims[0]
    split_axis = start_node.in_port(1).data.get_value()
    if channel_dim != split_axis or channel_dim != last_node.axis:
        log.debug(
            'Grouped convolutions fusion : split or concat has weird axis!')
        return False

    # Check that all convolutions has the same parameters
    conv_attrs = ['pad', 'stride']
    for attr in conv_attrs:
        for id in range(len(conv_nodes)):
            conv = conv_nodes[id]
            if not np.array_equal(gconv[attr], conv[attr]):
                log.debug(
                    'Grouped convolutions fusion : attrs {} doesn\'t match'.
                    format(attr))
                return False

    # Check that all Convolutions has biases (if exists)
    has_biases = False
    for id in range(len(conv_nodes)):
        conv = conv_nodes[id]
        if len(conv.in_nodes()) == 3:
            if not has_biases:
                has_biases = True
        elif has_biases:
            return False  # All convolution mast have biases

    # Check that all biases have same shape
    if has_biases:
        for id in range(len(conv_nodes)):
            conv = conv_nodes[id]
            if conv.in_node(2).shape != gconv.in_node(2).shape:
                log.debug(
                    'Group convolutions fusion : convolutions have different biases shape {} and {}'
                    .format(conv.in_node(2).shape,
                            gconv.in_node(2).shape))
                return False

    graph.remove_edge(gconv.in_node(0).id, gconv.id)
    graph.remove_edge(gconv.id, gconv.out_node().id)

    input = start_node.in_node(0)
    output = last_node.out_node()

    # Removing edges from data nodes to Split and Concat
    graph.remove_edge(input.id, start_node.id)
    graph.remove_edge(last_node.id, output.id)

    # Add edges to grouped convolution
    graph.add_edges_from([(input.id, gconv.id, {
        'in': 0
    }), (gconv.id, output.id, {
        'out': 0
    })])

    # Concatenation of convolutions
    weights_node = gconv.in_node(1)
    bias_node = gconv.in_node(2) if has_biases else None

    weights_value = mo_array(weights_node.value)
    bias_value = mo_array(bias_node.value) if has_biases else None

    # gconv.get_weights_permute.perm contains permutation indices
    # where feature dimension is set to zero position, so 0 value
    # in gconv.get_weights_permute.inv indicates original feature dimension index
    feature_dim = np.where(gconv.get_weights_permute.inv == 0)[0][0]

    for conv in conv_nodes[1:]:
        weights_value = np.concatenate((weights_value, conv.in_node(1).value),
                                       axis=feature_dim)
        if has_biases:
            bias_value = np.concatenate((bias_value, conv.in_node(2).value),
                                        axis=-1)  # Not validated

    weights_node.value = mo_array(weights_value)
    weights_node.shape = mo_array(weights_value.shape)

    if has_biases:
        bias_node.value = mo_array(bias_value)
        bias_node.shape = mo_array(bias_value.shape)

    log.debug('Start node : {} Last node : {}  Nodes inside : {}'.format(
        start_node.id, last_node.id, len(start_node.out_nodes())))
    log.debug('Output shape : {}'.format(weights_value.shape))

    gconv.group = len(conv_nodes)
    gconv.output = weights_node.shape[feature_dim]
    gconv.output_shape[feature_dim] = weights_node.shape[feature_dim]

    return True
Exemplo n.º 27
0
 def reverse_infer(node: Node):
     input_shape = node.in_port(0).data.get_shape()
     window = node.soft_get('window', None)
     if input_shape is None and window is not None:
         node.in_port(0).data.set_shape(undefined_shape_of_rank(
             len(window)))
Exemplo n.º 28
0
 def infer(node: Node):
     node_name = node.soft_get('name', node.id)
     assert node.soft_get('group') is not None, 'The attribute "group" must be set for node {}'.format(node_name)
     node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
Exemplo n.º 29
0
    def calculate_frame_time(graph: Graph):
        # there are either one or two inputs in Kaldi. Only main input can change delay in network.
        # Usually ivector input has name 'ivector'.
        max_frame_time = -2
        inputs = graph.get_op_nodes(op='Parameter')
        inp = check_inputs(graph)
        inp_name = inp.soft_get('name', inp.id)

        # sort nodes to calculate delays
        nodes = list(bfs_search(graph, [inp_name]))

        for n in nodes:
            node = Node(graph, n)

            # just ignore data nodes
            if node.kind != 'op':
                continue

            # calculate frame_time (delay) that was not calculated
            if node.frame_time < 0:
                # Splice increases frame delay
                if node.op == "Splice":
                    if node.in_port(0).get_source().node.frame_time == -1:
                        continue
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time + len(node.context) - 1
                # crop often used to get concrete time frame, set frame_time correctly for this case
                elif node.op == 'Crop':
                    if node.in_port(0).get_source().node.frame_time == -1:
                        continue
                    if node.in_port(0).get_connection().get_source(
                    ).node.op == 'Splice':
                        splice_node = node.in_port(0).get_source().node
                        assert len(node.offset) == 1
                        assert len(node.dim) == 1
                        new_delay = splice_node.context[
                            node.offset[0] //
                            node.dim[0]] - splice_node.context[0]
                        node.frame_time = splice_node.in_port(
                            0).get_source().node.frame_time + new_delay
                    else:
                        node.frame_time = node.in_port(
                            0).get_source().node.frame_time
                elif node.op == 'ShapeOf':
                    # exclude shape path from time delay calculation using special value
                    node.frame_time = max_frame_time
                elif node.op == 'Broadcast':
                    # finished shape path
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time
                # for node with several inputs frame_time = maximum of delays from branches
                else:
                    # find out maximum of delay and check that we have at least one branch with another delay
                    node.frame_time = -1 if len(node.in_ports()) != 0 else 0
                    min_in_frame_time = -1
                    for inp in node.in_ports():
                        if node.in_port(inp).disconnected():
                            continue
                        in_node = node.in_port(inp).get_source().node
                        if in_node.frame_time < min_in_frame_time:
                            min_in_frame_time = in_node.frame_time
                        if in_node.frame_time > node.frame_time and in_node.frame_time != -1:
                            node.frame_time = in_node.frame_time
                    # if all inputs have special value for frame time, node have special value for frame time too
                    # because it is on shape path
                    if min_in_frame_time == max_frame_time:
                        node.frame_time = max_frame_time
Exemplo n.º 30
0
def eltwise_reverse_infer(node: Node):
    input_1_shape = node.in_port(0).data.get_shape()
    input_2_shape = node.in_port(1).data.get_shape()
    if input_1_shape is not None and input_2_shape is not None:
        return

    output_shape = node.out_port(0).data.get_shape()
    node_name = node.soft_get('name', node.id)

    if node['auto_broadcast'] is 'none':
        # input_1, input_2 and output shapes must match
        # therefore undefined partial shapes can be exactly defined from output shape
        if output_shape is not None:
            most_defined_shape = output_shape

            # if out_shape = [4, dyn] and input_1_shape = [dyn, 13]
            # then missing shape must be [4, 13]
            if input_1_shape is not None and not compatible_shapes(
                    output_shape, input_1_shape):
                raise Error("shapes are not compatible for node '{}'".format(
                    node_name))
            elif input_1_shape is not None:
                most_defined_shape = find_common_partial_shape(
                    output_shape, input_1_shape)

            if input_2_shape is not None and not compatible_shapes(
                    output_shape, input_2_shape):
                raise Error("shapes are not compatible for node '{}'".format(
                    node_name))
            elif input_2_shape is not None:
                most_defined_shape = find_common_partial_shape(
                    most_defined_shape, input_2_shape)

            if input_1_shape is None:
                node.in_port(0).data.set_shape(most_defined_shape)
            if input_2_shape is None:
                node.in_port(1).data.set_shape(most_defined_shape)
    elif node['auto_broadcast'] == 'numpy':
        if output_shape is not None:
            out_rank = len(output_shape)
            deduced_in_shape = undefined_shape_of_rank(out_rank)

            if input_1_shape is not None and input_2_shape is None and out_rank > len(
                    input_1_shape):
                in_port_to_update = 1
                defined_in_shape = input_1_shape
            elif input_2_shape is not None and input_1_shape is None and out_rank > len(
                    input_2_shape):
                in_port_to_update = 0
                defined_in_shape = input_2_shape
            else:
                return
            defined_in_rank = len(defined_in_shape)

            for i in range(-1, -defined_in_rank - 1, -1):
                assert defined_in_shape[i] == 1 or np.ma.is_masked(defined_in_shape[i]) \
                       or compatible_dims(defined_in_shape[i], output_shape[i]), \
                    "Shapes of Elementwise node '{}' are not compatible for reverse_infer.".format(node_name)

                # if defined_input_shape = [1] and output_shape = [N, 400, 400, 3]
                # partial shape information about sizes should not be lost
                if defined_in_shape[i] == 1 or output_shape[i] == 1:
                    deduced_in_shape[i] = output_shape[i]
            deduced_in_shape[:
                             -defined_in_rank] = output_shape[:
                                                              -defined_in_rank]

            node.in_port(in_port_to_update).data.set_shape(deduced_in_shape)