Пример #1
0
def optimize_despot1_protocol(params, start_range, t1, m0, L, wrt_in, protocol_frameworks, cost_types, **kwargs):
    """
    Optimize the DESPOT1 protocol of collecting SPGR images to estimate a subset of (t1, t2, m0, off_resonance_phase).
    
    Arguments:
        params -- the parameters that are free to vary for each sequence, sequences may be removed to fit with the protocol framework
        start_range -- a list of tuples containing the solver's initial point start range for each of the params.
        L -- the lambda weights in the cost function, in corresponding order with "wrt_in"
        wrt_in -- the parameters being estimated, the strings must match the actual signal equation function input names exactly
        protocol_frameworks -- a list of dictionaries for example [{'spgr': 2, 'ssfp': 2}], note that params and start_range will be properly filtered if the sequence is missing from a framework
        cost_types -- a dictionary containing a 'combine' function for the 
    """
    store_solutions = OrderedDict()
    for n_images in protocol_frameworks:
        print '\n\n========== Solving %s ==========' % (n_images, )
        p = {k: v for k, v in params.items() if k in n_images}
        sr = {k: v for k, v in start_range.items() if k in n_images}
        DESPOT1_Cost_Function = construct_despot1_cost_function(n_images=n_images, params=p, start_range=sr, **kwargs)   
        print 'Constraints:', len(DESPOT1_Cost_Function.constraints), DESPOT1_Cost_Function.constraints
        partial_cost_func = partial(
            DESPOT1_Cost_Function,
            t1=t1,
            m0=m0,
            L=L,
            wrt_in=wrt_in,
        )
        # Call this first with arbitrary input to cache the compiled function and avoid MultiStart compiling many times.
        print 'Compile Theano for floats'
        try:
            partial_cost_func(np.random.random(len(DESPOT1_Cost_Function.start_range)))
        except spla.LinAlgError:
            pass

        # Only SLSQP can handle equality and inequality constraints.
        M = MultiStart(
            100,
            DESPOT1_Cost_Function.start_range,
            constraints=DESPOT1_Cost_Function.constraints,
            method='SLSQP',
        )

        for i, (cost_type_name, cost_type) in enumerate(cost_types.iteritems()):
            print 'Cost Type', cost_type_name, cost_type
            res = M.solve(
                parallel_pool=0,
                fun=partial(partial_cost_func, **cost_type),
                label=str(n_images)
            )
            if res:
                print '  Top Solution: %s\n' % DESPOT1_Cost_Function._parameter_values(res.x)
            store_solutions['%s %s' % (n_images, cost_type_name)] = M.candidates
    return store_solutions
def optimize_diffusion_protocol(params, start_range, D, m0, L, wrt_in, protocol_frameworks, cost_types, **kwargs):
    """
    Optimize the diffusion protocol of collecting images at different b-values to estimate a m0 and D.
    
    Arguments:
        params -- the parameters that are free to vary for each sequence, sequences may be removed to fit with the protocol framework
        start_range -- 
        L -- 
        wrt_in -- 
        wrt_out_ssfp -- 
        protocol_frameworks -- a list of dictionaries for example [{'spgr': 2, 'ssfp': 2}], note that params and start_range will be properly filtered if the sequence is missing from a framework
        cost_types -- 
    """
    store_solutions = OrderedDict()
    for n_images in protocol_frameworks:
        print "\n\n========== Solving %s ==========" % (n_images,)
        p = {k: v for k, v in params.items() if k in n_images}
        sr = {k: v for k, v in start_range.items() if k in n_images}
        Cost_Function = construct_diffusion_cost_function(n_images=n_images, params=p, start_range=sr, **kwargs)
        print "Constraints:", len(Cost_Function.constraints), Cost_Function.constraints
        partial_cost_func = partial(Cost_Function, D=D, m0=m0, L=L, wrt_in=wrt_in, regularization=0.0)
        print "Compile Theano for floats"
        try:
            partial_cost_func(np.random.random(len(Cost_Function.start_range)))
        except spla.LinAlgError:
            pass

        M = MultiStart(
            100,
            Cost_Function.start_range,
            # constraints=Cost_Function.constraints,
            method="L-BFGS-B",
        )

        for i, (cost_type_name, cost_type) in enumerate(cost_types.iteritems()):
            print "Cost Type", cost_type_name, cost_type
            res = M.solve(parallel_pool=0, fun=partial(partial_cost_func, **cost_type), label=str(n_images))
            if res:
                top_solution = Cost_Function._parameter_values(res.x)
                print "  Top Solution: %s %s\n" % (res.fun, top_solution)
            store_solutions["%s %s" % (n_images, cost_type_name)] = M.candidates
    return store_solutions