Ejemplo n.º 1
0
    def infer(node):
        layout = node.graph.graph['layout']
        node_name = node.soft_get('name', node.id)

        assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 3, \
            'The node "{}" must 3 inputs'.format(node_name)

        assert node.has_valid('pooled_w'), '"pooled_w" attribute is not set for node "{}"'.format(node_name)
        assert node.has_valid('pooled_h'), '"pooled_h" attribute is not set for node "{}"'.format(node_name)
        assert node.has_valid('mode'), '"mode" attribute is not set for node "{}"'.format(node_name)
        assert node.mode in ['avg', 'max'], \
            '"mode" attribute range of values is ["avg", "max"], got {} for node "{}"'.format(node.mode, node_name)

        input_shape = node.in_port(0).data.get_shape()
        rois_shape = node.in_port(1).data.get_shape()
        indices_shape = node.in_port(2).data.get_shape()
        assert input_shape is not None and rois_shape is not None and indices_shape is not None, \
            'The node "{}" input shape is None'.format(node_name)
        assert compatible_dims(rois_shape[0], indices_shape[0]), 'The number of batch indices does not correspond ' \
                                                                 'to number of ROIs for node "{}"'.format(node_name)
        assert compatible_dims(rois_shape[1], 4), 'The size of ROI element must be 4 for node "{}"'.format(node_name)
        assert len(input_shape) == 4, 'The rank of port 0 input tensor of node "{}" must be 4.'.format(node_name)

        node.out_port(0).data.set_shape(
            shape_for_layout(layout,
                             batch=rois_shape[0],
                             features=input_shape[get_features_dim(layout, 4)],
                             height=node.pooled_h,
                             width=node.pooled_w)
        )
Ejemplo n.º 2
0
    def infer(node: Node):
        if node.has_and_set('extra_inputs'):
            assert len(node.in_nodes()) == 8
        else:
            assert len(node.in_nodes()) == 5
        assert len(node.out_nodes()) in [1, 2]

        hidden_shape = node.in_node(1).shape.copy()
        cell_shape = node.in_node(2).shape.copy()

        mark_input_bins(node, start_port=3)
        node.out_node(0).shape = hidden_shape
        if len(node.out_nodes()) == 2:
            node.out_node(1).shape = cell_shape

        hidden_size = hidden_shape[1]

        if node.has_valid('hidden_size'):
            if node.hidden_size != hidden_size:
                raise Error(
                    "Input shape {} for hidden size doesn't match pre-defined hidden_size in node {}"
                    .format(node.in_node(1).shape, node.soft_get('name')))
        else:
            node['hidden_size'] = hidden_size

        assert cell_shape[1] == hidden_size

        input_shape = node.in_node(0).shape
        assert input_shape is not None
        assert compatible_dims(hidden_shape[0], cell_shape[0]) and \
               compatible_dims(cell_shape[0], input_shape[0]), 'States are not broadcast-able by batch for node {}' \
                                                               ''.format(node.soft_get('name', node.id))
Ejemplo n.º 3
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        assert len(connected_in_ports) == 2, \
            "Incorrect number of inputs for {} node".format(node_name)

        logits_shape = node.in_port(0).data.get_shape()
        sequence_mask_shape = node.in_port(1).data.get_shape()

        # check shapes of input tensors
        assert len(logits_shape) == 3, \
            'Incorrect rank of logits for {} node'.format(node_name)
        assert len(sequence_mask_shape) == 2, \
            'Incorrect rank of sequence length tensor for {} node'.format(node_name)
        assert compatible_dims(logits_shape[1], sequence_mask_shape[1]), \
            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
        assert compatible_dims(logits_shape[0], sequence_mask_shape[0]), \
            'Time dimensions of input tensors must be the same for {} node'.format(node_name)

        batch_size = logits_shape[1]
        time_size = logits_shape[0]
        node.out_port(0).data.set_shape([batch_size, time_size, 1, 1])
