Exemple #1
0
 def symbolic_nmsfilt(g, boxes, scores, iou_threshold, score_threshold, max_output_boxes):
     # if should return all
     if max_output_boxes <= 0:
         max_output_boxes = 10000
     shape = g.op("Shape", scores)  # original shape
     boxes = view(g, boxes, (1, -1, 4))
     max_output_per_class = g.op('Constant', value_t=torch.tensor([max_output_boxes], dtype=torch.long))
     iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float))
     score_threshold = g.op('Constant', value_t=torch.tensor([score_threshold], dtype=torch.float))
     # center_point_box == 1 is for our center_x, centr_y, width, height format
     nms_out = g.op('NonMaxSuppression',
                    boxes, view(g, scores, (1, 1, -1)), max_output_per_class, iou_threshold, score_threshold,
                    center_point_box_i=1)
     idx = view(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), (-1,))
     scores = view(g, scores, (-1,))
     flat_shape = g.op("Shape", scores)
     src = index_select(g, scores, 0, idx)
     src = view(g, src, (-1,))
     filt = g.op("ConstantOfShape", flat_shape)
     filt = scatter(g, filt, 0, idx, src)
     return view(g, filt, shape)
Exemple #2
0
def _index_fill_reshape_helper(g, self, dim, index):
    # 1. reshape index => [1, ..., 1, dim, 1, ..., 1]
    # 2. expand index => [..., dim, ...], same shape as self except for dim.
    # 3. expand value as well.
    # 4. apply onnx::scatter.

    from torch.onnx.symbolic_opset9 import expand
    if _export_onnx_opset_version <= 10:
        from torch.onnx.symbolic_opset9 import scatter
    else:
        from torch.onnx.symbolic_opset11 import scatter

    if self.type().dim() is None:
        return _unimplemented("index_fill", "input rank not accesible")
    self_dim = self.type().dim()
    dim_value = _parse_arg(dim, 'i')
    unsqueezed_index = g.op("Unsqueeze", index, axes_i=[i for i in range(self_dim) if i != dim_value])
    expanded_index_shape = scatter(g, g.op("Shape", self), 0,
                                   g.op("Unsqueeze", dim, axes_i=[0]), g.op("Shape", index))
    expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None)
    return expanded_index_shape, expanded_index
def _scatter_helper(g, self, dim, index, src):
    if _export_onnx_opset_version <= 10:
        from torch.onnx.symbolic_opset9 import scatter
    else:
        from torch.onnx.symbolic_opset11 import scatter
    return scatter(g, self, dim, index, src)