Exemplo n.º 1
0
    def errf(x):

        tw_loc = deepcopy(tw)
        tw0 = deepcopy(tw)
        '''
        parameter to be varied is determined by variable class
        '''
        for i in range(len(vars)):
            if isinstance(vars[i], Drift):
                if x[i] < 0:
                    # print('negative length in match')
                    return weights('negative_length')

                vars[i].l = x[i]
                vars[i].transfer_map = lat.method.create_tm(vars[i])
            if isinstance(vars[i], Quadrupole):
                vars[i].k1 = x[i]
                vars[i].transfer_map = lat.method.create_tm(vars[i])
            if isinstance(vars[i], Solenoid):
                vars[i].k = x[i]
                vars[i].transfer_map = lat.method.create_tm(vars[i])
            if isinstance(vars[i], (RBend, SBend, Bend)):
                if vary_bend_angle:
                    vars[i].angle = x[i]
                else:
                    vars[i].k1 = x[i]
                vars[i].transfer_map = lat.method.create_tm(vars[i])
            if isinstance(vars[i], list):
                if isinstance(vars[i][0], Twiss) and isinstance(
                        vars[i][1], str):
                    k = vars[i][1]
                    tw_loc.__dict__[k] = x[i]
            if isinstance(
                    vars[i], tuple
            ):  # all quads strength in tuple varied simultaneously
                for v in vars[i]:
                    v.k1 = x[i]
                    v.transfer_map = lat.method.create_tm(v)

        err = 0.0
        if "periodic" in constr.keys():
            if constr["periodic"]:
                tw_loc = periodic_twiss(tw_loc,
                                        lattice_transfer_map(lat, tw.E))
                tw0 = deepcopy(tw_loc)
                if tw_loc is None:
                    print("########")
                    return weights('periodic')

        # save reference points where equality is asked

        ref_hsh = {}  # penalties on two-point inequalities
        for e in constr.keys():
            if e == 'periodic': continue
            if e == 'total_len': continue
            for k in constr[e].keys():
                if isinstance(constr[e][k], list):
                    if constr[e][k][0] == '->':
                        # print 'creating reference to', constr[e][k][1].id
                        ref_hsh[constr[e][k][1]] = {k: 0.0}
        # evaluating global and point penalties

        tw_loc.s = 0

        for e in lat.sequence:

            tw_loc = e.transfer_map * tw_loc

            if 'global' in constr.keys():
                for c in constr['global'].keys():
                    if isinstance(constr['global'][c], list):
                        v1 = constr['global'][c][1]
                        if constr['global'][c][0] == '<':
                            if tw_loc.__dict__[c] > v1:
                                err = err + weights(k) * (tw_loc.__dict__[c] -
                                                          v1)**2
                        if constr['global'][c][0] == '>':
                            if tw_loc.__dict__[c] < v1:
                                err = err + weights(k) * (tw_loc.__dict__[c] -
                                                          v1)**2
            if 'delta' in constr.keys():
                if e in constr['delta'].keys():
                    tw_k = constr['delta'][e][0]
                    constr['delta'][e][1] = tw_loc.__dict__[tw_k]
            if e in ref_hsh.keys():
                ref_hsh[e] = deepcopy(tw_loc)

            if e in constr.keys():

                for k in constr[e].keys():
                    if isinstance(constr[e][k], list):
                        v1 = constr[e][k][1]

                        if constr[e][k][0] == '<':
                            if tw_loc.__dict__[k] > v1:
                                err = err + weights(k) * (tw_loc.__dict__[k] -
                                                          v1)**2
                        if constr[e][k][0] == '>':
                            if tw_loc.__dict__[k] < v1:
                                err = err + weights(k) * (tw_loc.__dict__[k] -
                                                          v1)**2
                        if constr[e][k][0] == 'a<':
                            if np.abs(tw_loc.__dict__[k]) > v1:
                                err = err + weights(k) * (tw_loc.__dict__[k] -
                                                          v1)**2
                        if constr[e][k][0] == 'a>':
                            if np.abs(tw_loc.__dict__[k]) < v1:
                                err = err + weights(k) * (tw_loc.__dict__[k] -
                                                          v1)**2

                        if constr[e][k][0] == '->':
                            try:
                                if len(constr[e][k]) > 2:
                                    dv1 = float(constr[e][k][2])
                                else:
                                    dv1 = 0.0
                                err += (tw_loc.__dict__[k] -
                                        (ref_hsh[v1].__dict__[k] + dv1))**2

                                if tw_loc.__dict__[k] < v1:
                                    err = err + (tw_loc.__dict__[k] - v1)**2
                            except:
                                print(
                                    'constraint error: rval should precede lval in lattice'
                                )

                        if tw_loc.__dict__[k] < 0:
                            err += (tw_loc.__dict__[k] - v1)**2
                    elif isinstance(constr[e][k], str):
                        pass
                    else:
                        err = err + weights(k) * (constr[e][k] -
                                                  tw_loc.__dict__[k])**2
        if "total_len" in constr.keys():
            total_len = constr["periodic"]
            err = err + weights('total_len') * (tw_loc.s - total_len)**2

        if 'delta' in constr.keys():
            delta_dict = constr['delta']
            elems = []
            for e in delta_dict.keys():
                if isinstance(e, Element):
                    elems.append(e)
            delta_err = delta_dict["weight"] * (delta_dict[elems[0]][1] -
                                                delta_dict[elems[1]][1] -
                                                delta_dict["val"])**2
            err = err + delta_err

        if min_i5:
            ''' evaluating integral parameters
            '''
            I1, I2, I3, I4, I5 = radiation_integrals(lat, tw0, nsuperperiod=1)
            err += I5 * weights('i5')

            Je = 2 + I4 / I2
            Jx = 1 - I4 / I2
            Jy = 1

            if Je < 0 or Jx < 0 or Jy < 0: err = 100000.0

        # c1, c2 = natural_chromaticity(lat, tw0)
        # err += ( c1**2 + c2**2) * 1.e-6

        if verbose:
            print('iteration error:', err)
        return err