Ejemplo n.º 4
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        assert len(connected_in_ports) in [4, 5], \
            "Incorrect number of inputs for {} node".format(node_name)

        logits_shape = node.in_port(0).data.get_shape()
        logit_length_shape = node.in_port(1).data.get_shape()
        labels_shape = node.in_port(2).data.get_shape()
        label_length_shape = node.in_port(3).data.get_shape()
        blank_index_shape = int64_array([])
        if len(node.in_nodes()) == 5:
            blank_index_shape = node.in_port(4).data.get_shape()

        # check shapes of input tensors
        assert len(logits_shape) == 3 and len(logit_length_shape) == 1 and len(labels_shape) == 2\
            and len(label_length_shape) == 1 and len(blank_index_shape) == 0, \
            'Incorrect rank of some input tensor for {} node'.format(node_name)
        assert compatible_dims(logits_shape[0], logit_length_shape[0]) and \
               compatible_dims(logits_shape[0], labels_shape[0]) and \
               compatible_dims(logits_shape[0], label_length_shape[0]), \
            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
        assert compatible_dims(logits_shape[1], labels_shape[1]), \
            'Time dimensions of input tensors must be the same for {} node'.format(node_name)

        batch_size = logits_shape[0]
        node.out_port(0).data.set_shape([batch_size])
Ejemplo n.º 5
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        loc_shape = node.in_port(0).data.get_shape()
        conf_shape = node.in_port(1).data.get_shape()
        prior_boxes_shape = node.in_port(2).data.get_shape()

        if loc_shape is None or conf_shape is None or prior_boxes_shape is None:
            raise Error(
                'Shapes for the Detection Output node "{}" are not defined'.
                format(node_name))

        prior_size = 4
        if node.has('normalized') and not node.normalized:
            prior_size = 5

        if is_fully_defined(prior_boxes_shape[-1]
                            ) and prior_boxes_shape[-1] % prior_size != 0:
            raise Error(
                'Amount of confidences "{}" is not divisible by {} for node "{}"'
                ''.format(prior_boxes_shape[-1], prior_size, node_name))

        num_priors = prior_boxes_shape[-1] // prior_size
        if not node.has_valid('keep_top_k') or node.keep_top_k == -1:
            node['keep_top_k'] = num_priors

        num_classes = conf_shape[-1] // num_priors
        num_loc_classes = num_classes
        if node.has_and_set('share_location') and node.share_location:
            num_loc_classes = 1

        if not compatible_dims(num_priors * num_loc_classes * 4,
                               loc_shape[-1]):
            raise Error(
                'Locations and prior boxes shapes mismatch: "{}" vs "{}" for node "{}"'
                ''.format(loc_shape, prior_boxes_shape, node_name))

        if not node.variance_encoded_in_target and not compatible_dims(
                prior_boxes_shape[-2], 2):
            raise Error(
                'The "-2" dimension of the prior boxes must be 2 but it is "{}" for node "{}".'
                ''.format(prior_boxes_shape[-2], node_name))

        if is_fully_defined(conf_shape[-1]) and is_fully_defined(
                num_priors) and conf_shape[-1] % num_priors != 0:
            raise Error(
                'Amount of confidences "{}" is not divisible by amount of priors "{}" for node "{}".'
                ''.format(conf_shape[-1], num_priors, node_name))

        node.out_port(0).data.set_shape(
            [1, 1, conf_shape[0] * node.keep_top_k, 7])

        # the line below is needed for the TF framework so the MO will not change the layout
        node.graph.node[node.out_node(0).id]['nchw_layout'] = True
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        assert len(connected_in_ports) in [2, 3], \
            "Incorrect number of inputs for {} node".format(node_name)

        logits_shape = node.in_port(0).data.get_shape()
        sequence_len_shape = node.in_port(1).data.get_shape()
        if len(node.in_nodes()) == 3:
            blank_index_shape = node.in_port(2).data.get_shape()
            assert len(blank_index_shape) == 1, \
                'Incorrect rank of blank_index for {} node'.format(node_name)

        # check shapes of input tensors
        assert len(logits_shape) == 3, \
            'Incorrect rank of logits for {} node'.format(node_name)

        assert len(sequence_len_shape) == 1, \
            'Incorrect rank of sequence length tensor for {} node'.format(node_name)
        assert compatible_dims(logits_shape[0], sequence_len_shape[0]), \
            'Batch dimensions of input tensors must be the same for {} node'.format(node_name)

        batch_size = logits_shape[0]
        time_size = logits_shape[1]
        if node.is_out_port_connected(0):
            node.out_port(0).data.set_shape([batch_size, time_size])
        if node.is_out_port_connected(1):
            node.out_port(1).data.set_shape([batch_size])
