示例#1
0
    def normalize_strided_slice(graph: Graph, node: Node):
        input_shape = node.in_port(0).data.get_shape()
        input_rank = len(input_shape)
        begin, _, _ = StridedSlice.validate_inputs_and_get_args(node)
        slice_rank = len(begin)

        StridedSlice.align_mask_with_slice_rank(node, slice_rank)  # if StridedSlice is created after partial_infer
        StridedSliceNormalizer.normalize_slices_attr(node)

        num_insertions = input_rank - slice_rank + np.count_nonzero(node.new_axis_mask)
        assert num_insertions >= 0, 'slice_rank - num_new_axis must <= input rank. Got instead: ' \
                                    'input_rank = {}, slice_rank = {}, num_new_axis = {}'. \
            format(input_rank, slice_rank, np.count_nonzero(node.new_axis_mask))

        if np.any(node.ellipsis_mask):
            assert np.count_nonzero(node.ellipsis_mask) == 1, 'only one ellipsis_mask nonzero value is allowed'
            ellipsis_start = np.nonzero(node.ellipsis_mask)[0][0]
            # since we don't expect values in begin and end: take the whole range along ellipsis_start
            node.begin_mask[ellipsis_start] = 0
            node.end_mask[ellipsis_start] = 0
            node.ellipsis_mask[ellipsis_start] = 0
            insertion_start_idx = ellipsis_start + 1

            StridedSliceNormalizer.unroll_ellipsis_for_inputs(graph, node, ellipsis_start, num_insertions)
        elif num_insertions > 0:
            insertion_start_idx = slice_rank  # insert blank values to mask ends
            StridedSliceNormalizer.extend_inputs(node, num_insertions)

        if num_insertions > 0:
            # insert blank values for ellipsis unrolling and extending
            for mask_name in StridedSlice.get_mask_names():
                node[mask_name] = np.insert(node[mask_name], insertion_start_idx, [0] * num_insertions).astype(int)
    def extend(op: Node):
        for attr in StridedSlice.get_mask_names():
            # We can not use op.has_and_set(attr) here as a condition, because it will return False if begin/end is
            # 1D tensor and begin_mask/end_mask is equal to 0
            if op.has(attr) and op[attr] != '':
                Extender.attr_to_list(op, attr)
            else:
                assert attr not in ['begin_mask', 'end_mask'],\
                    '{} is not defined for the node {}'.format(attr, op.soft_get('name', op.id))
                op[attr] = int64_array([0])

        op.begin_mask = int64_array([1 - i for i in op.begin_mask])
        op.end_mask = int64_array([1 - i for i in op.end_mask])
示例#3
0
    def extend(op: Node):
        for attr in StridedSlice.get_mask_names():
            Extender.attr_to_list(op, attr)

        op.begin_mask = int64_array([1 - i for i in op.begin_mask])
        op.end_mask = int64_array([1 - i for i in op.end_mask])