def propagate_broadcast(op, const_value_by_tensor, dtype=''): # type: (TFOperation, _ConstValueByTensorT, str)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] return ([ infer.elementwise(inputs=[input.shape for input in op.inputs], broadcast=infer.Broadcast.FROM_RIGHT) ], [dtype if dtype else get_op_t(op)])
def propagate_broadcast(op, dtype_from_index=0): # type: (ONNXOperation, int)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] return ([ infer.elementwise(inputs=[input.shape for input in op.inputs], broadcast=infer.Broadcast.FROM_RIGHT) ], [op.inputs[dtype_from_index].dtype])
def propagate_broadcast_with_axis(op, dtype=None): # type: (ONNXOperation, typing.Optional[str])->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] x, y = op.inputs if 'axis' in op.attribs: axis = op.attribs['axis'] if axis < 0: axis += x.rank return ([ infer.elementwise(inputs=[x.shape, axis * [1] + y.shape], broadcast=infer.Broadcast.FROM_LEFT) ], [x.dtype]) else: return ([ infer.elementwise(inputs=[x.shape, y.shape], broadcast=infer.Broadcast.FROM_RIGHT) ], [x.dtype if dtype is None else dtype])
def propagate_expand(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] input, shape = op.inputs return [ infer.elementwise( inputs=[input.shape, evaluate_shape_tensor_simple(shape)], broadcast=infer.Broadcast.FROM_RIGHT) ], [input.dtype]
def propagate_expand(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] input, shape = op.inputs input_shape = list(input.shape) other_shape = evaluate_shape_tensor_simple(shape) if len(other_shape) > len(input_shape): input_shape = [1] * (len(other_shape) - len(input_shape)) + input_shape if len(input_shape) > len(other_shape): other_shape = [1] * (len(input_shape) - len(other_shape)) + other_shape # undocumented feature in ONNX, probably comes from pytorch other_shape = [ i if o == -1 else o for i, o in zip(input_shape, other_shape) ] op.attribs['shape'] = other_shape op.inputs = (input, ) return [ infer.elementwise(inputs=[input_shape, other_shape], broadcast=infer.Broadcast.SAME_RANK) ], [input.dtype]
def test_elementwise(self): a = [1, 2, 3] b = [1, 2, 3] c = infer.elementwise([a, b], infer.Broadcast.NONE) self.assertEqual(c, a) self.assertIsNot(c, a) self.assertIsNot(c, b) a = [1, 2, 3] b = [3] ab = [a, b] c = infer.elementwise([a, b], infer.Broadcast.FROM_RIGHT) self.assertEqual([1, 2, 3], c) self.assertEqual([1, 2, 3], a) self.assertEqual([3], b) self.assertEqual([a, b], ab) with self.assertRaises(AssertionError): infer.elementwise([[1, 2], [1]], infer.Broadcast.NONE) with self.assertRaises(AssertionError): infer.elementwise([[1, 2], [1, 3]], infer.Broadcast.NONE) self.assertEqual([2, 4, 3], infer.elementwise([[1, 4, 1], [2, 1, 3]], infer.Broadcast.SAME_RANK)) with self.assertRaises(AssertionError): infer.elementwise([[1, 2], [1]], infer.Broadcast.SAME_RANK) self.assertEqual([4, 2, 3], infer.elementwise([[1, 2, 3], [4, 2]], infer.Broadcast.FROM_LEFT)) with self.assertRaises(AssertionError): infer.elementwise([[2, 3], [3]], infer.Broadcast.FROM_LEFT) self.assertEqual([1, 2, 3], infer.elementwise([[1, 1, 3], [2, 3]], infer.Broadcast.FROM_RIGHT)) with self.assertRaises(AssertionError): infer.elementwise([[2, 3], [2]], infer.Broadcast.FROM_RIGHT)