Ejemplo n.º 7
0
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) >= 2 and 0 in connected_in_ports and 1 in connected_in_ports, \
            'FullyConnected should have 2 connected input ports, but it doesn\'t for node: `{}`. Ports: {}' \
            ''.format(name, connected_in_ports)

        assert node.has_valid('out-size')
        input_shape = node.in_port(0).data.get_shape()
        weights_shape = node.in_port(1).data.get_shape()
        assert input_shape is not None and weights_shape is not None, \
            'Incorrect FullyConnected input shapes. Node: {}. Shapes: {}'.format(name, [input_shape, weights_shape])
        assert weights_shape.size == 2
        out_size = node.soft_get('out-size')
        assert compatible_dims(weights_shape[0], out_size), \
            'weights_shape={}, out-size={}'.format(weights_shape, out_size)

        if 2 in connected_in_ports:
            bias_value = node.in_port(2).data.get_value()
            bias_shape = node.in_port(2).data.get_shape()
            assert bias_shape is not None, 'Shape was not inferred for biases of FullyConnected {}'.format(
                name)
            assert bias_value is not None, 'Value was not inferred for biases of FullyConnected {}'.format(
                name)
            assert compatible_shapes(bias_shape, [out_size]) or compatible_shapes(bias_shape, [1, out_size]), \
                'Incorrect FullyConnected bias shape `{}` for node {}. `out-size`={}'.format(bias_shape, node, out_size)

        node.out_port(0).data.set_shape([*input_shape[:-1], out_size])
Ejemplo n.º 8
0
    def insert_pre_processing(graph: Graph, input_node: Node, node_mean_scale_values: np.array,
                              preprocessing_name: str):
        assert preprocessing_name in ['scale', 'mean']
        if node_mean_scale_values.get(preprocessing_name) is None:
            return
        user_value = node_mean_scale_values[preprocessing_name]
        value = 1 / user_value if preprocessing_name == 'scale' else user_value * (-1)
        optimize_value = int(preprocessing_name == 'scale')
        op = Mul if preprocessing_name == 'scale' else Add

        if all([x == optimize_value for x in value]):
            return
        assert input_node.has_valid('shape')
        features_dim_idx = get_features_dim(graph.graph['layout'], len(input_node.shape))
        assert compatible_dims(value.size, input_node.shape[features_dim_idx]) or value.size == 1

        shape = np.ones(len(input_node.shape), dtype=np.int64)
        shape[features_dim_idx] = value.size
        value = value.reshape(shape)

        name = input_node.soft_get('name', input_node.id) + '/' + preprocessing_name
        preprocessing = create_op_with_const_inputs(graph, op=op, port_value_dict={1: value}, op_attrs={'name': name})

        for dst in input_node.out_port(0).get_destinations():
            if dst.node.soft_get('type') != 'ShapeOf':
                # After the insertion of additional operations model optimizer
                # should keep the link to the input layer. Parameter node in framework
                # should map to parameter node in IR.
                # For this reason 'fw_tensor_debug_info' should be kept in data node.
                dst.get_connection().set_source(preprocessing.out_port(0), "source")

        input_node.out_port(0).connect(preprocessing.in_port(0))
