def warp_atlas_subject(subject,
                       path,
                       labels,
                       input_image,
                       input_transform_prefix,
                       output_path,
                       exec_options={}):
    """
    Warp a training set subject's labels to input_image.
    """
    a_transform_prefix = os.path.join(path, subject + '/WMnMPRAGE')
    output_path = os.path.join(output_path, subject)
    try:
        os.mkdir(output_path)
    except OSError:
        # Exists
        pass
    combined_warp = os.path.join(output_path, 'Warp.nii.gz')
    if not os.path.exists(combined_warp):
        check_run(combined_warp,
                  ants_compose_a_to_b,
                  a_transform_prefix,
                  b_path=input_image,
                  b_transform_prefix=input_transform_prefix,
                  output=combined_warp,
                  **exec_options)
    output_labels = {}
    # OPT parallelize, or merge parallelism with subject level
    for label in labels:
        label_fname = os.path.join(path, subject, 'sanitized_rois',
                                   label + '.nii.gz')
        warped_label = os.path.join(output_path, label + '.nii.gz')
        switches = '--use-NN'
        check_run(warped_label,
                  ants_apply_only_warp,
                  template=input_image,
                  input_image=label_fname,
                  input_warp=combined_warp,
                  output_image=warped_label,
                  switches=switches,
                  **exec_options)
        output_labels[label] = warped_label
    # Warp anatomical WMnMPRAGE_bias_corr too
    # TODO merge this into previous for loop to be DRY?
    output_labels['WMnMPRAGE_bias_corr'] = output_image = os.path.join(
        output_path, image_name)
    if not os.path.exists(output_labels['WMnMPRAGE_bias_corr']):
        print output_labels['WMnMPRAGE_bias_corr']
        check_run(output_image,
                  ants_apply_only_warp,
                  template=input_image,
                  input_image=os.path.join(path, subject, image_name),
                  input_warp=combined_warp,
                  output_image=output_image,
                  switches='--use-BSpline',
                  **exec_options)
    return output_labels
def main(args, temp_path, pool):
    input_image = orig_input_image = args.input_image

    # assigning default value of mask
    mask = mask_93

    #setting up output path
    if args.output_path:
        output_path = args.output_path
    else:
        output_path = os.path.dirname(orig_input_image)

    #setting up the ROIs
    if roi['param_all'] in args.roi_names:
        labels = list(roi['label_names'])
    else:
        roi_dict = dict(zip(roi['param_names'], roi['label_names']))
        labels = [roi_dict[el] for el in args.roi_names]

    #setting up the template
    if args.algorithm == "v2":
        if args.template is not None and args.mask is not None:
            template = args.template
            mask = args.mask
            print "Custom template and mask"
        elif args.template is not None and args.mask is None:
            sys.exit(
                "!!!!!!! Both template and mask need to be specified simultaneously and they need to be of the same size !!!!!!!"
            )
        elif args.template is None and args.mask is not None:
            sys.exit(
                "!!!!!!! Both template and mask need to be specified simultaneously and they need to be of the same size !!!!!!!"
            )
        else:
            template = template_93
            mask = mask_93
            print "Algorithm is v2"
    elif args.algorithm == "v1":
        sys.exit("!!!!!!! v1 algorithm not yet implemented !!!!!!!")
    elif args.algorithm == "v0":
        template = orig_template
        print "Template is origtemplate.nii.gz"
    else:
        sys.exit("!!!!!!! Algorithm incorrectly specified !!!!!!!")

    # print 'Template being used is'
    # print os.path.abspath(template)

    # TODO prevent both jointfusion and majority voting being set

