def test_concat(self): self.assertEqual([3, 2, 3, 4], infer.concat([[1, 2, 3, 4], [2, 2, 3, 4]], 0)) self.assertEqual([3, 2, 3, 4], infer.concat([[1, 2, 3, 4], [2, 2, 3, 4]], -4)) self.assertEqual([1, 12, 3, 4], infer.concat([[1, 2, 3, 4], [1, 10, 3, 4]], 1)) self.assertEqual([1, 2, 3, 9], infer.concat([[1, 2, 3, 4], [1, 2, 3, 5]], -1))
def propagate_concat(op, const_value_by_tensor): # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] inputs = op.inputs[:-1] axis = op.inputs[-1] axis = const_value_by_tensor[axis].item() return [infer.concat(inputs=[i.shape for i in inputs], axis=axis)], [op.attribs['T']]
def propagate_concat(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] return ([ infer.concat(inputs=[input.shape for input in op.inputs], axis=op.attribs.get('axis', 1)) ], [op.inputs[0].dtype])
def concat_shape(op): # type: (Caffe2Operation)->ShapeResult if op.attribs['add_axis']: output_shape = infer.stack([input.shape for input in op.inputs], axis=op.attribs['axis']) else: output_shape = infer.concat([input.shape for input in op.inputs], axis=op.attribs['axis']) graph_utils.replace_tensor_in_consumers( op.graph, op.outputs[1], Caffe2Tensor( graph=op.graph, shape=[len(op.inputs)], data=np.array( [input.shape[op.attribs['axis']] for input in op.inputs], dtype=np.int32), dtype=DTYPE_INT32), remove=False) return (output_shape, [len(op.inputs)]), (op.inputs[0].dtype, DTYPE_INT32)
def concat_shape(op): return shapes.concat(inputs=[input.shape for input in op.inputs], axis=op.attribs['axis'])