Ejemplo n.º 9
0
    def insert_pre_processing(graph: Graph, input_node: Node,
                              node_mean_scale_values: np.array,
                              preprocessing_name: str):
        assert preprocessing_name in ['scale', 'mean']
        if node_mean_scale_values.get(preprocessing_name) is None:
            return
        user_value = node_mean_scale_values[preprocessing_name]
        value = 1 / user_value if preprocessing_name == 'scale' else user_value * (
            -1)
        optimize_value = int(preprocessing_name == 'scale')
        op = Mul if preprocessing_name == 'scale' else Add

        if all([x == optimize_value for x in value]):
            return
        assert input_node.has_valid('shape')
        features_dim_idx = get_features_dim(graph.graph['layout'],
                                            len(input_node.shape))
        assert compatible_dims(
            value.size, input_node.shape[features_dim_idx]) or value.size == 1

        shape = np.ones(len(input_node.shape), dtype=np.int64)
        shape[features_dim_idx] = value.size
        value = value.reshape(shape)

        name = input_node.soft_get('name',
                                   input_node.id) + '/' + preprocessing_name
        preprocessing = create_op_with_const_inputs(graph,
                                                    op=op,
                                                    port_value_dict={1: value},
                                                    op_attrs={'name': name})

        if input_node.is_out_port_connected(0) and len(
                input_node.out_port(0).get_destinations()) == 1:
            # There are models with pattern Parameter(uint8) -> Convert(float).
            # Adding mean/scale leads to the following:
            # Parameter(uint8) -> Mean/Scale -> Convert(float) which is incorrect.
            # To fix this mean and scale preprocessing node is inserted after Convert(float) node.
            out_node = input_node.out_port(0).get_destination().node
            convert_type = out_node.soft_get('dst_type')
            if out_node.soft_get('type') == "Convert" and (convert_type in [
                    np.float32, np.float16
            ]):
                input_node = out_node
                if convert_type != value.dtype:
                    new_value = value.astype(convert_type)
                    const_node = preprocessing.in_port(
                        1).get_connection().get_source().node
                    const_node['value'] = new_value

        for dst in input_node.out_port(0).get_destinations():
            if dst.node.soft_get('type') != 'ShapeOf':
                # After the insertion of additional operations model optimizer
                # should keep the link to the input layer. Parameter node in framework
                # should map to parameter node in IR.
                # For this reason 'fw_tensor_debug_info' should be kept in data node.
                dst.get_connection().set_source(preprocessing.out_port(0),
                                                "source")

        input_node.out_port(0).connect(preprocessing.in_port(0))
Ejemplo n.º 10
0
    def infer(node: Node):
        """
        Performs shape inference of MatMul node as operation doc-string says
        Raises on any shape inconsistency
        """
        name = node.soft_get('name', str(node.id))
        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \
            "MatMul should have 2 connected input ports, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_in_ports)

        log.debug('MatMul `{}` input shapes: {}'.format(
            name, [node.in_port(i).data.get_shape() for i in range(2)]))
        A_shape, B_shape = MatMul.shape_alignment(node)
        log.debug('MatMul `{}` aligned input shapes: {}'.format(
            name, [A_shape, B_shape]))

        assert compatible_dims(A_shape[-1], B_shape[-2]), \
            "MatMul input shapes are incorrect. COL_INDEX_DIMs are not equal. Node: {}. Shapes: {}" \
            "".format(name, [A_shape, B_shape])

        output_shape = np.ma.concatenate((A_shape[:-1], B_shape[-1:]))

        if node.in_port(0).data.get_shape().size == 1:
            assert compatible_dims(output_shape[-2], 1)
            output_shape = shape_delete(output_shape, -2)
        if node.in_port(1).data.get_shape().size == 1:
            assert compatible_dims(output_shape[-1], 1)
            output_shape = shape_delete(output_shape, -1)

        node.out_port(0).data.set_shape(output_shape)

        in_ch = 0 if not node.transpose_b else 1
        out_ch = 1 if not node.transpose_b else 0
        assign_dims_to_weights(node.in_node(1), None, in_ch, out_ch,
                               node.in_port(1).data.get_shape().size)
        MatMul.value_propagation(node)
