示例#1
0
class TrackBeam(object):
    def __init__(self, config):
        self.config = config
        self.output_dir = config.run_control["output_dir"]
        self.centre = None
        self.ellipse = None
        self.run_dir = ""
        self.cwd = os.getcwd()
        self.hits_in = []
        self.hits_out = []
        self.energy = None

    def load_tune_data(self):
        file_name = self.output_dir+"/"+self.config.find_tune["output_file"]
        fin = open(file_name)
        data = [json.loads(line) for line in fin.readlines()]
        return data

    def fit_tune_data(self, data):
        eps_max = self.config.track_beam['eps_max']
        x_emittance = self.config.track_beam['x_emittance']
        y_emittance = self.config.track_beam['y_emittance']
        sigma_pz = self.config.track_beam['sigma_pz']
        sigma_z = self.config.track_beam['sigma_z']

        self.energy = data['substitutions']['__energy__']
        pid = self.config.tracking["pdg_pid"]
        mass = xboa.common.pdg_pid_to_mass[abs(pid)]
        p = ((self.energy+mass)**2 - mass**2)**0.5

        x_centre, x_ellipse = xboa.common.fit_ellipse(
                                                  data['x_signal'],
                                                  eps_max,
                                                  verbose = False)
        y_centre, y_ellipse = xboa.common.fit_ellipse(
                                                  data['y_signal'],
                                                  eps_max,
                                                  verbose = False)
        x_ellipse *= (x_emittance/numpy.linalg.det(x_ellipse))**0.5
        y_ellipse *= (y_emittance/numpy.linalg.det(y_ellipse))**0.5
        self.centre = numpy.array([x for x in x_centre]+[y for y in y_centre]+[0., p])
        self.ellipse = numpy.zeros((6, 6))
        for i in range(2):
            for j in range(2):
                self.ellipse[i, j] = x_ellipse[i, j]
                self.ellipse[i+2, j+2] = y_ellipse[i, j]
        #for i in [0, 2]:
        #    self.ellipse[i+1, i+1] *= p*p
        #    self.ellipse[i, i+1] *= p
        #    self.ellipse[i+1, i] *= p
        self.ellipse[4, 4] = sigma_z
        self.ellipse[5, 5] = sigma_pz
        print("Centre", self.centre)
        print("Ellipse")
        print(self.ellipse)

    def setup_workspace(self):
        self.run_dir = self.config.run_control["output_dir"]+"/"+self.config.track_beam["run_dir"]
        try:
            os.makedirs(self.run_dir)
        except OSError:
            pass # maybe the dir already exists
        os.chdir(self.run_dir)

    def generate_beam(self):
        n_events = self.config.track_beam["subs_overrides"]["__n_events__"]
        events = numpy.random.multivariate_normal(self.centre, self.ellipse, n_events)
        keys = "x", "px", "y", "py", "z", "pz"
        self.hits_in = []
        for item in events:
            hit = self.reference()
            for i, key in enumerate(keys):
                hit[key] = item[i]
            self.hits_in.append(hit)
        for hit in self.hits_in[0:10]:
            print("   ", [hit[key] for key in keys])
        print("Made", len(self.hits_in), "hits")
        

    def run_tracking(self, index):
        opal_exe = self.config.tracking["opal_path"]
        input_file = self.config.tracking["lattice_file"]
        n_cores = self.config.tracking["n_cores"]
        mpi_exe = self.config.tracking["mpi_exe"]
        lattice_file = self.run_dir+'SectorFFAGMagnet.tmp'

        subs = self.config.substitution_list[index]
        for key, value in self.config.track_beam["subs_overrides"].items():
            subs[key] = value
        xboa.common.substitute(input_file, lattice_file, subs)
        log_name = self.run_dir+"/log"
        ref_hit = self.reference()
        probe_files = self.config.track_beam["probe_files"]
        self.tracking = OpalTracking("SectorFFAGMagnet.tmp", 'disttest.dat', ref_hit, probe_files, opal_exe, log_name, None, n_cores, mpi_exe)

        tunes_analysis = TunesAnalysis(self.config)
        phase_space_plots = PhaseSpacePlots(self.config)
        tunes_analysis.set_match(self.centre[0:4], self.ellipse[0:4, 0:4])
        if self.config.track_beam["do_track"]:
            print("Running tracking with\n   ", end=' ')
            for key, value in subs.items():
                print(utilities.sub_to_name(key)+":", value, end=' ')
            print()
            self.tracking.track_many(self.hits_in, None)
        print(os.getcwd(), probe_files)
        self.tracking._read_probes(tunes_analysis)
        #self.tracking._read_probes(phase_space_plots)


    def reference(self):
        """
        Generate a reference particle
        """
        hit_dict = {}
        hit_dict["pid"] = self.config.tracking["pdg_pid"]
        hit_dict["mass"] = xboa.common.pdg_pid_to_mass[abs(hit_dict["pid"])]
        hit_dict["charge"] = 1
        hit_dict["x"] = 0.
        hit_dict["kinetic_energy"] = self.energy
        return Hit.new_from_dict(hit_dict, "pz")

    def track(self):
        try:
            data = self.load_tune_data()
            self.setup_workspace()
            for i, item in enumerate(data):
                self.fit_tune_data(item)
                self.generate_beam()
                self.run_tracking(i)
        except:
            raise
        finally:
            os.chdir(self.cwd)