Exemplo n.º 2
0
    def errf(x):
        p_array0 = deepcopy(p_array)
        tws = get_envelope(p_array0, bounds=bounds)
        tw_loc = deepcopy(tws)
        tw0 = deepcopy(tws)
        '''
        parameter to be varied is determined by variable class
        '''
        for i in range(len(vars)):
            if vars[i].__class__ == Drift:
                if x[i] < 0:
                    # print('negative length in match')
                    return weights('negative_length')
                    pass
                vars[i].l = x[i]
                vars[i].transfer_map = lat.method.create_tm(vars[i])
            if vars[i].__class__ == Quadrupole:
                vars[i].k1 = x[i]
                vars[i].transfer_map = lat.method.create_tm(vars[i])
            if vars[i].__class__ == Solenoid:
                vars[i].k = x[i]
                vars[i].transfer_map = lat.method.create_tm(vars[i])
            if vars[i].__class__ in [RBend, SBend, Bend]:
                if vary_bend_angle:
                    vars[i].angle = x[i]
                else:
                    vars[i].k1 = x[i]
                vars[i].transfer_map = lat.method.create_tm(vars[i])
            if vars[i].__class__ == list:
                if vars[i][0].__class__ == Twiss and vars[i][
                        1].__class__ == str:
                    k = vars[i][1]
                    tw_loc.__dict__[k] = x[i]
            if vars[i].__class__ == tuple:  # all quads strength in tuple varied simultaneously
                for v in vars[i]:
                    v.k1 = x[i]
                    v.transfer_map = lat.method.create_tm(v)

        err = 0.0
        if "periodic" in constr.keys():
            if constr["periodic"] is True:
                tw_loc = periodic_twiss(tw_loc,
                                        lattice_transfer_map(lat, tw_loc.E))
                tw0 = deepcopy(tw_loc)
                if tw_loc is None:
                    print("########")
                    return weights('periodic')

        # save reference points where equality is asked

        ref_hsh = {}  # penalties on two-point inequalities

        for e in constr.keys():
            if e == 'periodic': continue
            if e == 'total_len': continue
            for k in constr[e].keys():
                if constr[e][k].__class__ == list:
                    if constr[e][k][0] == '->':
                        # print 'creating reference to', constr[e][k][1].id
                        ref_hsh[constr[e][k][1]] = {k: 0.0}

        # print 'references:', ref_hsh.keys()

        # evaluating global and point penalties

        # tw_loc.s = 0
        # print("start = ", get_envelope(p_array0))
        navi.go_to_start()
        tws_list, p_array0 = track(lat,
                                   p_array0,
                                   navi,
                                   print_progress=False,
                                   bounds=bounds)
        s = np.array([tw.s for tw in tws_list])
        # print("stop = ", tws_list[-1])
        L = 0.
        for e in lat.sequence:
            indx = (np.abs(s - L)).argmin()

            L += e.l
            tw_loc = tws_list[indx]
            if 'global' in constr.keys():
                # print 'there is a global constraint', constr['global'].keys()
                for c in constr['global'].keys():
                    if constr['global'][c].__class__ == list:
                        # print 'list'
                        v1 = constr['global'][c][1]
                        if constr['global'][c][0] == '<':
                            if tw_loc.__dict__[c] > v1:
                                err = err + weights(k) * (tw_loc.__dict__[c] -
                                                          v1)**2
                        if constr['global'][c][0] == '>':
                            # print '> constr'
                            if tw_loc.__dict__[c] < v1:
                                err = err + weights(k) * (tw_loc.__dict__[c] -
                                                          v1)**2

            if e in ref_hsh.keys():
                # print 'saving twiss for', e.id
                ref_hsh[e] = deepcopy(tw_loc)

            if e in constr.keys():

                for k in constr[e].keys():
                    # print(k)
                    if constr[e][k].__class__ == list:
                        v1 = constr[e][k][1]

                        if constr[e][k][0] == '<':
                            if tw_loc.__dict__[k] > v1:
                                err = err + weights(k) * (tw_loc.__dict__[k] -
                                                          v1)**2
                        if constr[e][k][0] == '>':
                            if tw_loc.__dict__[k] < v1:
                                err = err + weights(k) * (tw_loc.__dict__[k] -
                                                          v1)**2
                        if constr[e][k][0] == 'a<':
                            if np.abs(tw_loc.__dict__[k]) > v1:
                                err = err + weights(k) * (tw_loc.__dict__[k] -
                                                          v1)**2
                        if constr[e][k][0] == 'a>':
                            if np.abs(tw_loc.__dict__[k]) < v1:
                                err = err + weights(k) * (tw_loc.__dict__[k] -
                                                          v1)**2

                        if constr[e][k][0] == '->':
                            try:
                                # print 'huh', k, e.id, float(constr[e][k][2])

                                if len(constr[e][k]) > 2:
                                    dv1 = float(constr[e][k][2])
                                else:
                                    dv1 = 0.0
                                # print 'weiter'
                                err += (tw_loc.__dict__[k] -
                                        (ref_hsh[v1].__dict__[k] + dv1))**2

                                if tw_loc.__dict__[k] < v1:
                                    err = err + (tw_loc.__dict__[k] - v1)**2
                            except:
                                print(
                                    'constraint error: rval should precede lval in lattice'
                                )

                        if tw_loc.__dict__[k] < 0:
                            # print 'negative constr (???)'
                            err += (tw_loc.__dict__[k] - v1)**2

                    else:
                        # print "safaf", constr[e][k] , tw_loc.__dict__[k], k, e.id, x
                        err = err + weights(k) * (constr[e][k] -
                                                  tw_loc.__dict__[k])**2
                        # print err
        for v in vars:
            print(v.id, v.k1)
        if "total_len" in constr.keys():
            total_len = constr["periodic"]
            err = err + weights('total_len') * (tw_loc.s - total_len)**2

        if min_i5:
            ''' evaluating integral parameters
            '''
            I1, I2, I3, I4, I5 = radiation_integrals(lat, tw0, nsuperperiod=1)
            err += I5 * weights('i5')

            Je = 2 + I4 / I2
            Jx = 1 - I4 / I2
            Jy = 1

            if Je < 0 or Jx < 0 or Jy < 0: err = 100000.0

        # c1, c2 = natural_chromaticity(lat, tw0)
        # err += ( c1**2 + c2**2) * 1.e-6

        if verbose:
            print('iteration error:', err)
        return err