Ejemplo n.º 11
0
    def repack_weights(graph: Graph, match: dict):
        """
        Repack weights into general format (described above) and reorder gates.
        """
        rnn_layer = match['rnn_layer']
        W = match['W'].value.copy()
        R = match['R'].value.copy()
        num_directions = 2 if rnn_layer.direction == 'bidirectional' else 1

        graph.remove_edge(match['W'].id, rnn_layer.id)
        graph.remove_edge(match['R'].id, rnn_layer.id)

        # find optional 'B' biases blob
        if 3 in rnn_layer.in_nodes():
            # TODO: check if 'bin': 'B' attribute is assigned to this edge
            B = rnn_layer.in_node(3).value.copy()
            graph.remove_edge(rnn_layer.in_node(3).id, rnn_layer.id)
        else:
            B_shape = [num_directions, 2 * rnn_layer.multiplier * rnn_layer.hidden_size]  # from ONNX spec
            B = np.full(B_shape, 0, dtype=np.float32)

        # Add extra dimensions for W, R and B for easier repacking and reordering
        B = B.reshape([
            num_directions,  # 0: num of directions
            rnn_layer.num_layers,  # 1: num_layers
            2,  # 2: two input parts of the matrix: W, R
            rnn_layer.multiplier,  # 3: four output parts of the matrix for all gates in order: i, o, f, c
            rnn_layer.hidden_size,  # 4: output size per direction and gate
        ])

        W, R = [x.reshape([
                num_directions,  # 0: num of directions
                rnn_layer.num_layers,  # 1: num_layers
                rnn_layer.multiplier,  # 2: four output parts of the matrix for all gates in order: i, o, f, c
                rnn_layer.hidden_size,  # 3: output size per direction and gate
                -1])  # 4: input size/hidden size in W/R correspondingly
                for x in (W, R)]

        input_size = match['input'].shape[2]
        assert compatible_dims(input_size, W.shape[-1])

        # Reorder gates: iofc --> fico
        gate_reorder = rnn_layer.gate_order
        W, R = (np.take(x, gate_reorder, axis=2) for x in (W, R))
        B = np.take(B, gate_reorder, axis=3)

        for blob, port in [(W, 1), (R, 2), (B, 3)]:
            Op.create_and_connect_input_data_node(
                graph,
                rnn_layer,
                {'value': blob, 'shape': int64_array(blob.shape)},
                {'in': port, 'permutation': None}
            )