示例#2
0
class DAFinder(object):
    def __init__(self, config):
        self.closed_orbit_file_name = os.path.join(
            config.run_control["output_dir"],
            config.find_closed_orbits["output_file"]) + ".out"
        self.da_file_name = os.path.join(config.run_control["output_dir"],
                                         config.find_da["get_output_file"])
        self.scan_file_name = os.path.join(config.run_control["output_dir"],
                                           config.find_da["scan_output_file"])
        self.config = config
        self.run_dir = os.path.join(config.run_control["output_dir"],
                                    config.find_da["run_dir"])
        self.co_list = self.load_closed_orbits()
        self.ref_hit = None
        self.min_delta = config.find_da["min_delta"]
        self.max_delta = config.find_da["max_delta"]
        self.data = []
        self.fout_scan_tmp = None
        self.fout_get_tmp = None
        self.required_n_hits = config.find_da["required_n_hits"]
        self.max_iterations = config.find_da["max_iterations"]
        self.tracking = self.setup()

    def get_all_da(self, co_index_list, seed_x, seed_y):
        if co_index_list == None:
            co_index_list = list(range(len(self.co_list)))
        for i in co_index_list:
            try:
                co_element = self.co_list[i]
                print("Finding da for element", i)
            except KeyError:
                print("Failed to find index", i, "in co_list of length",
                      len(co_list))
                continue
            if seed_x != None and seed_x > 0.:
                co_element['x_da'] = self.get_da(co_element, 'x', seed_x)
            if seed_y != None and seed_y > 0.:
                co_element['y_da'] = self.get_da(co_element, 'y', seed_y)
            print(json.dumps(co_element), file=self.fout_get())
            self.fout_get().flush()

    def da_all_scan(self, co_index_list, x_list, y_list):
        if co_index_list == None:
            co_index_list = list(range(len(self.co_list)))
        for i in co_index_list:
            try:
                co_element = self.co_list[i]
                print("Scanning da for element", i)
            except KeyError:
                print("Failed to find index", i, "in co_list of length",
                      len(co_list))
                continue
            co_element['da_scan'] = self.da_scan(co_element, x_list, y_list)

    def load_closed_orbits(self):
        fin = open(self.closed_orbit_file_name)
        co_list = [json.loads(line) for line in fin.readlines()]
        print("Loaded", len(co_list), "closed orbits")
        return co_list

    def setup(self):
        self.tmp_dir = "./"
        try:
            os.makedirs(self.run_dir)
        except OSError:  # maybe the dir already exists
            pass
        os.chdir(self.run_dir)
        print("Running in", os.getcwd())
        self.opal_exe = os.path.expandvars("${OPAL_EXE_PATH}/opal")

    def reference(self, hit_dict):
        """
        Generate a reference particle
        """
        hit = Hit.new_from_dict(hit_dict)
        hit["x"] = 0.
        hit["px"] = 0.
        return hit

    def setup_tracking(self, co_element):
        subs = co_element["substitutions"]
        for item, key in self.config.find_da["subs_overrides"].items():
            subs[item] = key
        print("Set up tracking for da with", end=' ')
        for key in sorted(subs.keys()):
            print(utilities.sub_to_name(key), subs[key], end=' ')
        self.ref_hit = self.reference(co_element["hits"][0])
        lattice_src = self.config.tracking["lattice_file"]
        common.substitute(lattice_src, self.run_dir + "/SectorFFAGMagnet.tmp",
                          subs)
        tracking_file = self.config.find_da["probe_files"]
        self.tracking = OpalTracking(self.run_dir + "/SectorFFAGMagnet.tmp",
                                     self.tmp_dir + '/disttest.dat',
                                     self.ref_hit, tracking_file,
                                     self.opal_exe, self.tmp_dir + "/log")

    def new_seed(self):
        if self.data[-1][0] < self.min_delta:  # reference run?
            return None
        if self.test_pass(
                *self.data[-1]):  # upper limit is okay; keep going up
            if self.data[-1][0] > self.max_delta:  # too big; give up
                return None
            return self.data[-1][0] * 2.
        elif not self.test_pass(
                *self.data[0]):  # lower limit is bad; try going down
            if abs(self.data[0][0]) < self.min_delta:
                return None
            return self.data[0][0] / 2.
        else:
            for i, item in enumerate(self.data[1:]):
                if not self.test_pass(*item):
                    break
            if abs(self.data[i][0] - self.data[i + 1][0]) < self.min_delta:
                return None
            return (self.data[i][0] + self.data[i + 1][0]) / 2.

    def test_pass(self, seed, hits_list):
        return len(hits_list) > self.required_n_hits

    def events_generator(self, co_element, x_list, y_list):
        co_hit = co_element["hits"][0]
        for x in x_list:
            for y in y_list:
                a_hit = self.ref_hit.deepcopy()
                a_hit['x'] = co_hit['x'] + x
                a_hit['px'] = co_hit['px']
                a_hit['y'] += y
                yield {"x": x, "y": y}, a_hit

    def da_scan(self, co_element, x_list, y_list):
        self.setup_tracking(co_element)
        gen = self.events_generator(co_element, x_list, y_list)
        self.data = []
        finished = False
        while not finished:
            event_list = []
            track_list = []
            try:
                while len(event_list) < 1:
                    track, event = next(gen)
                    track_list.append(track)
                    event_list.append(event)
            except StopIteration:
                finished = True
            if len(event_list) == 0:
                break
            many_tracks = self.tracking.track_many(event_list)
            for i, hits in enumerate(many_tracks):
                print("Tracked", len(hits), "total hits with track",
                      track_list[i], "first event x px y py", hits[0]["x"],
                      hits[0]["px"], hits[0]["y"], hits[0]["py"])
                self.data.append(
                    [track_list[i], [a_hit.dict_from_hit() for a_hit in hits]])
        print(json.dumps(self.data), file=self.fout_scan())
        self.fout_scan().flush()

    def get_da(self, co_element, axis, seed_x):
        is_ref = abs(seed_x) < 1e-6
        self.setup_tracking(co_element)
        self.data = []
        co_delta = {"x": 0, "y": 0}
        iteration = 0
        while seed_x != None and iteration < self.max_iterations:
            co_delta[axis] = seed_x
            my_time = time.time()
            a_hit = Hit.new_from_dict(co_element["hits"][0])
            a_hit[axis] += seed_x
            try:
                hits = self.tracking.track_one(a_hit)
            except RuntimeError:
                sys.excepthook(*sys.exc_info())
                print("Never mind, keep on going...")
            self.data.append(
                [co_delta[axis], [a_hit.dict_from_hit() for a_hit in hits]])
            self.data = sorted(self.data)
            print("Axis", axis, "Seed", seed_x, "Number of cells hit",
                  len(hits), "in",
                  time.time() - my_time, "[s]")
            sys.stdout.flush()
            seed_x = self.new_seed()
            if is_ref:
                seed_x = None
            iteration += 1
        self.data = [list(item) for item in self.data]
        return self.data

    def fout_scan(self):
        if self.fout_scan_tmp == None:
            file_name = self.scan_file_name + ".tmp"
            self.fout_scan_tmp = open(file_name, "w")
            print("Opened file", file_name)
        return self.fout_scan_tmp

    def fout_get(self):
        if self.fout_get_tmp == None:
            file_name = self.da_file_name + ".tmp"
            self.fout_get_tmp = open(file_name, "w")
            print("Opened file", file_name)
        return self.fout_get_tmp
