def Split3(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([2, 2], [0, 2, 4, 6]), \ mesh1:GetMeshImpl([2, 4])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape[:2] + [('axis1', shape[2])] + shape[3:]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1, mtf_shape.dimension_names) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def NoConcatSplit(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \ mesh1:GetMeshImpl([4, 2])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape([shape[0], ('axis0', shape[1])] + shape[2:]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1, mtf_shape.dimension_names) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)