Ejemplo n.º 12
0
    def get_suitable_channel_index(node: Node, shape):
        if len(shape) != 4:
            return None

        guessed_layout = 'NCHW'
        if node.has_valid('rt_info'):
            rt_info = node.rt_info
            if rt_info.contains('old_api_map_order'):
                old_api_map_version = rt_info.get_attribute_version('old_api_map_order')
                old_api_map = rt_info.info['old_api_map_order', old_api_map_version]
                if 'inverse_order' in old_api_map.info:
                    order = old_api_map.info['inverse_order']
                    assert len(order) == len(guessed_layout)
                    guessed_layout = np.array(list(guessed_layout))[order]
                    guessed_layout = ''.join(guessed_layout)
        idx, has_layout = get_dim_from_layout(node, 'C')
        if not has_layout:
            idx = get_features_dim(guessed_layout, len(node.shape))
        if compatible_dims(shape[idx], 3):
            return idx
        else:
            return None
    def replace_pattern(graph, match: dict):
        # Check for SS params
        # Sanity check that we iterate over axis of some tensor
        ss = match['Strided_slice']
        params = ss.in_nodes()
        assert np.all(params[1].in_node().value == 0)
        assert np.all(params[2].in_node().value == 1)
        assert np.all(params[3].in_node().value == 1)

        # Check for comparing SS and seq_length source (it should be one tensor)
        # SIMPLE CHECK
        assert match['Strided_slice_data'].value is not None
        if match['minimum_data'].value is None:
            log.warning(
                'TF loop doesn\'t have a constant upper bound produced by node {}, or ModelOptimizer '
                'cannot detect a constant in this case. Loops with a dynamic number of iterations are not '
                'supported, so in the resulting IR, generated TensorIterator will have '
                'a maximum number of iterations determined by input tensor size: {}'
                ''.format(match['minimum_data'].soft_get('name'),
                          match['Strided_slice_data'].value))
        else:
            assert compatible_dims(match['Strided_slice_data'].value, match['minimum_data'].value), \
                'Values do not match: {} and {}'.format(match['Strided_slice_data'].value, match['minimum_data'].value)

        # Check that bound for Condition and Inputs/Outputs sizes match
        condition_time = match['condition'].out_node(0)
        inputs_and_outputs = condition_time.out_nodes()
        type_list = ['TensorIteratorInput']

        for ta in inputs_and_outputs:
            if ta.has_valid(
                    'kind') and ta['kind'] == 'op' and ta['op'] in type_list:
                assert ta.in_node(0).id == ss.id

        log.debug(
            '+++++++++++++++ Condition Check was successful ++++++++++++++++')
Ejemplo n.º 14
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        assert len(connected_in_ports) == 2, \
            "Incorrect number of inputs for {} node".format(node_name)

        data_shape = node.in_port(0).data.get_shape()
        data_value = node.in_port(0).data.get_value()
        indices_shape = node.in_port(1).data.get_shape()
        indices_value = node.in_port(1).data.get_value()

        assert node.has_valid(
            'batch_dims'
        ), "Node {} must contain `batch_dims` attribute".format(node_name)
        batch_dims = node.batch_dims

        # check that a number of batch dimensions is less than both ranks of data and indices tensors
        assert batch_dims < len(
            data_shape
        ), "Number of batch dimensions must be less than a rank of data"
        assert batch_dims < len(
            indices_shape
        ), "Number of batch dimensions must be less than a rank of indices"

        # check that batch dimensions of data and indices are the same
        for batch_dim in range(batch_dims):
            assert compatible_dims(data_shape[batch_dim], indices_shape[batch_dim]), \
                "The dimension {} for data and indices tensors must be the same".format(batch_dim)

        # check ranks of input tensors
        assert len(data_shape) > 0, "Data must not be a scalar"
        assert len(indices_shape) > 0, "Indices must not be a scalar"
        assert (batch_dims + indices_shape[-1]) <= len(data_shape), \
            "Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions"
        assert node['version'] in ['opset5', 'opset8'], 'Unsupported version of GatherND operation: {}, operation ' \
                                                        'name : {}'.format(node['version'], node.soft_get('name'))

        # compute output shape
        batch = []
        if batch_dims > 0:
            if node['version'] == 'opset5':  # Support old version of gatherND shape inference
                if is_fully_defined(data_shape[:batch_dims]):
                    batch = [np.prod(data_shape[:batch_dims]).tolist()]
                else:
                    batch = [dynamic_dimension_value]
            elif node['version'] == 'opset8':
                for dim in range(batch_dims):
                    assert compatible_dims(indices_shape[dim], data_shape[dim]),\
                        "Batch dimensions in data.shape and indices.shape must be compatible"
                if is_fully_defined(indices_shape[:batch_dims]):
                    batch = indices_shape[:batch_dims].tolist()
                elif is_fully_defined(data_shape[:batch_dims]):
                    batch = data_shape[:batch_dims].tolist()
                else:
                    for ind in range(batch_dims):
                        if indices_shape[ind] != dynamic_dimension_value:
                            batch.append(indices_shape[ind])
                        elif data_shape[ind] != dynamic_dimension_value:
                            batch.append(data_shape[ind])
                        else:
                            batch.append(dynamic_dimension_value)

        slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):])

        output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape
        node.out_port(0).data.set_shape(output_shape)

        # compute output value if all input indices are defined
        if is_fully_defined(indices_value) and data_value is not None:
            batch_dims_size = 1

            for i in range(batch_dims):
                batch_dims_size *= indices_shape[i]

            output_data = []

            reshaped_indices = indices_value.reshape(batch_dims_size, -1,
                                                     indices_shape[-1])

            reshaped_data = data_value.reshape((batch_dims_size, ) + tuple(
                (data_shape[batch_dims:])))

            for batch_dim in range(reshaped_indices.shape[0]):
                for outer_dim in range(reshaped_indices.shape[1]):
                    gather_index = tuple(
                        reshaped_indices[batch_dim][outer_dim])
                    output_data.append(reshaped_data[(batch_dim, ) +
                                                     gather_index])
            output_value = np.asarray(
                output_data, dtype=data_value.dtype).reshape(output_shape)
            node.out_port(0).data.set_value(output_value)
