Ejemplo n.º 1
0
    def extend(op: Node):
        for attr in StridedSlice.get_mask_names():
            # We can not use op.has_and_set(attr) here as a condition, because it will return False if begin/end is
            # 1D tensor and begin_mask/end_mask is equal to 0
            if op.has(attr) and op[attr] != '':
                Extender.attr_to_list(op, attr)
            else:
                assert attr not in ['begin_mask', 'end_mask'],\
                    '{} is not defined for the node {}'.format(attr, op.soft_get('name', op.id))
                op[attr] = int64_array([0])

        op.begin_mask = int64_array([1 - i for i in op.begin_mask])
        op.end_mask = int64_array([1 - i for i in op.end_mask])
Ejemplo n.º 2
0
    def normalize_strided_slice(graph: Graph, node: Node):
        input_shape = node.in_port(0).data.get_shape()
        input_rank = len(input_shape)
        begin = node.in_port(1).data.get_value()
        if begin is not None:
            slice_rank = len(begin)
        else:
            slice_rank = input_rank + np.count_nonzero(
                node.new_axis_mask) - np.count_nonzero(node.shrink_axis_mask)

        StridedSlice.align_mask_with_slice_rank(
            node, slice_rank)  # if StridedSlice is created after partial_infer
        StridedSliceNormalizer.normalize_slices_attr(node)

        num_insertions = input_rank - slice_rank + np.count_nonzero(
            node.new_axis_mask)
        assert num_insertions >= 0, 'slice_rank - num_new_axis must <= input rank. Got instead: ' \
                                    'input_rank = {}, slice_rank = {}, num_new_axis = {}'. \
            format(input_rank, slice_rank, np.count_nonzero(node.new_axis_mask))

        if np.any(node.ellipsis_mask):
            assert np.count_nonzero(
                node.ellipsis_mask
            ) == 1, 'only one ellipsis_mask nonzero value is allowed'
            ellipsis_start = np.nonzero(node.ellipsis_mask)[0][0]
            # since we don't expect values in begin and end: take the whole range along ellipsis_start
            node.begin_mask[ellipsis_start] = 0
            node.end_mask[ellipsis_start] = 0
            node.ellipsis_mask[ellipsis_start] = 0
            insertion_start_idx = ellipsis_start + 1

            StridedSliceNormalizer.unroll_ellipsis_for_inputs(
                graph, node, ellipsis_start, num_insertions)
        elif num_insertions > 0:
            insertion_start_idx = slice_rank  # insert blank values to mask ends
            StridedSliceNormalizer.extend_inputs(node, num_insertions)

        if num_insertions > 0:
            # insert blank values for ellipsis unrolling and extending
            for mask_name in StridedSlice.get_mask_names():
                node[mask_name] = np.insert(node[mask_name],
                                            insertion_start_idx,
                                            [0] * num_insertions).astype(int)