def propagate_pack(op, const_value_by_tensor): # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] return [ infer.stack(inputs=[input.shape for input in op.inputs], axis=op.attribs['axis']) ], [op.attribs['T']]
def concat_shape(op): # type: (Caffe2Operation)->ShapeResult if op.attribs['add_axis']: output_shape = infer.stack([input.shape for input in op.inputs], axis=op.attribs['axis']) else: output_shape = infer.concat([input.shape for input in op.inputs], axis=op.attribs['axis']) graph_utils.replace_tensor_in_consumers( op.graph, op.outputs[1], Caffe2Tensor( graph=op.graph, shape=[len(op.inputs)], data=np.array( [input.shape[op.attribs['axis']] for input in op.inputs], dtype=np.int32), dtype=DTYPE_INT32), remove=False) return (output_shape, [len(op.inputs)]), (op.inputs[0].dtype, DTYPE_INT32)
def test_stack(self): self.assertEqual([1, 2, 3], infer.stack([[1, 2], [1, 2], [1, 2]], 2)) self.assertEqual([1, 2, 3], infer.stack([[1, 2], [1, 2], [1, 2]], -1)) self.assertEqual([3, 1, 2], infer.stack([[1, 2], [1, 2], [1, 2]], 0)) self.assertEqual([1, 3, 2], infer.stack([[1, 2], [1, 2], [1, 2]], 1))