def symbolic_nmsfilt(g, boxes, scores, iou_threshold, score_threshold, max_output_boxes): # if should return all if max_output_boxes <= 0: max_output_boxes = 10000 shape = g.op("Shape", scores) # original shape boxes = view(g, boxes, (1, -1, 4)) max_output_per_class = g.op('Constant', value_t=torch.tensor([max_output_boxes], dtype=torch.long)) iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float)) score_threshold = g.op('Constant', value_t=torch.tensor([score_threshold], dtype=torch.float)) # center_point_box == 1 is for our center_x, centr_y, width, height format nms_out = g.op('NonMaxSuppression', boxes, view(g, scores, (1, 1, -1)), max_output_per_class, iou_threshold, score_threshold, center_point_box_i=1) idx = view(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), (-1,)) scores = view(g, scores, (-1,)) flat_shape = g.op("Shape", scores) src = index_select(g, scores, 0, idx) src = view(g, src, (-1,)) filt = g.op("ConstantOfShape", flat_shape) filt = scatter(g, filt, 0, idx, src) return view(g, filt, shape)
def _index_fill_reshape_helper(g, self, dim, index): # 1. reshape index => [1, ..., 1, dim, 1, ..., 1] # 2. expand index => [..., dim, ...], same shape as self except for dim. # 3. expand value as well. # 4. apply onnx::scatter. from torch.onnx.symbolic_opset9 import expand if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter else: from torch.onnx.symbolic_opset11 import scatter if self.type().dim() is None: return _unimplemented("index_fill", "input rank not accesible") self_dim = self.type().dim() dim_value = _parse_arg(dim, 'i') unsqueezed_index = g.op("Unsqueeze", index, axes_i=[i for i in range(self_dim) if i != dim_value]) expanded_index_shape = scatter(g, g.op("Shape", self), 0, g.op("Unsqueeze", dim, axes_i=[0]), g.op("Shape", index)) expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) return expanded_index_shape, expanded_index
def _scatter_helper(g, self, dim, index, src): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter else: from torch.onnx.symbolic_opset11 import scatter return scatter(g, self, dim, index, src)