def masked_scatter(g, self, mask, source): from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size index = nonzero(g, expand_as(g, mask, self)) # NOTE: source can have more elements than needed. # It could also have arbitrary shape. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. source = view(g, source, torch.LongTensor([-1])) source = sym_help._slice_helper(g, source, axes=torch.LongTensor([0]), starts=torch.LongTensor([0]), ends=size(g, index, torch.LongTensor([0])), dynamic_slice=True) return g.op('ScatterND', self, index, source)
def symbolic_nms(g, boxes, scores, iou_threshold, max_output_boxes): # if should return all if max_output_boxes <= 0: max_output_boxes = 10000 boxes = view(g, boxes, (1, -1, 4)) max_output_per_class = g.op('Constant', value_t=torch.tensor([max_output_boxes], dtype=torch.long)) iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float)) # center_point_box == 1 is for our center_x, centr_y, width, height format nms_out = g.op('NonMaxSuppression', boxes, view(g, scores, (1, 1, -1)), max_output_per_class, iou_threshold, center_point_box_i=1) idx = select( g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))) return view(g, idx, (-1, ))
def repeat(g, self, repeats): if not sym_help._is_value(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) if sym_help._is_packed_list(repeats): repeat_size_len = len(sym_help._unpack_list(repeats)) else: const_repeats = sym_help._maybe_get_const(repeats, 'is') repeat_size_len = len(const_repeats) if self.isCompleteTensor(): sizes = self.type().sizes() diff_dims = repeat_size_len - len(sizes) if diff_dims > 0: self = sym_opset9.view(g, self, [1] * diff_dims + sizes) return g.op("Tile", self, repeats)
def symbolic_nmsfilt(g, boxes, scores, iou_threshold, score_threshold, max_output_boxes): # if should return all if max_output_boxes <= 0: max_output_boxes = 10000 shape = g.op("Shape", scores) # original shape boxes = view(g, boxes, (1, -1, 4)) max_output_per_class = g.op('Constant', value_t=torch.tensor([max_output_boxes], dtype=torch.long)) iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float)) score_threshold = g.op('Constant', value_t=torch.tensor([score_threshold], dtype=torch.float)) # center_point_box == 1 is for our center_x, centr_y, width, height format nms_out = g.op('NonMaxSuppression', boxes, view(g, scores, (1, 1, -1)), max_output_per_class, iou_threshold, score_threshold, center_point_box_i=1) idx = view(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), (-1,)) scores = view(g, scores, (-1,)) flat_shape = g.op("Shape", scores) src = index_select(g, scores, 0, idx) src = view(g, src, (-1,)) filt = g.op("ConstantOfShape", flat_shape) filt = scatter(g, filt, 0, idx, src) return view(g, filt, shape)