def Split3(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([2, 2], [0, 2, 4, 6]), \
            mesh1:GetMeshImpl([2, 4])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape[:2] + [('axis1', shape[2])] + shape[3:])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1,
                                                mtf_shape.dimension_names)
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def NoConcatSplit(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \
            mesh1:GetMeshImpl([4, 2])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape([shape[0], ('axis0', shape[1])] + shape[2:])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1,
                                                mtf_shape.dimension_names)
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)