Ejemplo n.º 15
0
def eltwise_reverse_infer(node: Node):
    input_1_shape = node.in_port(0).data.get_shape()
    input_2_shape = node.in_port(1).data.get_shape()
    if input_1_shape is not None and input_2_shape is not None:
        return

    output_shape = node.out_port(0).data.get_shape()
    node_name = node.soft_get('name', node.id)

    if node['auto_broadcast'] is 'none':
        # input_1, input_2 and output shapes must match
        # therefore undefined partial shapes can be exactly defined from output shape
        if output_shape is not None:
            most_defined_shape = output_shape

            # if out_shape = [4, dyn] and input_1_shape = [dyn, 13]
            # then missing shape must be [4, 13]
            if input_1_shape is not None and not compatible_shapes(
                    output_shape, input_1_shape):
                raise Error("shapes are not compatible for node '{}'".format(
                    node_name))
            elif input_1_shape is not None:
                most_defined_shape = find_common_partial_shape(
                    output_shape, input_1_shape)

            if input_2_shape is not None and not compatible_shapes(
                    output_shape, input_2_shape):
                raise Error("shapes are not compatible for node '{}'".format(
                    node_name))
            elif input_2_shape is not None:
                most_defined_shape = find_common_partial_shape(
                    most_defined_shape, input_2_shape)

            if input_1_shape is None:
                node.in_port(0).data.set_shape(most_defined_shape)
            if input_2_shape is None:
                node.in_port(1).data.set_shape(most_defined_shape)
    elif node['auto_broadcast'] == 'numpy':
        if output_shape is not None:
            out_rank = len(output_shape)
            deduced_in_shape = undefined_shape_of_rank(out_rank)

            if input_1_shape is not None and input_2_shape is None and out_rank > len(
                    input_1_shape):
                in_port_to_update = 1
                defined_in_shape = input_1_shape
            elif input_2_shape is not None and input_1_shape is None and out_rank > len(
                    input_2_shape):
                in_port_to_update = 0
                defined_in_shape = input_2_shape
            else:
                return
            defined_in_rank = len(defined_in_shape)

            for i in range(-1, -defined_in_rank - 1, -1):
                assert defined_in_shape[i] == 1 or np.ma.is_masked(defined_in_shape[i]) \
                       or compatible_dims(defined_in_shape[i], output_shape[i]), \
                    "Shapes of Elementwise node '{}' are not compatible for reverse_infer.".format(node_name)

                # if defined_input_shape = [1] and output_shape = [N, 400, 400, 3]
                # partial shape information about sizes should not be lost
                if defined_in_shape[i] == 1 or output_shape[i] == 1:
                    deduced_in_shape[i] = output_shape[i]
            deduced_in_shape[:
                             -defined_in_rank] = output_shape[:
                                                              -defined_in_rank]

            node.in_port(in_port_to_update).data.set_shape(deduced_in_shape)