Exemplo n.º 1
0
def transform_polydata_from_disk(in_filename, transform_filename, out_filename):
    # Read it in.
    print "<io.py> Transforming ", in_filename, "->", out_filename, "..."

    # Read the transform from disk because we cannot pickle it
    (root, ext) = os.path.splitext(transform_filename)
    print root, ext
    if ext == '.xfm':
        reader = vtk.vtkMNITransformReader()
        reader.SetFileName(transform_filename)
        reader.Update()
        transform = reader.GetTransform()
    elif ext == '.img':
        reader = vtk.vtkImageReader()
        reader.SetFileName(transform_filename)
        reader.Update()
        coeffs = reader.GetOutput()
        transform = vtk.vtkBSplineTransform()
        transform.SetCoefficients(coeffs)
        print coeffs
        print transform
    else:
        f = open(transform_filename, 'r')
        transform = vtk.vtkTransform()
        matrix = vtk.vtkMatrix4x4()
        for i in range(0,4):
            for j in range(0,4):
                matrix_val = float(f.readline())
                matrix.SetElement(i,j, matrix_val)
        transform.SetMatrix(matrix)
        del matrix

    start_time = time.time()
    pd = read_polydata(in_filename)
    elapsed_time = time.time() - start_time
    print "READ:", elapsed_time
    # Transform it.
    start_time = time.time()
    transformer = vtk.vtkTransformPolyDataFilter()
    if (vtk.vtkVersion().GetVTKMajorVersion() >= 6.0):
        transformer.SetInputData(pd)
    else:
        transformer.SetInput(pd)
    transformer.SetTransform(transform)
    transformer.Update()
    elapsed_time = time.time() - start_time
    print "TXFORM:", elapsed_time

    # Write it out.
    start_time = time.time()
    pd2 = transformer.GetOutput()
    write_polydata(pd2, out_filename)
    elapsed_time = time.time() - start_time
    print "WRITE:", elapsed_time

    # Clean up.
    del transformer
    del pd2
    del pd
    del transform
Exemplo n.º 2
0
def transform_polydata_from_disk(in_filename, transform_filename, out_filename):
    # Read it in.
    print "<io.py> Transforming ", in_filename, "->", out_filename, "..."

    # Read the transform from disk because we cannot pickle it
    (root, ext) = os.path.splitext(transform_filename)
    print root, ext
    if ext == '.xfm':
        reader = vtk.vtkMNITransformReader()
        reader.SetFileName(transform_filename)
        reader.Update()
        transform = reader.GetTransform()
    elif ext == '.img':
        reader = vtk.vtkImageReader()
        reader.SetFileName(transform_filename)
        reader.Update()
        coeffs = reader.GetOutput()
        transform = vtk.vtkBSplineTransform()
        transform.SetCoefficients(coeffs)
        print coeffs
        print transform
    else:
        f = open(transform_filename, 'r')
        transform = vtk.vtkTransform()
        matrix = vtk.vtkMatrix4x4()
        for i in range(0,4):
            for j in range(0,4):
                matrix_val = float(f.readline())
                matrix.SetElement(i,j, matrix_val)
        transform.SetMatrix(matrix)
        del matrix

    start_time = time.time()
    pd = read_polydata(in_filename)
    elapsed_time = time.time() - start_time
    print "READ:", elapsed_time
    # Transform it.
    start_time = time.time()
    transformer = vtk.vtkTransformPolyDataFilter()
    if (vtk.vtkVersion().GetVTKMajorVersion() >= 6.0):
        transformer.SetInputData(pd)
    else:
        transformer.SetInput(pd)
    transformer.SetTransform(transform)
    transformer.Update()
    elapsed_time = time.time() - start_time
    print "TXFORM:", elapsed_time

    # Write it out.
    start_time = time.time()
    pd2 = transformer.GetOutput()
    write_polydata(pd2, out_filename)
    elapsed_time = time.time() - start_time
    print "WRITE:", elapsed_time

    # Clean up.
    del transformer
    del pd2
    del pd
    del transform
