def extend(op: Node): for attr in StridedSlice.get_mask_names(): # We can not use op.has_and_set(attr) here as a condition, because it will return False if begin/end is # 1D tensor and begin_mask/end_mask is equal to 0 if op.has(attr) and op[attr] != '': Extender.attr_to_list(op, attr) else: assert attr not in ['begin_mask', 'end_mask'],\ '{} is not defined for the node {}'.format(attr, op.soft_get('name', op.id)) op[attr] = int64_array([0]) op.begin_mask = int64_array([1 - i for i in op.begin_mask]) op.end_mask = int64_array([1 - i for i in op.end_mask])
def normalize_strided_slice(graph: Graph, node: Node): input_shape = node.in_port(0).data.get_shape() input_rank = len(input_shape) begin = node.in_port(1).data.get_value() if begin is not None: slice_rank = len(begin) else: slice_rank = input_rank + np.count_nonzero( node.new_axis_mask) - np.count_nonzero(node.shrink_axis_mask) StridedSlice.align_mask_with_slice_rank( node, slice_rank) # if StridedSlice is created after partial_infer StridedSliceNormalizer.normalize_slices_attr(node) num_insertions = input_rank - slice_rank + np.count_nonzero( node.new_axis_mask) assert num_insertions >= 0, 'slice_rank - num_new_axis must <= input rank. Got instead: ' \ 'input_rank = {}, slice_rank = {}, num_new_axis = {}'. \ format(input_rank, slice_rank, np.count_nonzero(node.new_axis_mask)) if np.any(node.ellipsis_mask): assert np.count_nonzero( node.ellipsis_mask ) == 1, 'only one ellipsis_mask nonzero value is allowed' ellipsis_start = np.nonzero(node.ellipsis_mask)[0][0] # since we don't expect values in begin and end: take the whole range along ellipsis_start node.begin_mask[ellipsis_start] = 0 node.end_mask[ellipsis_start] = 0 node.ellipsis_mask[ellipsis_start] = 0 insertion_start_idx = ellipsis_start + 1 StridedSliceNormalizer.unroll_ellipsis_for_inputs( graph, node, ellipsis_start, num_insertions) elif num_insertions > 0: insertion_start_idx = slice_rank # insert blank values to mask ends StridedSliceNormalizer.extend_inputs(node, num_insertions) if num_insertions > 0: # insert blank values for ellipsis unrolling and extending for mask_name in StridedSlice.get_mask_names(): node[mask_name] = np.insert(node[mask_name], insertion_start_idx, [0] * num_insertions).astype(int)