示例#3
0
class DAFinder(object):
    """
    DAFinder attempts to find the dynamic aperture by performing a binary search
    of trajectories offset from the closed orbit. DA is determined to be the
    trajectory that passes through at least a (user defined) number of probes.
    """
    def __init__(self, config):
        """
        Initialise the DAFinder object
        - config: configuration object
        """
        self.closed_orbit_file_name = os.path.join(config.run_control["output_dir"], 
                                                   config.find_closed_orbits["output_file"])
        self.da_file_name = os.path.join(config.run_control["output_dir"],
                                         config.find_da["get_output_file"])
        self.scan_file_name = os.path.join(config.run_control["output_dir"],
                                           config.find_da["scan_output_file"]) 
        self.config = config
        self.run_dir =  os.path.join(config.run_control["output_dir"],
                                     config.find_da["run_dir"])
        self.co_list = self.load_closed_orbits()
        self.ref_hit = None
        self.min_delta = config.find_da["min_delta"]
        self.max_delta = config.find_da["max_delta"]
        self.data = []
        self.fout_scan_tmp = None
        self.fout_get_tmp = None
        self.required_n_hits = config.find_da["required_n_hits"]
        self.max_iterations = config.find_da["max_iterations"]
        self.tracking = self.setup()

    def get_all_da(self, co_index_list, seed_x, seed_y):
        """
        Get the DA
        - co_index_list: list of indices to find DA for. Each element should be
                         an index from self.config.substitution_list. Set to 
                         None to iterate over every element.
        - seed_x: (float) best guess position offset for trajectory on the 
                  horizontal DA. Set to None or negative value to disable
                  horizontal DA finding.
        - seed_y: (float) best guess position offset for trajectory on the 
                  vertical DA. Set to None or negative value to disable vertical 
                  DA finding.
        """
        if co_index_list == None:
            co_index_list = list(range(len(self.co_list)))
        for i in co_index_list:
            try:
                co_element = self.co_list[i]
                print("Finding da for element", i)
            except KeyError:
                print("Failed to find index", i, "in co_list of length", len(co_list))
                continue
            if seed_x != None and seed_x > 0.:
                co_element['x_da'] = self.get_da(co_element, 'x', seed_x)
            if seed_y != None and seed_y > 0.:
                co_element['y_da'] = self.get_da(co_element, 'y', seed_y)
            print(json.dumps(co_element), file=self.fout_get())
            self.fout_get().flush()
        os.rename(self.da_file_name+".tmp", self.da_file_name)

    def da_all_scan(self, co_index_list, x_list, y_list):
        """
        Scan the DA
        - co_index_list: list of indices to find DA for. Each element should be
                         an index from self.config.substitution_list. Set to 
                         None to iterate over every element.
        - seed_x: list of floats. Each list element is a horizontal position 
                  offset from the closed orbit; the algorithm will track the
                  particle and count the number of probes through which the
                  particle passes.
        - seed_y: list of floats. Each list element is a vertical position 
                  offset from the closed orbit; the algorithm will track the
                  particle and count the number of probes through which the
                  particle passes.

        The scan routine will generate a 2D grid in x and y. The trajectories
        will be written to the scan file name.
        """
        if co_index_list == None:
            co_index_list = list(range(len(self.co_list)))
        for i in co_index_list:
            try:
                co_element = self.co_list[i]
                print("Scanning da for element", i)
            except KeyError:
                print("Failed to find index", i, "in co_list of length", len(co_list))
                continue
            co_element['da_scan'] = self.da_scan(co_element, x_list, y_list)
        os.rename(self.scan_file_name+".tmp", self.scan_file_name)

    def load_closed_orbits(self):
        """
        Load the closed orbits file
        """
        fin = open(self.closed_orbit_file_name)
        co_list = [json.loads(line) for line in fin.readlines()]
        print("Loaded", len(co_list), "closed orbits")
        return co_list

    def setup(self):
        """
        Perform some setup
        """
        self.tmp_dir = "./"
        try:
            os.makedirs(self.run_dir)
        except OSError: # maybe the dir already exists
            pass
        os.chdir(self.run_dir)
        print("Running in", os.getcwd())
        self.opal_exe = os.path.expandvars("${OPAL_EXE_PATH}/opal")

    def reference(self, hit_dict):
        """
        Generate a reference particle
        """
        hit = Hit.new_from_dict(hit_dict)
        hit["x"] = 0.
        hit["px"] = 0.
        return hit

    def setup_tracking(self, co_element):
        """
        Setup the tracking routines
        """
        subs = co_element["substitutions"]
        for item, key in self.config.find_da["subs_overrides"].items():
            subs[item] = key
        print("Set up tracking for da with", end=' ') 
        for key in sorted(subs.keys()):
            print(utilities.sub_to_name(key), subs[key], end=' ')
        self.ref_hit = self.reference(co_element["hits"][0])
        lattice_src = self.config.tracking["lattice_file"]
        common.substitute(
            lattice_src,
            self.run_dir+"/SectorFFAGMagnet.tmp",
            subs
        )
        tracking_file = self.config.find_da["probe_files"]
        self.tracking = OpalTracking(self.run_dir+"/SectorFFAGMagnet.tmp", self.tmp_dir+'/disttest.dat', self.ref_hit, tracking_file, self.opal_exe, self.tmp_dir+"/log")

    def new_seed(self):
        """
        Generate a new seed for the DA finding routines.
        * If all previous iterations have passed, then the largest offset is
          doubled
        * If all previous iterations have failed, then the smallest offset is
          halved
        * If some previous iterations have failed and some have passed, then a
          binary interplation is performed between the highest passing
          iteration and the lowest failing iteration
        Returns the new seed or None if the iteration has finished.
        """
        if self.data[-1][0] < self.min_delta: # reference run?
            return None
        if self.test_pass(*self.data[-1]): # upper limit is okay; keep going up
            if self.data[-1][0] > self.max_delta: # too big; give up
                return None
            return self.data[-1][0]*2.
        elif not self.test_pass(*self.data[0]): # lower limit is bad; try going down
            if abs(self.data[0][0]) < self.min_delta:
                return None
            return self.data[0][0]/2.
        else:
            for i, item in enumerate(self.data[1:]):
                if not self.test_pass(*item):
                    break
            if abs(self.data[i][0]-self.data[i+1][0]) < self.min_delta:
                return None
            return (self.data[i][0]+self.data[i+1][0])/2.

    def test_pass(self, seed, hits_list):
        """
        Check to see if an iteration passed
        
        Returns true if the number of hits in the hits list is greater than the
        required number of hits.
        """
        return len(hits_list) > self.required_n_hits

    def events_generator(self, co_element, x_list, y_list):
        """
        Generates the list of events for the da scan.
        """
        co_hit = co_element["hits"][0]
        for x in x_list:
            for y in y_list:
                a_hit = self.ref_hit.deepcopy()
                a_hit['x'] = co_hit['x'] + x
                a_hit['px'] = co_hit['px']
                a_hit['y'] += y
                yield {"x":x, "y":y}, a_hit

    def da_scan(self, co_element, x_list, y_list):
        """
        Do the da scan for a particular element of self.config.substitution_list.
        """
        self.setup_tracking(co_element)
        gen = self.events_generator(co_element, x_list, y_list)
        self.data = []
        finished = False
        while not finished:
            event_list = []
            track_list = []
            try:
                while len(event_list) < 1:
                    track, event = next(gen)
                    track_list.append(track)
                    event_list.append(event)
            except StopIteration:
                finished = True
            if len(event_list) == 0:
                break
            many_tracks = self.tracking.track_many(event_list)
            for i, hits in enumerate(many_tracks):
                print("Tracked", len(hits), "total hits with track", track_list[i], "first event x px y py", hits[0]["x"], hits[0]["px"], hits[0]["y"], hits[0]["py"])
                self.data.append([track_list[i], [a_hit.dict_from_hit() for a_hit in hits]])
        print(json.dumps(self.data), file=self.fout_scan())
        self.fout_scan().flush()

    def get_da(self, co_element, axis, seed_x):
        """
        Do the da finding for a particular element of 
        self.config.substitution_list
        """
        is_ref = abs(seed_x) < 1e-6
        self.setup_tracking(co_element)
        self.data = []
        co_delta = {"x":0, "y":0}
        iteration = 0
        while seed_x != None and iteration < self.max_iterations:
            co_delta[axis] = seed_x
            my_time = time.time()
            a_hit = Hit.new_from_dict(co_element["hits"][0])
            a_hit[axis] += seed_x
            try:
                hits = self.tracking.track_one(a_hit)
            except (RuntimeError, OSError):
                sys.excepthook(*sys.exc_info())
                print("Never mind, keep on going...")
                hits = [a_hit]
            self.data.append([co_delta[axis], [a_hit.dict_from_hit() for a_hit in hits]])
            self.data = sorted(self.data)
            print("Axis", axis, "Seed", seed_x, "Number of cells hit", len(hits), "in", time.time() - my_time, "[s]")
            sys.stdout.flush()
            seed_x = self.new_seed()
            if is_ref:
                seed_x = None
            iteration += 1
        self.data = [list(item) for item in self.data]
        return self.data

    def fout_scan(self):
        """
        Open the scan output file
        """
        if self.fout_scan_tmp == None:
            file_name = self.scan_file_name+".tmp"
            self.fout_scan_tmp = open(file_name, "w")
            print("Opened file", file_name)
        return self.fout_scan_tmp

    def fout_get(self):
        """
        Open the da finder output file
        """
        if self.fout_get_tmp == None:
            file_name = self.da_file_name+".tmp"
            self.fout_get_tmp = open(file_name, "w")
            print("Opened file", file_name)
        return self.fout_get_tmp