Exemplo n.º 3
0
def convert_transform_to_vtk(transform):
    """Produce an output vtkBSplineTransform corresponding to the

    registration results. Input is a numpy array corresponding to the displacement field.
    """
    displacement_field_vtk = vtk.util.numpy_support.numpy_to_vtk(
        num_array=transform, deep=True, array_type=vtk.VTK_FLOAT)
    displacement_field_vtk.SetNumberOfComponents(3)
    displacement_field_vtk.SetName('DisplacementField')
    grid_image = vtk.vtkImageData()
    if (vtk.vtkVersion().GetVTKMajorVersion() >= 6.0):
        grid_image.AllocateScalars(vtk.VTK_FLOAT, 3)
        grid_image.GetPointData().SetScalars(displacement_field_vtk)
    else:
        grid_image.SetScalarTypeToFloat()
        grid_image.SetNumberOfScalarComponents(3)
        grid_image.GetPointData().SetScalars(displacement_field_vtk)
        grid_image.Update()
    #print "CONVERT TXFORM 1:", grid_image.GetExtent(), displacement_field_vtk.GetSize()

    # this is a hard-coded assumption about where the polydata is located in space.
    # other code should check that it is centered.
    # This code uses a grid of 240mm x 240mm x 240mm
    #spacing origin extent
    num_vectors = len(transform) / 3
    dims = round(numpy.power(num_vectors, 1.0 / 3.0))
    # This MUST correspond to the size used in congeal_multisubject update_nonrigid_grid
    #size_mm = 240.0
    size_mm = 200.0
    origin = -size_mm / 2.0
    # assume 240mm x 240mm x 240mm grid
    spacing = size_mm / (dims - 1)
    grid_image.SetOrigin(origin, origin, origin)
    grid_image.SetSpacing(spacing, spacing, spacing)
    #grid_image.SetExtent(0, dims-1.0, 0, dims-1.0, 0, dims-1.0)
    grid_image.SetDimensions(int(dims), int(dims), int(dims))
    #print "CONVERT TXFORM:", num_vectors, dims, int(dims), dims-1.0, grid_image.GetExtent(),

    #print "GRID:", grid_image
    coeff = vtk.vtkImageBSplineCoefficients()
    if (vtk.vtkVersion().GetVTKMajorVersion() >= 6.0):
        coeff.SetInputData(grid_image)
    else:
        coeff.SetInput(grid_image)

    coeff.Update()
    # this was in the test code.
    coeff.UpdateWholeExtent()
    #print "TX:", transform.shape, transform, displacement_field_vtk, grid_image.GetExtent(), coeff.GetOutput().GetExtent()

    vtktrans = vtk.vtkBSplineTransform()
    if (vtk.vtkVersion().GetVTKMajorVersion() >= 6.0):
        vtktrans.SetCoefficientData(coeff.GetOutput())
    else:
        vtktrans.SetCoefficients(coeff.GetOutput())
    vtktrans.SetBorderModeToZero()

    ## print "~~~~~~~~~~~~~~~~~~~~~~~~"
    ## print "COEFF:",  coeff.GetOutput()
    ## print "*********"
    ## print "COEFF2:", vtktrans.GetCoefficients()
    ## print "======="

    return vtktrans
