def propagate_pack(op, const_value_by_tensor):
    # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    return [
        infer.stack(inputs=[input.shape for input in op.inputs],
                    axis=op.attribs['axis'])
    ], [op.attribs['T']]
Beispiel #2
0
def concat_shape(op):
    # type: (Caffe2Operation)->ShapeResult
    if op.attribs['add_axis']:
        output_shape = infer.stack([input.shape for input in op.inputs],
                                   axis=op.attribs['axis'])
    else:
        output_shape = infer.concat([input.shape for input in op.inputs],
                                    axis=op.attribs['axis'])

    graph_utils.replace_tensor_in_consumers(
        op.graph,
        op.outputs[1],
        Caffe2Tensor(
            graph=op.graph,
            shape=[len(op.inputs)],
            data=np.array(
                [input.shape[op.attribs['axis']] for input in op.inputs],
                dtype=np.int32),
            dtype=DTYPE_INT32),
        remove=False)

    return (output_shape, [len(op.inputs)]), (op.inputs[0].dtype, DTYPE_INT32)
 def test_stack(self):
     self.assertEqual([1, 2, 3], infer.stack([[1, 2], [1, 2], [1, 2]], 2))
     self.assertEqual([1, 2, 3], infer.stack([[1, 2], [1, 2], [1, 2]], -1))
     self.assertEqual([3, 1, 2], infer.stack([[1, 2], [1, 2], [1, 2]], 0))
     self.assertEqual([1, 3, 2], infer.stack([[1, 2], [1, 2], [1, 2]], 1))