# if args.jointfusion is None:
# print "args.jointfusion has been set (value is %s)" % args.jointfusion
# if args.majorityvoting is None:
# print "args.majorityvoting has been set (value is %s)" % args.majorityvoting
# sys.exit("!!!!!!! Only one label fusion can be selected at any time (default is antsJointFusion) !!!!!!!")

    if args.warp:
        warp_path = args.warp
    else:
        # TODO remove this as the default behavior, instead do ANTS?
        head, tail = os.path.split(input_image)
        tail = tail.replace('.nii', '').replace('.gz', '')  #split('.', 1)[0]
        warp_path = os.path.join(temp_path, tail)

    t = time.time()

    if args.algorithm == "v2":
        # Crop the input
        # Affine registering input to template
        print "1.   Linear Registration of input and template \n"
        t1 = time.time()
        ants_linear_registration(orig_template, orig_input_image)
        print '1.   ----- Time Elapsed: %s  \n \n' % timedelta(
            seconds=time.time() - t1)

        mask_input = os.path.join(os.path.dirname(orig_input_image),
                                  'mask_inp.nii.gz')

        print "2.   Transform mask from template space to input space \n"
        t1 = time.time()
        # Transform mask from template space to input space
        ants_WarpImageMultiTransform(mask, mask_input, orig_input_image)
        print '2.   ----- Time Elapsed: %s  \n \n' % timedelta(
            seconds=time.time() - t1)

        file_name = os.path.basename(orig_input_image)
        index_of_dot = file_name.index('.')
        file_name_without_extension = file_name[:index_of_dot]
        input_image = os.path.join(
            os.path.dirname(orig_input_image),
            'crop_' + file_name_without_extension + '.nii.gz')

        # Cropping input using this mask
        print "3.   Cropping input using this mask \n"
        t1 = time.time()
        parallel_command(
            crop_by_mask(orig_input_image, input_image, mask_input))
        print '3.   ----- Time Elapsed: %s  \n \n' % timedelta(
            seconds=time.time() - t1)

    # FSL automatically converts .nii to .nii.gz
    sanitized_image = os.path.join(
        temp_path,
        os.path.basename(input_image) +
        ('.gz' if input_image.endswith('.nii') else ''))
    if not os.path.exists(sanitized_image):
        input_image = sanitize_input(input_image, sanitized_image,
                                     parallel_command)
        if args.right:

            print "4.   lipping along L-R \n"
            t1 = time.time()
            flip_lr(input_image, input_image, parallel_command)
            print '4.   ----- Time Elapsed: %s  \n \n' % timedelta(
                seconds=time.time() - t1)

        print "5.   Correcting bias \n"
        t1 = time.time()
        bias_correct(input_image, input_image, **exec_options)
        print '5.   ----- Time Elapsed: %s  \n \n' % timedelta(
            seconds=time.time() - t1)

    else:
        print 'Skipped, using %s' % sanitized_image
        input_image = sanitized_image

    if args.forcereg or not check_warps(warp_path):
        if args.warp:
            print 'Saving output as %s' % warp_path
        else:
            warp_path = os.path.join(temp_path, tail)
            print 'Saving output to temporary path.'
        ants_nonlinear_registration(template, input_image, warp_path,
                                    **exec_options)
    else:
        print 'Skipped, using %sInverseWarp.nii.gz and %sAffine.txt' % (
            warp_path, warp_path)

    # generating the warped output
    print "6.   Warping prior labels and images \n \n"
    t1 = time.time()
    registered = os.path.join(temp_path, 'registered.nii.gz')
    cmd = 'WarpImageMultiTransform 3 %s %s -R %s %sWarp.nii.gz %sAffine.txt' % (
        input_image, registered, template, warp_path, warp_path)
    parallel_command(cmd)

    print '6.   --- Elapsed: %s  \n \n' % timedelta(seconds=time.time() - t1)

    t1 = time.time()
    # TODO should probably use output from warp_atlas_subject instead of hard coding paths in create_atlas
    # TODO make this more parallel
    warped_labels = pool.map(
        partial(
            warp_atlas_subject,
            path=prior_path,
            # TODO cleanup this hack to always have whole thalamus so can estimate mask
            labels=set(labels + ['1-THALAMUS']),
            input_image=input_image,
            input_transform_prefix=warp_path,
            output_path=temp_path,
            exec_options=exec_options,
        ),
        subjects)
    warped_labels = {
        label: {subj: d[label]
                for subj, d in zip(subjects, warped_labels)}
        for label in warped_labels[0]
    }
    # # print '--- Forming subject-registered atlases. --- Elapsed: %s' % timedelta(seconds=time.time()-t)
    # atlases = pool.map(partial(create_atlas, path=temp_path, subjects=subjects, target='', echo=exec_options['echo']),
    # [{'label': label, 'output_atlas': os.path.join(temp_path, label+'_atlas.nii.gz')} for label in warped_labels])
    # atlases = dict(zip(warped_labels, zip(*atlases)[0]))
    # atlas_image = atlases['WMnMPRAGE_bias_corr']
    atlas_images = warped_labels['WMnMPRAGE_bias_corr'].values()

    print '!! --- Performing label fusion. --- Elapsed: %s  \n \n' % timedelta(
        seconds=time.time() - t1)
    # FIXME use whole-brain template registration optimized parameters instead, these are from crop pipeline
    optimal_picsl = optimal['PICSL']
    # for k, v in warped_labels.iteritems():
    #     print k, v
    # for label in labels:
    #     print optimal_picsl[label]
    if args.jointfusion:
        pool.map(partial(label_fusion_picsl, input_image, atlas_images), [
            dict(atlas_labels=warped_labels[label].values(),
                 output_label=os.path.join(temp_path, label + '.nii.gz'),
                 rp=optimal_picsl[label]['rp'],
                 rs=optimal_picsl[label]['rs'],
                 beta=optimal_picsl[label]['beta'],
                 **exec_options) for label in labels
        ])
    elif args.majorityvoting:
        pool.map(partial(label_fusion_majority), [
            dict(atlas_labels=warped_labels[label].values(),
                 output_label=os.path.join(temp_path, label + '.nii.gz'),
                 rp=optimal_picsl[label]['rp'],
                 rs=optimal_picsl[label]['rs'],
                 beta=optimal_picsl[label]['beta'],
                 **exec_options) for label in labels
        ])
    else:
        # Estimate mask to restrict computation
        mask = os.path.join(temp_path, 'mask.nii.gz')
        check_run(
            mask,
            conservative_mask,
            warped_labels['1-THALAMUS'].values(),
            mask,
            dilation=10,
        )
        pool.map(partial(label_fusion_picsl_ants, input_image, atlas_images), [
            dict(atlas_labels=warped_labels[label].values(),
                 output_label=os.path.join(temp_path, label + '.nii.gz'),
                 rp=optimal_picsl[label]['rp'],
                 rs=optimal_picsl[label]['rs'],
                 beta=optimal_picsl[label]['beta'],
                 mask=mask,
                 **exec_options) for label in labels
        ])

    print '!! --- Performing label fusion. --- Elapsed: %s  \n \n' % timedelta(
        seconds=time.time() - t1)

    # STEPS
    # pool_small.map(partial(label_fusion, input_image=input_image, image_atlas=atlases['WMnMPRAGE_bias_corr'], echo=exec_options['echo']),
    #     [{
    #         'label_atlas': atlases[label],
    #         'output_label': os.path.join(output_path, label+'.nii.gz'),
    #         'sigma': optimal_steps[label]['steps_sigma'],
    #         'X': optimal_steps[label]['steps_X'],
    #         'mrf': optimal_steps[label]['steps_mrf'],
    #     } for label in labels]
    # )
    # for label in labels:
    #     print {
    #         'label': label,
    #         'sigma': optimal_steps[label]['steps_sigma'],
    #         'X': optimal_steps[label]['steps_X'],
    #         'mrf': optimal_steps[label]['steps_mrf'],
    #     }
    #     partial_fusion = partial(label_fusion, input_image=input_image, image_atlas=atlases['WMnMPRAGE_bias_corr'], echo=exec_options['echo'])
    #     label_fusion_args = {
    #         'label_atlas': atlases[label],
    #         'output_label': os.path.join(output_path, label+'.nii.gz'),
    #         'sigma': optimal_steps[label]['steps_sigma'],
    #         'X': optimal_steps[label]['steps_X'],
    #         'mrf': optimal_steps[label]['steps_mrf'],
    #     }
    #     partial_fusion(**label_fusion_args)

    files = [(os.path.join(temp_path, label + '.nii.gz'),
              os.path.join(output_path, label + '.nii.gz'))
             for label in labels]
    if args.right:
        pool.map(flip_lr, files)
        files = [(os.path.join(output_path, label + '.nii.gz'),
                  os.path.join(output_path, label + '.nii.gz'))
                 for label in labels]
    # Resort output to original ordering
    pool.map(parallel_command, [
        '%s %s %s %s' % (os.path.join(
            this_path, 'swapdimlike.py'), in_file, orig_input_image, out_file)
        for in_file, out_file in files
    ])

    # get the vlp file path for splitting
    vlp_file = os.path.join(output_path, '6-VLP.nii.gz')

    # Re-orient to standard space - LR PA IS format
    san_vlp_file = os.path.join(output_path, 'san_6-VLP.nii.gz')
    input_image1 = sanitize_input(vlp_file, san_vlp_file, parallel_command)

    # get the sanitized vlp for processing
    input_nii = nibabel.load(input_image1)
    data = input_nii.get_data()
    hdr = input_nii.get_header()
    affine = input_nii.get_affine()

    # Coronal axis for RL PA IS orientation
    vlps = split_roi(data, None, 2)
    for fname, sub_vlp in zip(['6_VLPv.nii.gz', '6_VLPd.nii.gz'], vlps):
        output_nii = nibabel.Nifti1Image(sub_vlp, affine, hdr)
        output_nii.to_filename(os.path.join(os.path.dirname(out_file), fname))

    print '--- Finished --- Elapsed: %s' % timedelta(seconds=time.time() - t)