Exemplo n.º 4
0
def write_transforms_to_itk_format(transform_list, outdir, subject_ids=None):
    """Write VTK affine or spline transforms to ITK 4 text file formats.

    Input transforms are in VTK RAS space and are forward transforms. Output
    transforms are in LPS space and are the corresponsing inverse
    transforms, according to the conventions for these file formats and for
    resampling images. The affine transform is straightforward. The spline
    transform file format is just a list of displacements that have to be in
    the same order as they are stored in ITK C code. This now outputs an ITK
    transform that works correctly to transform the tracts (or any volume in
    the same space) in Slicer. In the nonrigid case, we also output a vtk
    native spline transform file using MNI format.
    """

    idx = 0
    tx_fnames = list()
    for tx in transform_list:

        # save out the vtk transform to a text file as it is
        # The MNI transform reader/writer are available in vtk so use those:
        if tx.GetClassName() != 'vtkBSplineTransform':
            writer = vtk.vtkMNITransformWriter()
            writer.AddTransform(tx)
            if subject_ids is not None:
                fname = 'vtk_txform_' + str(subject_ids[idx]) + '.xfm'
            else:
                fname = 'vtk_txform_{0:05d}.xfm'.format(idx)
            writer.SetFileName(os.path.join(outdir, fname))
            writer.Write()

        # file name for itk transform written below
        if subject_ids is not None:
            fname = 'itk_txform_' + str(subject_ids[idx]) + '.tfm'
        else:
            fname = 'itk_txform_{0:05d}.tfm'.format(idx)
        fname = os.path.join(outdir, fname)
        tx_fnames.append(fname)

        # Save the itk transform as the inverse of this transform (resampling transform) and in LPS.
        # This will show the same transform in the slicer GUI as the vtk transform we internally computed
        # that is stored in the .xfm text file, above.
        # To apply our transform to resample a volume in LPS:
        # convert to RAS, use inverse of transform to resample, convert back to LPS
        if tx.GetClassName() == 'vtkThinPlateSplineTransform' or  tx.GetClassName() == 'vtkBSplineTransform':
            #print 'Saving nonrigid transform displacements in ITK format'

            # Deep copy to avoid modifying input transform that will be applied to polydata
            if tx.GetClassName() == 'vtkThinPlateSplineTransform':
                tps = vtk.vtkThinPlateSplineTransform()
            else:
                tps = vtk.vtkBSplineTransform()
            tps.DeepCopy(tx)

            #extent = tps.GetCoefficients().GetExtent()
            #origin = tps.GetCoefficients().GetOrigin()
            #spacing = tps.GetCoefficients().GetSpacing()
            #dims = tps.GetCoefficients().GetDimensions()
            #print "E:", extent
            #print "O:", origin
            #print "S:", spacing
            #print "D:", dims

            # invert to get the transform suitable for resampling an image
            tps.Inverse()

            # convert the inverse spline transform from RAS to LPS
            ras_2_lps = vtk.vtkTransform()
            ras_2_lps.Scale(-1, -1, 1)
            lps_2_ras = vtk.vtkTransform()
            lps_2_ras.Scale(-1, -1, 1)
            spline_inverse_lps = vtk.vtkGeneralTransform()
            spline_inverse_lps.Concatenate(lps_2_ras)
            spline_inverse_lps.Concatenate(tps)
            spline_inverse_lps.Concatenate(ras_2_lps)

            # Now, loop through LPS space. Find the effect of the
            # inverse transform on each point. This is essentially what
            # vtk.vtkTransformToGrid() does, but this puts things into
            # LPS.

            # This low-res grid produced small differences (order of 1-2mm) when transforming
            # polydatas inside Slicer vs. in this code. 
            #grid_size = [15, 15, 15]
            #grid_spacing = 10
            # This higher-res grid has fewer small numerical differences
            # grid_size = [50, 50, 50]
            # grid_spacing = 5
            # This higher-res grid has fewer small numerical differences, but files are larger
            #grid_size = [70, 70, 70]
            #grid_spacing = 3

            # This higher-res grid is sufficient to limit numerical
            # differences to under .1mm in tests.  However, files are
            # quite large (47M). As this is still much smaller than
            # the tractography files, and correctness is desired, we
            # will produce large transform files. A preferable
            # solution would be to store the forward transform we
            # compute at the grid points at which it is defined, but
            # there is no inverse flag available in the file
            # format. Therefore the inverse must be stored at high
            # resolution.
            grid_size = [105, 105, 105]
            grid_spacing = 2

            extent_0 = [-(grid_size[0] - 1)/2, -(grid_size[1] - 1)/2, -(grid_size[2] - 1)/2]
            extent_1 = [ (grid_size[0] - 1)/2,  (grid_size[1] - 1)/2,  (grid_size[2] - 1)/2]

            origin = -grid_spacing * (numpy.array(extent_1) - numpy.array(extent_0))/2.0

            grid_points_LPS = list()
            grid_points_RAS = list()

            # ordering of grid points must match itk-style array order for images
            for s in range(extent_0[0], extent_1[0]+1):
                for p in range(extent_0[1], extent_1[1]+1):
                    for l in range(extent_0[2], extent_1[2]+1):
                        grid_points_RAS.append([-l*grid_spacing, -p*grid_spacing, s*grid_spacing])
                        grid_points_LPS.append([l*grid_spacing, p*grid_spacing, s*grid_spacing])

            displacements_LPS = list()

            print "LPS grid for storing transform:", grid_points_LPS[0], grid_points_LPS[-1], grid_spacing

            lps_points = vtk.vtkPoints()
            lps_points2 = vtk.vtkPoints()
            for gp_lps in grid_points_LPS:
                lps_points.InsertNextPoint(gp_lps[0], gp_lps[1], gp_lps[2])

            spline_inverse_lps.TransformPoints(lps_points, lps_points2)
            pidx = 0
            for gp_lps in grid_points_LPS:
                pt = lps_points2.GetPoint(pidx)
                diff_lps = [pt[0] - gp_lps[0], pt[1] - gp_lps[1], pt[2] - gp_lps[2]]
                pidx += 1

                ## # this tested grid definition and origin were okay.
                ## diff_lps = [20,30,40]

                ## # this tested that the ordering of L,P,S is correct:
                ## diff_lps = [0, gp_lps[1], 0]
                ## diff_lps = [gp_lps[0], 0, 0]
                ## diff_lps = [0, 0, gp_lps[2]]

                ## # this tested that the ordering of grid points is correct
                ## # only the R>0, A>0, S<0 region shows a transform.
                ## if gp_lps[0] < 0 and gp_lps[1] < 0 and gp_lps[2] < 0:
                ##     diff_lps = [gp_lps[0]/2.0, 0, 0]
                ## else:
                ##     diff_lps = [0, 0, 0]

                displacements_LPS.append(diff_lps)

            # save the points and displacement vectors in ITK format.
            #print 'Saving in ITK transform format.'
            f = open(fname, 'w')
            f.write('#Insight Transform File V1.0\n')
            f.write('# Transform 0\n')
            # ITK version 3 that included an additive (!) affine transform
            #f.write('Transform: BSplineDeformableTransform_double_3_3\n')
            # ITK version 4 that does not include a second transform in the file
            f.write('Transform: BSplineTransform_double_3_3\n')
            f.write('Parameters: ')
            # "Here the data are: The bulk of the BSpline part are 3D
            # displacement vectors for each of the BSpline grid-nodes
            # in physical space, i.e. for each grid-node, there will
            # be three blocks of displacements defining dx,dy,dz for
            # all grid nodes."
            for block in [0, 1, 2]:
                for diff in displacements_LPS:
                    f.write('{0} '.format(diff[block]))

            #FixedParameters: size size size origin origin origin origin spacing spacing spacing (then direction cosines: 1 0 0 0 1 0 0 0 1)
            f.write('\nFixedParameters:')
            #f.write(' {0} {0} {0}'.format(2*sz+1))
            f.write(' {0}'.format(grid_size[0]))
            f.write(' {0}'.format(grid_size[1]))
            f.write(' {0}'.format(grid_size[2]))

            f.write(' {0}'.format(origin[0]))
            f.write(' {0}'.format(origin[1]))
            f.write(' {0}'.format(origin[2]))
            f.write(' {0} {0} {0}'.format(grid_spacing))
            f.write(' 1 0 0 0 1 0 0 0 1\n')

            f.close()
        else:
            tx_inverse = vtk.vtkTransform()
            tx_inverse.DeepCopy(tx)
            tx_inverse.Inverse()
            ras_2_lps = vtk.vtkTransform()
            ras_2_lps.Scale(-1, -1, 1)
            lps_2_ras = vtk.vtkTransform()
            lps_2_ras.Scale(-1, -1, 1)
            tx2 = vtk.vtkTransform()
            tx2.Concatenate(lps_2_ras)
            tx2.Concatenate(tx_inverse)
            tx2.Concatenate(ras_2_lps)

            three_by_three = list()
            translation = list()
            for i in range(0,3):
                for j in range(0,3):
                    three_by_three.append(tx2.GetMatrix().GetElement(i,j))
            translation.append(tx2.GetMatrix().GetElement(0,3))
            translation.append(tx2.GetMatrix().GetElement(1,3))
            translation.append(tx2.GetMatrix().GetElement(2,3))

            f = open(fname, 'w')
            f.write('#Insight Transform File V1.0\n')
            f.write('# Transform 0\n')
            f.write('Transform: AffineTransform_double_3_3\n')
            f.write('Parameters: ')
            for el in three_by_three:
                f.write('{0} '.format(el))
            for el in translation:
                f.write('{0} '.format(el))
            f.write('\nFixedParameters: 0 0 0\n')
            f.close()

        idx +=1
    return(tx_fnames)
