Exemple #1
0
def masked_scatter(g, self, mask, source):
    from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size
    index = nonzero(g, expand_as(g, mask, self))
    # NOTE: source can have more elements than needed.
    # It could also have arbitrary shape.
    # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
    source = view(g, source, torch.LongTensor([-1]))
    source = sym_help._slice_helper(g, source,
                                    axes=torch.LongTensor([0]),
                                    starts=torch.LongTensor([0]),
                                    ends=size(g, index, torch.LongTensor([0])),
                                    dynamic_slice=True)
    return g.op('ScatterND', self, index, source)
Exemple #2
0
 def symbolic_nms(g, boxes, scores, iou_threshold, max_output_boxes):
     # if should return all
     if max_output_boxes <= 0:
         max_output_boxes = 10000
     boxes = view(g, boxes, (1, -1, 4))
     max_output_per_class = g.op('Constant',
                                 value_t=torch.tensor([max_output_boxes],
                                                      dtype=torch.long))
     iou_threshold = g.op('Constant',
                          value_t=torch.tensor([iou_threshold],
                                               dtype=torch.float))
     # center_point_box == 1 is for our center_x, centr_y, width, height format
     nms_out = g.op('NonMaxSuppression',
                    boxes,
                    view(g, scores, (1, 1, -1)),
                    max_output_per_class,
                    iou_threshold,
                    center_point_box_i=1)
     idx = select(
         g, nms_out, 1,
         g.op('Constant', value_t=torch.tensor([2], dtype=torch.long)))
     return view(g, idx, (-1, ))
Exemple #3
0
def repeat(g, self, repeats):
    if not sym_help._is_value(repeats):
        repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
    if sym_help._is_packed_list(repeats):  
        repeat_size_len = len(sym_help._unpack_list(repeats))
    else:
        const_repeats = sym_help._maybe_get_const(repeats, 'is')
        repeat_size_len = len(const_repeats)
    if self.isCompleteTensor():
        sizes = self.type().sizes()
        diff_dims = repeat_size_len - len(sizes)
        if diff_dims > 0:
            self = sym_opset9.view(g, self, [1] * diff_dims + sizes)
    return g.op("Tile", self, repeats)
Exemple #4
0
 def symbolic_nmsfilt(g, boxes, scores, iou_threshold, score_threshold, max_output_boxes):
     # if should return all
     if max_output_boxes <= 0:
         max_output_boxes = 10000
     shape = g.op("Shape", scores)  # original shape
     boxes = view(g, boxes, (1, -1, 4))
     max_output_per_class = g.op('Constant', value_t=torch.tensor([max_output_boxes], dtype=torch.long))
     iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float))
     score_threshold = g.op('Constant', value_t=torch.tensor([score_threshold], dtype=torch.float))
     # center_point_box == 1 is for our center_x, centr_y, width, height format
     nms_out = g.op('NonMaxSuppression',
                    boxes, view(g, scores, (1, 1, -1)), max_output_per_class, iou_threshold, score_threshold,
                    center_point_box_i=1)
     idx = view(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), (-1,))
     scores = view(g, scores, (-1,))
     flat_shape = g.op("Shape", scores)
     src = index_select(g, scores, 0, idx)
     src = view(g, src, (-1,))
     filt = g.op("ConstantOfShape", flat_shape)
     filt = scatter(g, filt, 0, idx, src)
     return view(g, filt, shape)