示例#3
0
def main(args, temp_path, pool):
    input_image = orig_input_image = args.input_image
    output_path = args.output_path

    if roi['param_all'] in args.roi_names:
        labels = list(roi['label_names'])
    else:
        roi_dict = dict(zip(roi['param_names'], roi['label_names']))
        labels = [roi_dict[el] for el in args.roi_names]

    if args.template_size == 61:
        template = template_61
        mask = mask_61
        if args.right:
            sys.exit("!!!!!!! Feature not implemented (61 with -R option) !!!!!!!")
    else:
        template = template_93
        mask = mask_93
	
	# TODO prevent both jointfusion and majority voting being set
	# if args.jointfusion is None:
		# print "args.jointfusion has been set (value is %s)" % args.jointfusion
		# if args.majorityvoting is None:
			# print "args.majorityvoting has been set (value is %s)" % args.majorityvoting
			# sys.exit("!!!!!!! Only one label fusion can be selected at any time (default is antsJointFusion) !!!!!!!")
	
    if args.warp:
        warp_path = args.warp
    else:
        # TODO remove this as the default behavior, instead do ANTS?
        head, tail = os.path.split(input_image)
        tail = tail.replace('.nii', '').replace('.gz', '') #split('.', 1)[0]
        warp_path = os.path.join(temp_path, tail)

    t = time.time()

    # Crop the input

    # Affine registering input to template
    ants_linear_registration(orig_template, orig_input_image)

    mask_input = os.path.join(os.path.dirname(orig_input_image), 'mask_inp.nii.gz')

    # Transform mask from template space to input space
    ants_WarpImageMultiTransform(mask, mask_input, orig_input_image)
	
    file_name = os.path.basename(orig_input_image)
    index_of_dot = file_name.index('.')
    file_name_without_extension = file_name[:index_of_dot]

    input_image = os.path.join(os.path.dirname(orig_input_image), 'crop_'+file_name_without_extension+'.nii.gz')

    # Cropping input using this mask
    parallel_command(crop_by_mask(orig_input_image, input_image, mask_input))

    # FSL automatically converts .nii to .nii.gz
    sanitized_image = os.path.join(temp_path, os.path.basename(input_image) + ('.gz' if input_image.endswith('.nii') else ''))
    print '--- Reorienting image. --- Elapsed: %s' % timedelta(seconds=time.time()-t)
    if not os.path.exists(sanitized_image):
        input_image = sanitize_input(input_image, sanitized_image, parallel_command)
        if args.right:
            print '--- Flipping along L-R. --- Elapsed: %s' % timedelta(seconds=time.time()-t)
            flip_lr(input_image, input_image, parallel_command)
        print '--- Correcting bias. --- Elapsed: %s' % timedelta(seconds=time.time()-t)
        bias_correct(input_image, input_image, **exec_options)
    else:
        print 'Skipped, using %s' % sanitized_image
        input_image = sanitized_image


    print '--- Registering to mean brain template. --- Elapsed: %s' % timedelta(seconds=time.time()-t)
    if args.forcereg or not check_warps(warp_path):
        if args.warp:
            print 'Saving output as %s' % warp_path
        else:
            warp_path = os.path.join(temp_path, tail)
            print 'Saving output to temporary path.'
        ants_nonlinear_registration(template, input_image, warp_path, **exec_options)
    else:
        print 'Skipped, using %sInverseWarp.nii.gz and %sAffine.txt' % (warp_path, warp_path)
	
    registered = os.path.join(temp_path, 'registered.nii.gz')    
    cmd = 'WarpImageMultiTransform 3 %s %s -R %s %sWarp.nii.gz %sAffine.txt' % (input_image, registered, template, warp_path, warp_path)
    parallel_command(cmd)

    print '--- Warping prior labels and images. --- Elapsed: %s' % timedelta(seconds=time.time()-t)
    # TODO should probably use output from warp_atlas_subject instead of hard coding paths in create_atlas
    # TODO make this more parallel
    warped_labels = pool.map(partial(
        warp_atlas_subject,
        path=prior_path,
        # TODO cleanup this hack to always have whole thalamus so can estimate mask
        labels=set(labels + ['1-THALAMUS']),
        input_image=input_image,
        input_transform_prefix=warp_path,
        output_path=temp_path,
        exec_options=exec_options,
    ), subjects)
    warped_labels = {label: {subj: d[label] for subj, d in zip(subjects, warped_labels)} for label in warped_labels[0]}
    # print '--- Forming subject-registered atlases. --- Elapsed: %s' % timedelta(seconds=time.time()-t)
    # atlases = pool.map(partial(create_atlas, path=temp_path, subjects=subjects, target='', echo=exec_options['echo']),
        # [{'label': label, 'output_atlas': os.path.join(temp_path, label+'_atlas.nii.gz')} for label in warped_labels])
    # atlases = dict(zip(warped_labels, zip(*atlases)[0]))
    # atlas_image = atlases['WMnMPRAGE_bias_corr']
    atlas_images = warped_labels['WMnMPRAGE_bias_corr'].values()


    print '--- Performing label fusion. --- Elapsed: %s' % timedelta(seconds=time.time() - t)
    # FIXME use whole-brain template registration optimized parameters instead, these are from crop pipeline
    optimal_picsl = optimal['PICSL']
    # for k, v in warped_labels.iteritems():
    #     print k, v
    # for label in labels:
    #     print optimal_picsl[label]
    if args.jointfusion:
        pool.map(partial(label_fusion_picsl, input_image, atlas_images),
            [dict(
                atlas_labels=warped_labels[label].values(),
                output_label=os.path.join(temp_path, label+'.nii.gz'),
                rp=optimal_picsl[label]['rp'],
                rs=optimal_picsl[label]['rs'],
                beta=optimal_picsl[label]['beta'],
                **exec_options
            ) for label in labels])
    elif args.majorityvoting:
	   pool.map(partial(label_fusion_majority),
            [dict(
                atlas_labels=warped_labels[label].values(),
                output_label=os.path.join(temp_path, label+'.nii.gz'),
                rp=optimal_picsl[label]['rp'],
                rs=optimal_picsl[label]['rs'],
                beta=optimal_picsl[label]['beta'],
                **exec_options
            ) for label in labels])
    else:
        # Estimate mask to restrict computation
        mask = os.path.join(temp_path, 'mask.nii.gz')
        check_run(
            mask,
            conservative_mask,
            warped_labels['1-THALAMUS'].values(),
            mask,
            dilation=10,
        )
        pool.map(partial(label_fusion_picsl_ants, input_image, atlas_images),
            [dict(
                atlas_labels=warped_labels[label].values(),
                output_label=os.path.join(temp_path, label + '.nii.gz'),
                rp=optimal_picsl[label]['rp'],
                rs=optimal_picsl[label]['rs'],
                beta=optimal_picsl[label]['beta'],
                mask=mask,
                **exec_options
            ) for label in labels])
    # STEPS
    # pool_small.map(partial(label_fusion, input_image=input_image, image_atlas=atlases['WMnMPRAGE_bias_corr'], echo=exec_options['echo']),
    #     [{
    #         'label_atlas': atlases[label],
    #         'output_label': os.path.join(output_path, label+'.nii.gz'),
    #         'sigma': optimal_steps[label]['steps_sigma'],
    #         'X': optimal_steps[label]['steps_X'],
    #         'mrf': optimal_steps[label]['steps_mrf'],
    #     } for label in labels]
    # )
    # for label in labels:
    #     print {
    #         'label': label,
    #         'sigma': optimal_steps[label]['steps_sigma'],
    #         'X': optimal_steps[label]['steps_X'],
    #         'mrf': optimal_steps[label]['steps_mrf'],
    #     }
    #     partial_fusion = partial(label_fusion, input_image=input_image, image_atlas=atlases['WMnMPRAGE_bias_corr'], echo=exec_options['echo'])
    #     label_fusion_args = {
    #         'label_atlas': atlases[label],
    #         'output_label': os.path.join(output_path, label+'.nii.gz'),
    #         'sigma': optimal_steps[label]['steps_sigma'],
    #         'X': optimal_steps[label]['steps_X'],
    #         'mrf': optimal_steps[label]['steps_mrf'],
    #     } 
    #     partial_fusion(**label_fusion_args)
	
    files = [(os.path.join(temp_path, label + '.nii.gz'), os.path.join(output_path, label + '.nii.gz')) for label in labels]
    if args.right:
        pool.map(flip_lr, files)
        files = [(os.path.join(output_path, label + '.nii.gz'), os.path.join(output_path, label + '.nii.gz')) for label in labels]
    # Resort output to original ordering
    pool.map(parallel_command,
        ['%s %s %s %s' % (os.path.join(this_path, 'swapdimlike.py'), in_file, orig_input_image, out_file) for in_file, out_file in files])
    print '--- Finished --- Elapsed: %s' % timedelta(seconds=time.time() - t)