Exemplo n.º 5
0
def write_transforms_to_itk_format(transform_list, outdir, subject_ids=None):
    """Write VTK affine or spline transforms to ITK 4 text file formats.

    Input transforms are in VTK RAS space and are forward transforms. Output
    transforms are in LPS space and are the corresponsing inverse
    transforms, according to the conventions for these file formats and for
    resampling images. The affine transform is straightforward. The spline
    transform file format is just a list of displacements that have to be in
    the same order as they are stored in ITK C code. This now outputs an ITK
    transform that works correctly to transform the tracts (or any volume in
    the same space) in Slicer. In the nonrigid case, we also output a vtk
    native spline transform file using MNI format.
    """

    idx = 0
    tx_fnames = list()
    for tx in transform_list:

        # save out the vtk transform to a text file as it is
        # The MNI transform reader/writer are available in vtk so use those:
        if tx.GetClassName() != 'vtkBSplineTransform':
            writer = vtk.vtkMNITransformWriter()
            writer.AddTransform(tx)
            if subject_ids is not None:
                fname = 'vtk_txform_' + str(subject_ids[idx]) + '.xfm'
            else:
                fname = 'vtk_txform_{0:05d}.xfm'.format(idx)
            writer.SetFileName(os.path.join(outdir, fname))
            writer.Write()

        # file name for itk transform written below
        if subject_ids is not None:
            fname = 'itk_txform_' + str(subject_ids[idx]) + '.tfm'
        else:
            fname = 'itk_txform_{0:05d}.tfm'.format(idx)
        fname = os.path.join(outdir, fname)
        tx_fnames.append(fname)

        # Save the itk transform as the inverse of this transform (resampling transform) and in LPS.
        # This will show the same transform in the slicer GUI as the vtk transform we internally computed
        # that is stored in the .xfm text file, above.
        # To apply our transform to resample a volume in LPS:
        # convert to RAS, use inverse of transform to resample, convert back to LPS
        if tx.GetClassName(
        ) == 'vtkThinPlateSplineTransform' or tx.GetClassName(
        ) == 'vtkBSplineTransform':
            #print 'Saving nonrigid transform displacements in ITK format'

            # Deep copy to avoid modifying input transform that will be applied to polydata
            if tx.GetClassName() == 'vtkThinPlateSplineTransform':
                tps = vtk.vtkThinPlateSplineTransform()
            else:
                tps = vtk.vtkBSplineTransform()
            tps.DeepCopy(tx)

            #extent = tps.GetCoefficients().GetExtent()
            #origin = tps.GetCoefficients().GetOrigin()
            #spacing = tps.GetCoefficients().GetSpacing()
            #dims = tps.GetCoefficients().GetDimensions()
            #print "E:", extent
            #print "O:", origin
            #print "S:", spacing
            #print "D:", dims

            # invert to get the transform suitable for resampling an image
            tps.Inverse()

            # convert the inverse spline transform from RAS to LPS
            ras_2_lps = vtk.vtkTransform()
            ras_2_lps.Scale(-1, -1, 1)
            lps_2_ras = vtk.vtkTransform()
            lps_2_ras.Scale(-1, -1, 1)
            spline_inverse_lps = vtk.vtkGeneralTransform()
            spline_inverse_lps.Concatenate(lps_2_ras)
            spline_inverse_lps.Concatenate(tps)
            spline_inverse_lps.Concatenate(ras_2_lps)

            # Now, loop through LPS space. Find the effect of the
            # inverse transform on each point. This is essentially what
            # vtk.vtkTransformToGrid() does, but this puts things into
            # LPS.

            # This low-res grid produced small differences (order of 1-2mm) when transforming
            # polydatas inside Slicer vs. in this code.
            #grid_size = [15, 15, 15]
            #grid_spacing = 10
            # This higher-res grid has fewer small numerical differences
            # grid_size = [50, 50, 50]
            # grid_spacing = 5
            # This higher-res grid has fewer small numerical differences, but files are larger
            #grid_size = [70, 70, 70]
            #grid_spacing = 3

            # This higher-res grid is sufficient to limit numerical
            # differences to under .1mm in tests.  However, files are
            # quite large (47M). As this is still much smaller than
            # the tractography files, and correctness is desired, we
            # will produce large transform files. A preferable
            # solution would be to store the forward transform we
            # compute at the grid points at which it is defined, but
            # there is no inverse flag available in the file
            # format. Therefore the inverse must be stored at high
            # resolution.
            grid_size = [105, 105, 105]
            grid_spacing = 2

            extent_0 = [
                -(grid_size[0] - 1) / 2, -(grid_size[1] - 1) / 2,
                -(grid_size[2] - 1) / 2
            ]
            extent_1 = [(grid_size[0] - 1) / 2, (grid_size[1] - 1) / 2,
                        (grid_size[2] - 1) / 2]

            origin = -grid_spacing * (numpy.array(extent_1) -
                                      numpy.array(extent_0)) / 2.0

            grid_points_LPS = list()
            grid_points_RAS = list()

            # ordering of grid points must match itk-style array order for images
            for s in range(extent_0[0], extent_1[0] + 1):
                for p in range(extent_0[1], extent_1[1] + 1):
                    for l in range(extent_0[2], extent_1[2] + 1):
                        grid_points_RAS.append([
                            -l * grid_spacing, -p * grid_spacing,
                            s * grid_spacing
                        ])
                        grid_points_LPS.append([
                            l * grid_spacing, p * grid_spacing,
                            s * grid_spacing
                        ])

            displacements_LPS = list()

            print "LPS grid for storing transform:", grid_points_LPS[
                0], grid_points_LPS[-1], grid_spacing

            lps_points = vtk.vtkPoints()
            lps_points2 = vtk.vtkPoints()
            for gp_lps in grid_points_LPS:
                lps_points.InsertNextPoint(gp_lps[0], gp_lps[1], gp_lps[2])

            spline_inverse_lps.TransformPoints(lps_points, lps_points2)
            pidx = 0
            for gp_lps in grid_points_LPS:
                pt = lps_points2.GetPoint(pidx)
                diff_lps = [
                    pt[0] - gp_lps[0], pt[1] - gp_lps[1], pt[2] - gp_lps[2]
                ]
                pidx += 1

                ## # this tested grid definition and origin were okay.
                ## diff_lps = [20,30,40]

                ## # this tested that the ordering of L,P,S is correct:
                ## diff_lps = [0, gp_lps[1], 0]
                ## diff_lps = [gp_lps[0], 0, 0]
                ## diff_lps = [0, 0, gp_lps[2]]

                ## # this tested that the ordering of grid points is correct
                ## # only the R>0, A>0, S<0 region shows a transform.
                ## if gp_lps[0] < 0 and gp_lps[1] < 0 and gp_lps[2] < 0:
                ##     diff_lps = [gp_lps[0]/2.0, 0, 0]
                ## else:
                ##     diff_lps = [0, 0, 0]

                displacements_LPS.append(diff_lps)

            # save the points and displacement vectors in ITK format.
            #print 'Saving in ITK transform format.'
            f = open(fname, 'w')
            f.write('#Insight Transform File V1.0\n')
            f.write('# Transform 0\n')
            # ITK version 3 that included an additive (!) affine transform
            #f.write('Transform: BSplineDeformableTransform_double_3_3\n')
            # ITK version 4 that does not include a second transform in the file
            f.write('Transform: BSplineTransform_double_3_3\n')
            f.write('Parameters: ')
            # "Here the data are: The bulk of the BSpline part are 3D
            # displacement vectors for each of the BSpline grid-nodes
            # in physical space, i.e. for each grid-node, there will
            # be three blocks of displacements defining dx,dy,dz for
            # all grid nodes."
            for block in [0, 1, 2]:
                for diff in displacements_LPS:
                    f.write('{0} '.format(diff[block]))

            #FixedParameters: size size size origin origin origin origin spacing spacing spacing (then direction cosines: 1 0 0 0 1 0 0 0 1)
            f.write('\nFixedParameters:')
            #f.write(' {0} {0} {0}'.format(2*sz+1))
            f.write(' {0}'.format(grid_size[0]))
            f.write(' {0}'.format(grid_size[1]))
            f.write(' {0}'.format(grid_size[2]))

            f.write(' {0}'.format(origin[0]))
            f.write(' {0}'.format(origin[1]))
            f.write(' {0}'.format(origin[2]))
            f.write(' {0} {0} {0}'.format(grid_spacing))
            f.write(' 1 0 0 0 1 0 0 0 1\n')

            f.close()
        else:
            tx_inverse = vtk.vtkTransform()
            tx_inverse.DeepCopy(tx)
            tx_inverse.Inverse()
            ras_2_lps = vtk.vtkTransform()
            ras_2_lps.Scale(-1, -1, 1)
            lps_2_ras = vtk.vtkTransform()
            lps_2_ras.Scale(-1, -1, 1)
            tx2 = vtk.vtkTransform()
            tx2.Concatenate(lps_2_ras)
            tx2.Concatenate(tx_inverse)
            tx2.Concatenate(ras_2_lps)

            three_by_three = list()
            translation = list()
            for i in range(0, 3):
                for j in range(0, 3):
                    three_by_three.append(tx2.GetMatrix().GetElement(i, j))
            translation.append(tx2.GetMatrix().GetElement(0, 3))
            translation.append(tx2.GetMatrix().GetElement(1, 3))
            translation.append(tx2.GetMatrix().GetElement(2, 3))

            f = open(fname, 'w')
            f.write('#Insight Transform File V1.0\n')
            f.write('# Transform 0\n')
            f.write('Transform: AffineTransform_double_3_3\n')
            f.write('Parameters: ')
            for el in three_by_three:
                f.write('{0} '.format(el))
            for el in translation:
                f.write('{0} '.format(el))
            f.write('\nFixedParameters: 0 0 0\n')
            f.close()

        idx += 1
    return (tx_fnames)
def convert_transform_to_vtk(transform):
    """Produce an output vtkBSplineTransform corresponding to the

    registration results. Input is a numpy array corresponding to the displacement field.
    """
    displacement_field_vtk = vtk.util.numpy_support.numpy_to_vtk(num_array=transform, deep=True, array_type=vtk.VTK_FLOAT)
    displacement_field_vtk.SetNumberOfComponents(3)
    displacement_field_vtk.SetName('DisplacementField')
    grid_image = vtk.vtkImageData()
    if (vtk.vtkVersion().GetVTKMajorVersion() >= 6.0):
        grid_image.AllocateScalars(vtk.VTK_FLOAT, 3)
        grid_image.GetPointData().SetScalars(displacement_field_vtk)
    else:
        grid_image.SetScalarTypeToFloat()
        grid_image.SetNumberOfScalarComponents(3)
        grid_image.GetPointData().SetScalars(displacement_field_vtk)
        grid_image.Update()
    #print "CONVERT TXFORM 1:", grid_image.GetExtent(), displacement_field_vtk.GetSize()

    # this is a hard-coded assumption about where the polydata is located in space.
    # other code should check that it is centered.
    # This code uses a grid of 240mm x 240mm x 240mm
    #spacing origin extent
    num_vectors = len(transform) / 3
    dims = round(numpy.power(num_vectors, 1.0/3.0))
    # This MUST correspond to the size used in congeal_multisubject update_nonrigid_grid
    #size_mm = 240.0
    size_mm = 200.0
    origin = -size_mm / 2.0
    # assume 240mm x 240mm x 240mm grid
    spacing = size_mm / (dims - 1)
    grid_image.SetOrigin(origin, origin, origin)
    grid_image.SetSpacing(spacing, spacing, spacing)
    #grid_image.SetExtent(0, dims-1.0, 0, dims-1.0, 0, dims-1.0)
    grid_image.SetDimensions(int(dims), int(dims), int(dims))
    #print "CONVERT TXFORM:", num_vectors, dims, int(dims), dims-1.0, grid_image.GetExtent(), 
    
    #print "GRID:", grid_image
    coeff = vtk.vtkImageBSplineCoefficients()
    if (vtk.vtkVersion().GetVTKMajorVersion() >= 6.0):
        coeff.SetInputData(grid_image)
    else:
        coeff.SetInput(grid_image)

    coeff.Update()
    # this was in the test code.
    coeff.UpdateWholeExtent()
    #print "TX:", transform.shape, transform, displacement_field_vtk, grid_image.GetExtent(), coeff.GetOutput().GetExtent()

    vtktrans = vtk.vtkBSplineTransform()
    if (vtk.vtkVersion().GetVTKMajorVersion() >= 6.0):
        vtktrans.SetCoefficientData(coeff.GetOutput())
    else:
        vtktrans.SetCoefficients(coeff.GetOutput())
    vtktrans.SetBorderModeToZero()

    ## print "~~~~~~~~~~~~~~~~~~~~~~~~"
    ## print "COEFF:",  coeff.GetOutput()
    ## print "*********"
    ## print "COEFF2:", vtktrans.GetCoefficients()
    ## print "======="
    
    return vtktrans