Ejemplo n.º 1
0
class Astra:
    """ 
    Astra simulation object. Essential methods:
    .__init__(...)
    .configure()
    .run()
    
    Input deck is held in .input
    Output data is parsed into .output
    .load_particles() will load particle data into .output['particles'][...]
    
    The Astra binary file can be set on init. If it doesn't exist, configure will check the
        $ASTRA_BIN
    environmental variable.
    
    
    """
    def __init__(self,
                 input_file=None,
                 initial_particles=None,
                 astra_bin='$ASTRA_BIN',
                 use_tempdir=True,
                 workdir=None,
                 verbose=False):
        # Save init
        self.original_input_file = input_file
        self.initial_particles = initial_particles
        self.use_tempdir = use_tempdir
        self.workdir = workdir
        if workdir:
            assert os.path.exists(
                workdir), 'workdir does not exist: ' + workdir
        self.verbose = verbose
        self.astra_bin = astra_bin

        # These will be set
        self.log = []
        self.output = {'stats': {}, 'particles': {}, 'run_info': {}}
        self.timeout = None
        self.error = False

        # Run control
        self.finished = False
        self.configured = False
        self.using_tempdir = False

        # Call configure
        if input_file:
            self.load_input(input_file)
            self.configure()
        else:
            self.vprint('Warning: Input file does not exist. Not configured.')
            self.original_input_file = 'astra.in'

    def clean_output(self):
        run_number = parsers.astra_run_extension(self.input['newrun']['run'])
        outfiles = parsers.find_astra_output_files(self.input_file, run_number)
        for f in outfiles:
            os.remove(f)

    def clean_particles(self):
        run_number = parsers.astra_run_extension(self.input['newrun']['run'])
        phase_files = parsers.find_phase_files(self.input_file, run_number)
        files = [x[0] for x in phase_files]  # This is sorted by approximate z
        for f in files:
            os.remove(f)

    # Convenience routines
    @property
    def particles(self):
        return self.output['particles']

    def stat(self, key):
        return self.output['stats'][key]

    def particle_stat(self, key, alive_only=True):
        """
        Compute a statistic from the particles.
        
        Alive particles have status == 1. By default, statistics will only be computed on these.
        
        n_dead will override the alive_only flag, 
        and return the number of particles with status < -6 (Astra convention)
        """

        if key == 'n_dead':
            return np.array(
                [len(np.where(P.status < -6)[0]) for P in self.particles])

        if key == 'n_alive':
            return np.array(
                [len(np.where(P.status > -6)[0]) for P in self.particles])

        pstats = []
        for P in self.particles:
            if alive_only and P.n_dead > 0:
                P = P.where(P.status == 1)
            pstats.append(P[key])
        return np.array(pstats)

    def configure(self):
        self.configure_astra(workdir=self.workdir)

    def configure_astra(self, input_filePath=None, workdir=None):

        if input_filePath:
            self.load_input(input_filePath)

        # Check that binary exists
        self.astra_bin = tools.full_path(self.astra_bin)
        assert os.path.exists(
            self.astra_bin
        ), 'ERROR: Astra binary does not exist:' + self.astra_bin

        # Set paths
        if self.use_tempdir:
            # Need to attach this to the object. Otherwise it will go out of scope.
            self.tempdir = tempfile.TemporaryDirectory(dir=workdir)
            self.path = self.tempdir.name
        else:
            # Work in place
            self.path = self.original_path

        self.input_file = os.path.join(self.path, self.original_input_file)
        self.configured = True

    def load_initial_particles(self, h5):
        """Loads a openPMD-beamphysics particle h5 handle or file"""
        P = ParticleGroup(h5=h5)
        self.initial_particles = P

    def load_input(self, input_filePath, absolute_paths=True):
        f = tools.full_path(input_filePath)
        self.original_path, self.original_input_file = os.path.split(
            f)  # Get original path, filename
        self.input = parsers.parse_astra_input_file(f)

        if absolute_paths:
            parsers.fix_input_paths(self.input, root=self.original_path)

    def load_output(self, include_particles=True):
        """
        Loads Astra output files into .output
        
        .output is a dict with dicts:
            .stats
            .run_info
            .other
            
        and if include_particles,
            .particles = list of ParticleGroup objects
        
        """
        run_number = parsers.astra_run_extension(self.input['newrun']['run'])
        outfiles = parsers.find_astra_output_files(self.input_file, run_number)

        #assert len(outfiles)>0, 'No output files found'

        stats = self.output['stats'] = {}

        for f in outfiles:
            type = parsers.astra_output_type(f)
            d = parsers.parse_astra_output_file(f)
            if type in ['Cemit', 'Xemit', 'Yemit', 'Zemit']:
                stats.update(d)
            elif type in ['LandF']:
                self.output['other'] = d
            else:
                raise ValueError(f'Unknown output type: {type}')

        # Check that the lengths of all arrays are the same
        nlist = {len(stats[k]) for k in stats}

        assert len(
            nlist
        ) == 1, f'Stat keys do not all have the same length: { [len(stats[k]) for k in stats]}'

        if include_particles:
            self.load_particles()

    def load_particles(self, end_only=False):
        # Clear existing particles
        self.output['particles'] = []

        # Sort files by approximate z
        run_number = parsers.astra_run_extension(self.input['newrun']['run'])
        phase_files = parsers.find_phase_files(self.input_file, run_number)
        files = [x[0] for x in phase_files]  # This is sorted by approximate z
        zapprox = [x[1] for x in phase_files]

        if end_only:
            files = files[-1:]
        if self.verbose:
            print('loading ' + str(len(files)) + ' particle files')
            print(zapprox)
        for f in files:
            pdat = parsers.parse_astra_phase_file(f)
            P = ParticleGroup(data=pdat)
            self.output['particles'].append(P)

    def run(self):
        if not self.configured:
            print('not configured to run')
            return
        self.run_astra(verbose=self.verbose, timeout=self.timeout)

    def get_run_script(self, write_to_path=True):
        """
        Assembles the run script. Optionally writes a file 'run' with this line to path.
        """

        #_, infile = os.path.split(self.input_file)

        runscript = [self.astra_bin, self.input_file]

        if write_to_path:
            with open(os.path.join(self.path, 'run'), 'w') as f:
                f.write(' '.join(runscript))

        return runscript

    def run_astra(self, verbose=False, parse_output=True, timeout=None):
        """
        Runs Astra
        
        Does not change directory - this has problems with multithreading.
        """

        run_info = self.output['run_info'] = {}

        t1 = time()
        run_info['start_time'] = t1

        if self.initial_particles:
            fname = self.write_initial_particles()
            self.input['newrun']['distribution'] = fname

        # Write input file from internal dict
        self.write_input_file()

        runscript = self.get_run_script()
        run_info['run_script'] = ' '.join(runscript)

        try:
            if timeout:
                res = tools.execute2(runscript, timeout=timeout)
                log = res['log']
                self.error = res['error']
                run_info['why_error'] = res['why_error']
                # Log file must have this to have finished properly
                if log.find('finished simulation') == -1:
                    run_info['error'] = True
                    run_info.update({
                        'error':
                        True,
                        'why_error':
                        "Couldn't find finished simulation"
                    })

            else:
                # Interactive output, for Jupyter
                log = []
                for path in tools.execute(runscript):
                    self.vprint(path, end="")
                    log.append(path)

            self.log = log

            if parse_output:
                self.load_output()
        except Exception as ex:
            print('Run Aborted', ex)
            self.error = True
            run_info['why_error'] = str(ex)
        finally:
            run_info['run_time'] = time() - t1
            run_info['run_error'] = self.error

        self.finished = True

        self.vprint(run_info)

    def fingerprint(self):
        """
        Data fingerprint using the input. 
        """
        return tools.fingerprint(self.input)

    def vprint(self, *args, **kwargs):
        # Verbose print
        if self.verbose:
            print(*args, **kwargs)

    def units(self, key):
        if key in parsers.OutputUnits:
            return parsers.OutputUnits[key]
        else:
            return 'unknown unit'

    def write_input_file(self):
        parsers.write_namelists(self.input, self.input_file)

    def write_initial_particles(self, fname=None):
        if not fname:
            fname = os.path.join(self.path, 'astra.particles')
        self.initial_particles.write_astra(fname)
        self.vprint(f'Initial particles written to {fname}')
        return fname

    def load_archive(self, h5=None):
        """
        Loads input and output from archived h5 file.
        
        See: Astra.archive
        """
        if isinstance(h5, str):
            g = h5py.File(h5, 'r')

            glist = archive.find_astra_archives(g)
            n = len(glist)
            if n == 0:
                # legacy: try top level
                message = 'legacy'
            elif n == 1:
                gname = glist[0]
                message = f'group {gname} from'
                g = g[gname]
            else:
                raise ValueError(
                    f'Multiple archives found in file {h5}: {glist}')

            self.vprint(f'Reading {message} archive file {h5}')
        else:
            g = h5

        self.input = archive.read_input_h5(g['input'])
        self.output = archive.read_output_h5(g['output'])
        if 'initial_particles' in g:
            self.initial_particles = ParticleGroup(h5=g['initial_particles'])

        self.vprint(
            'Loaded from archive. Note: Must reconfigure to run again.')
        self.configured = False

    def archive(self, h5=None):
        """
        Archive all data to an h5 handle or filename.
        
        If no file is given, a file based on the fingerprint will be created.
        
        """
        if not h5:
            h5 = 'astra_' + self.fingerprint() + '.h5'

        if isinstance(h5, str):
            g = h5py.File(h5, 'w')
            self.vprint(f'Archiving to file {h5}')
        else:
            # store directly in the given h5 handle
            g = h5

        # Write basic attributes
        archive.astra_init(g)

        # Initial particles
        if self.initial_particles:
            self.initial_particles.write(g, name='initial_particles')

        # All input
        archive.write_input_h5(g, self.input)

        # All output
        archive.write_output_h5(g, self.output)

        return h5
Ejemplo n.º 2
0
class Astra:
    """ 
    Astra simulation object. Essential methods:
    .__init__(...)
    .configure()
    .run()
    
    Input deck is held in .input
    Output data is parsed into .output
    .load_particles() will load particle data into .output['particles'][...]
    
    The Astra binary file can be set on init. If it doesn't exist, configure will check the
        $ASTRA_BIN
    environmental variable.
    
    
    """
    def __init__(self,
                 input_file=None,
                 initial_particles=None,
                 astra_bin='$ASTRA_BIN',
                 use_tempdir=True,
                 workdir=None,
                 group=None,
                 verbose=False):
        # Save init
        self.original_input_file = input_file
        self.initial_particles = initial_particles
        self.use_tempdir = use_tempdir
        self.workdir = workdir
        if workdir:
            assert os.path.exists(
                workdir), 'workdir does not exist: ' + workdir
        self.verbose = verbose
        self.astra_bin = astra_bin

        # These will be set
        self.log = []
        self.output = {'stats': {}, 'particles': {}, 'run_info': {}}
        self.timeout = None
        self.error = False
        self.group = {}  # Control Groups
        self.fieldmap = {}  # Fieldmaps

        # Run control
        self.finished = False
        self.configured = False
        self.using_tempdir = False

        # Call configure
        if input_file:
            self.load_input(input_file)
            self.configure()

            # Add groups, if any.
            if group:
                for k, v in group.items():
                    self.add_group(k, **v)

        else:
            self.vprint('Warning: Input file does not exist. Not configured.')
            self.original_input_file = 'astra.in'

    def add_group(self, name, **kwargs):
        """
        Add a control group. See control.py
        """
        assert name not in self.input, f'{name} not allowed to be overwritten by group.'
        if name in self.group:
            self.vprint(f'Warning: group {name} already exists, overwriting.')

        g = ControlGroup(**kwargs)
        g.link(self.input)
        self.group[name] = g

        return self.group[name]

    def clean_output(self):
        run_number = parsers.astra_run_extension(self.input['newrun']['run'])
        outfiles = parsers.find_astra_output_files(self.input_file, run_number)
        for f in outfiles:
            os.remove(f)

    def clean_particles(self):
        run_number = parsers.astra_run_extension(self.input['newrun']['run'])
        phase_files = parsers.find_phase_files(self.input_file, run_number)
        files = [x[0] for x in phase_files]  # This is sorted by approximate z
        for f in files:
            os.remove(f)

    # Convenience routines
    @property
    def particles(self):
        return self.output['particles']

    def stat(self, key):
        return self.output['stats'][key]

    def particle_stat(self, key, alive_only=True):
        """
        Compute a statistic from the particles.
        
        Alive particles have status == 1. By default, statistics will only be computed on these.
        
        n_dead will override the alive_only flag, 
        and return the number of particles with status < -6 (Astra convention)
        """

        if key == 'n_dead':
            return np.array(
                [len(np.where(P.status < -6)[0]) for P in self.particles])

        if key == 'n_alive':
            return np.array(
                [len(np.where(P.status > -6)[0]) for P in self.particles])

        pstats = []
        for P in self.particles:
            if alive_only and P.n_dead > 0:
                P = P.where(P.status == 1)
            pstats.append(P[key])
        return np.array(pstats)

    def configure(self):
        self.configure_astra(workdir=self.workdir)

    def configure_astra(self, input_filePath=None, workdir=None):

        if input_filePath:
            self.load_input(input_filePath)

        # Check that binary exists
        self.astra_bin = tools.full_path(self.astra_bin)
        assert os.path.exists(
            self.astra_bin
        ), 'ERROR: Astra binary does not exist:' + self.astra_bin

        # Set paths
        if self.use_tempdir:
            # Need to attach this to the object. Otherwise it will go out of scope.
            self.tempdir = tempfile.TemporaryDirectory(dir=workdir)
            self.path = self.tempdir.name
        else:
            # Work in place
            self.path = self.original_path

        self.input_file = os.path.join(self.path, self.original_input_file)
        self.configured = True

    def load_fieldmaps(self):
        """
        Loads fieldmaps into Astra.fieldmap as a dict
        """

        # Do not consider files if fieldmaps have been loaded.
        if self.fieldmap:
            strip_path = False
        else:
            strip_path = True

        self.fieldmap = load_fieldmaps(self,
                                       fieldmap_dict=self.fieldmap,
                                       search_paths=[self.path],
                                       verbose=self.verbose,
                                       strip_path=strip_path)

    def load_initial_particles(self, h5):
        """Loads a openPMD-beamphysics particle h5 handle or file"""
        P = ParticleGroup(h5=h5)
        self.initial_particles = P

    def load_input(self, input_filePath, absolute_paths=True):
        f = tools.full_path(input_filePath)
        self.original_path, self.original_input_file = os.path.split(
            f)  # Get original path, filename
        self.input = parsers.parse_astra_input_file(f)

        if absolute_paths:
            parsers.fix_input_paths(self.input, root=self.original_path)

    def load_output(self, include_particles=True):
        """
        Loads Astra output files into .output
        
        .output is a dict with dicts:
            .stats
            .run_info
            .other
            
        and if include_particles,
            .particles = list of ParticleGroup objects
        
        """
        run_number = parsers.astra_run_extension(self.input['newrun']['run'])
        outfiles = parsers.find_astra_output_files(self.input_file, run_number)

        #assert len(outfiles)>0, 'No output files found'

        stats = self.output['stats'] = {}

        for f in outfiles:
            type = parsers.astra_output_type(f)
            d = parsers.parse_astra_output_file(f)
            if type in ['Cemit', 'Xemit', 'Yemit', 'Zemit']:
                stats.update(d)
            elif type in ['LandF']:
                self.output['other'] = d
            else:
                raise ValueError(f'Unknown output type: {type}')

        # Check that the lengths of all arrays are the same
        nlist = {len(stats[k]) for k in stats}

        assert len(
            nlist
        ) == 1, f'Stat keys do not all have the same length: { [len(stats[k]) for k in stats]}'

        if include_particles:
            self.load_particles()

    def load_particles(self, end_only=False):
        # Clear existing particles
        self.output['particles'] = []

        # Sort files by approximate z
        run_number = parsers.astra_run_extension(self.input['newrun']['run'])
        phase_files = parsers.find_phase_files(self.input_file, run_number)
        files = [x[0] for x in phase_files]  # This is sorted by approximate z
        zapprox = [x[1] for x in phase_files]

        if end_only:
            files = files[-1:]
        if self.verbose:
            print('loading ' + str(len(files)) + ' particle files')
            print(zapprox)
        for f in files:
            pdat = parsers.parse_astra_phase_file(f)
            P = ParticleGroup(data=pdat)
            self.output['particles'].append(P)

    def run(self):
        if not self.configured:
            print('not configured to run')
            return
        self.run_astra(verbose=self.verbose, timeout=self.timeout)

    def get_run_script(self, write_to_path=True):
        """
        Assembles the run script. Optionally writes a file 'run' with this line to path.
        
        This expect to run with .path as the cwd. 
        """

        _, infile = os.path.split(
            self.input_file
        )  # Expect to run locally. Astra has problems with long paths.

        runscript = [self.astra_bin, infile]

        if write_to_path:
            with open(os.path.join(self.path, 'run'), 'w') as f:
                f.write(' '.join(runscript))

        return runscript

    def run_astra(self, verbose=False, parse_output=True, timeout=None):
        """
        Runs Astra
        
        Changes directory, so does not work with threads. 
        """

        run_info = self.output['run_info'] = {}

        t1 = time()
        run_info['start_time'] = t1

        if self.initial_particles:
            fname = self.write_initial_particles()
            self.input['newrun']['distribution'] = fname

        # Write all input
        self.write_input()

        runscript = self.get_run_script()
        run_info['run_script'] = ' '.join(runscript)

        try:
            if timeout:
                res = tools.execute2(runscript, timeout=timeout, cwd=self.path)
                log = res['log']
                self.error = res['error']
                run_info['why_error'] = res['why_error']
                # Log file must have this to have finished properly
                if log.find('finished simulation') == -1:
                    run_info['error'] = True
                    run_info.update({
                        'error':
                        True,
                        'why_error':
                        "Couldn't find finished simulation"
                    })

            else:
                # Interactive output, for Jupyter
                log = []
                for path in tools.execute(runscript, cwd=self.path):
                    self.vprint(path, end="")
                    log.append(path)

            self.log = log

            if parse_output:
                self.load_output()
        except Exception as ex:
            print('Run Aborted', ex)
            self.error = True
            run_info['why_error'] = str(ex)
        finally:
            run_info['run_time'] = time() - t1
            run_info['run_error'] = self.error

        self.finished = True

        self.vprint(run_info)

    def fingerprint(self):
        """
        Data fingerprint using the input. 
        """
        return tools.fingerprint(self.input)

    def vprint(self, *args, **kwargs):
        # Verbose print
        if self.verbose:
            print(*args, **kwargs)

    def units(self, key):
        if key in parsers.OutputUnits:
            return parsers.OutputUnits[key]
        else:
            return 'unknown unit'

    def load_archive(self, h5=None):
        """
        Loads input and output from archived h5 file.
        
        See: Astra.archive
        """
        if isinstance(h5, str):
            h5 = os.path.expandvars(h5)
            g = h5py.File(h5, 'r')

            glist = archive.find_astra_archives(g)
            n = len(glist)
            if n == 0:
                # legacy: try top level
                message = 'legacy'
            elif n == 1:
                gname = glist[0]
                message = f'group {gname} from'
                g = g[gname]
            else:
                raise ValueError(
                    f'Multiple archives found in file {h5}: {glist}')

            self.vprint(f'Reading {message} archive file {h5}')
        else:
            g = h5

        self.input = archive.read_input_h5(g['input'])
        self.output = archive.read_output_h5(g['output'])
        if 'initial_particles' in g:
            self.initial_particles = ParticleGroup(h5=g['initial_particles'])

        if 'fieldmap' in g:
            self.fieldmap = archive.read_fieldmap_h5(g['fieldmap'])

        if 'control_groups' in g:
            self.group = archive.read_control_groups_h5(g['control_groups'],
                                                        verbose=self.verbose)

        self.vprint(
            'Loaded from archive. Note: Must reconfigure to run again.')
        self.configured = False

        # Re-link groups
        # TODO: cleaner logic
        for _, cg in self.group.items():
            cg.link(self.input)

    def archive(self, h5=None):
        """
        Archive all data to an h5 handle or filename.
        
        If no file is given, a file based on the fingerprint will be created.
        
        """
        if not h5:
            h5 = 'astra_' + self.fingerprint() + '.h5'

        if isinstance(h5, str):
            h5 = os.path.expandvars(h5)
            g = h5py.File(h5, 'w')
            self.vprint(f'Archiving to file {h5}')
        else:
            # store directly in the given h5 handle
            g = h5

        # Write basic attributes
        archive.astra_init(g)

        # Initial particles
        if self.initial_particles:
            self.initial_particles.write(g, name='initial_particles')

        # Fieldmaps
        if self.fieldmap:
            archive.write_fieldmap_h5(g, self.fieldmap, name='fieldmap')

        # All input
        archive.write_input_h5(g, self.input)

        # All output
        archive.write_output_h5(g, self.output)

        # Control groups
        if self.group:
            archive.write_control_groups_h5(g,
                                            self.group,
                                            name='control_groups')

        return h5

    @classmethod
    def from_archive(cls, archive_h5):
        """
        Class method to return an GPT object loaded from an archive file
        """
        c = cls()
        c.load_archive(archive_h5)
        return c

    @classmethod
    def from_yaml(cls, yaml_file):
        """
        Returns an Astra object instantiated from a YAML config file
        
        Will load intial_particles from an h5 file. 
        
        """
        # Try file
        if os.path.exists(os.path.expandvars(yaml_file)):
            config = yaml.safe_load(open(yaml_file))

            # The input file might be relative to the yaml file
            if 'input_file' in config:
                f = os.path.expandvars(config['input_file'])
                if not os.path.isabs(f):
                    # Get the yaml file root
                    root, _ = os.path.split(tools.full_path(yaml_file))
                    config['input_file'] = os.path.join(root, f)

        else:
            #Try raw string
            config = yaml.safe_load(yaml_file)

        # Form ParticleGroup from file
        if 'initial_particles' in config:
            f = config['initial_particles']
            if not os.path.isabs(f):
                root, _ = os.path.split(tools.full_path(yaml_file))
                f = os.path.join(root, f)
            config['initial_particles'] = ParticleGroup(f)

        return cls(**config)

    def write_fieldmaps(self):
        """
        Writes any loaded fieldmaps to path
        """

        if self.fieldmap:
            write_fieldmaps(self.fieldmap, self.path)
            self.vprint(
                f'{len(self.fieldmap)} fieldmaps written to {self.path}')

    def write_input(self):
        """
        Writes all input. If fieldmaps have been loaded, these will also be written. 
        """

        self.write_fieldmaps()

        self.write_input_file()

    def write_input_file(self):

        if self.use_tempdir:
            make_symlinks = True
        else:
            make_symlinks = False

        writers.write_namelists(self.input,
                                self.input_file,
                                make_symlinks=make_symlinks,
                                verbose=self.verbose)

    def write_initial_particles(self, fname=None):
        if not fname:
            fname = os.path.join(self.path, 'astra.particles')
        self.initial_particles.write_astra(fname)
        self.vprint(f'Initial particles written to {fname}')
        return fname

    def plot(self,
             y=['sigma_x', 'sigma_y'],
             x='mean_z',
             xlim=None,
             y2=[],
             nice=True,
             include_layout=True,
             include_labels=False,
             include_particles=True,
             include_legend=True,
             **kwargs):
        """
        Plots stat output multiple keys.
    
        If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key. 
    
        Logical switches, all default to True:
            nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized.
        
            include_legend: The plot will include the legend
        
            include_layout: the layout plot (fieldmaps) will be displayed at the bottom
        
            include_labels: the layout will include element labels.    
            
        If there is no output to plot, the fieldmaps will be plotted with .plot_fieldmaps
        
        """

        # Just plot fieldmaps if there are no
        if not self.output['stats']:
            return plot_fieldmaps(self,
                                  xlim=xlim,
                                  fieldmap_dict=self.fieldmap,
                                  **kwargs)

        plot_stats_with_layout(self,
                               ykeys=y,
                               ykeys2=y2,
                               xkey=x,
                               xlim=xlim,
                               nice=nice,
                               include_layout=include_layout,
                               include_labels=include_labels,
                               include_particles=include_particles,
                               include_legend=include_legend,
                               **kwargs)

    def plot_fieldmaps(self, **kwargs):
        return plot_fieldmaps(self, **kwargs)

    def copy(self):
        """
        Returns a deep copy of this object.
        
        If a tempdir is being used, will clear this and deconfigure. 
        """
        A2 = deepcopy(self)
        # Clear this
        if A2.use_tempdir:
            A2.path = None
            A2.configured = False

        return A2

    def __getitem__(self, key):
        """
        Convenience syntax to get a header or element attribute. 

        Special syntax:
        
        end_X
            will return the final item in a stat array X
            Example:
            'end_norm_emit_x'
            
        particles:N
            will return a ParticleGroup N from the .particles list
            Example:
                'particles:-1'
                returns the readback of the final particles
        particles:N:Y
            ParticleGroup N's property Y
            Example:
                'particles:-1:sigma_x'
            returns sigma_x from the end of the particles list.

        
        See: __setitem__
        """

        # Object attributes
        if hasattr(self, key):
            return getattr(self, key)

        # Send back top level input (namelist) or group object.
        # Do not add these to __setitem__. The user shouldn't be allowed to change them as a whole,
        #   because it will break all the links.
        if key in self.group:
            return self.group[key]
        if key in self.input:
            return self.input[key]

        if key.startswith('end_'):
            key2 = key[len('end_'):]
            assert key2 in self.output[
                'stats'], f'{key} does not have valid output stat: {key2}'
            return self.output['stats'][key2][-1]

        if key.startswith('particles:'):
            key2 = key[len('particles:'):]
            x = key2.split(':')
            if len(x) == 1:
                return self.particles[int(x[0])]
            else:
                return self.particles[int(x[0])][x[1]]

        # key isn't an ele or group, should have property s

        x = key.split(':')
        assert len(
            x
        ) == 2, f'{x} was not found in group or input dict, so should have : '
        name, attrib = x[0], x[1]

        # Look in input and group
        if name in self.input:
            return self.input[name][attrib]
        elif name in self.group:
            return self.group[name][attrib]

    def __setitem__(self, key, item):
        """
        Convenience syntax to set namelist or group attribute. 
        attribute_string should be 'header:key' or 'ele_name:key'
        
        Examples of attribute_string: 'header:Np', 'SOL1:solenoid_field_scale'
        
        Settable attributes can also be given:
        
        ['stop'] = 1.2345 will set Impact.stop = 1.2345
        
        """

        # Set attributes
        if hasattr(self, key):
            setattr(self, key, item)
            return

        # Must be in input or group
        name, attrib = key.split(':')
        if name in self.input:
            self.input[name][attrib] = item
        elif name in self.group:
            self.group[name][attrib] = item
        else:
            raise ValueError(
                f'{name} does not exist in eles or groups of the Impact object.'
            )
Ejemplo n.º 3
0
class GPT:
    """ 
    GPT simulation object. Essential methods:
    .__init__(...)
    .configure()
    .run()
    
    Input deck is held in .input
    Output data is parsed into .output
    .load_screens() will load particle data into .screen[...]
    
    The GPT binary file can be set on init. If it doesn't exist, configure will check the
        $GPT_BIN
    environmental variable.
    
    
    """
    def __init__(self,
                 input_file=None,
                 initial_particles=None,
                 gpt_bin='$GPT_BIN',
                 use_tempdir=True,
                 workdir=None,
                 timeout=None,
                 verbose=False,
                 ccs_beg='wcs',
                 ref_ccs=False,
                 kill_msgs=DEFAULT_KILL_MSGS):

        # Save init
        self.original_input_file = input_file
        self.initial_particles = initial_particles
        self.use_tempdir = use_tempdir

        self.workdir = workdir

        if workdir:
            assert os.path.exists(
                workdir), 'workdir does not exist: ' + workdir
            self.workdir = os.path.abspath(workdir)

        self.verbose = verbose
        self.gpt_bin = gpt_bin

        # These will be set
        self.log = []
        self.output = {}
        #self.screen = [] # list of screens
        self.timeout = timeout
        self.error = False

        # Run control
        self.finished = False
        self.configured = False
        self.using_tempdir = False

        self.ccs_beg = ccs_beg
        self.ref_ccs = ref_ccs
        self.kill_msgs = kill_msgs

        # Call configure
        if input_file:
            self.load_input(input_file)
            self.configure()
        else:
            self.vprint('Warning: Input file does not exist. Not configured.')

    def configure(self):
        """ Convenience wrapper for configure_gpt """
        self.configure_gpt(workdir=self.workdir)

    def configure_gpt(self, input_filePath=None, workdir=None):
        """ Configure the GPT object """
        if input_filePath:
            self.load_input(input_filePath)

        # Check that binary exists
        self.gpt_bin = tools.full_path(self.gpt_bin)
        assert os.path.exists(
            self.gpt_bin), 'ERROR: GPT binary does not exist:' + self.gpt_bin

        # Set paths
        if self.use_tempdir:

            # Need to attach this to the object. Otherwise it will go out of scope.
            self.tempdir = tempfile.TemporaryDirectory(dir=workdir)
            self.path = self.tempdir.name

        elif (workdir):

            # Use the top level of the provided workdir
            self.path = workdir

        else:

            # Work in location of the template file
            self.path = self.original_path

        self.input_file = os.path.join(self.path, self.original_input_file)

        parsers.set_support_files(self.input['lines'], self.original_path)

        self.vprint('GPT.configure_gpt:')
        self.vprint(
            f'   Original input file "{self.original_input_file}" in "{self.original_path}"'
        )
        self.vprint(f'   Configured to run in "{self.path}"')

        self.configured = True

    def load_input(self, input_filePath, absolute_paths=True):
        """ Load the GPT template file """
        f = tools.full_path(input_filePath)
        self.original_path, self.original_input_file = os.path.split(
            f)  # Get original path, filename
        self.input = parsers.parse_gpt_input_file(f)

    def get_dist_file(self):
        """ Find the distribution input file name in the GPT file """
        for line in self.input['lines']:
            if ('setfile' in line):
                return parse_gpt_string(line)[1]

    def set_dist_file(self, dist_file):
        """ Set the input distirbution file name in a GPT file """
        dist_file_set = False
        for ii, line in enumerate(self.input['lines']):
            if ('setfile' in line):
                gpt_strs = parse_gpt_string(line)
                assert len(
                    gpt_strs
                ) == 2, "Couldn't find distribution input file strs."
                assert gpt_strs[
                    0] == 'beam', "Could not find beam defintion in setfile str."
                self.input['lines'][ii] = f'setfile("beam", "{dist_file}");'
                dist_file_set = True

        if (not dist_file_set):
            self.input['lines'].append(f'setfile("beam", "{dist_file}");')

    def set_variable(self, variable, value):
        """ Set variable in the GPT input file to a new value """
        if (variable in self.input["variables"]):
            self.input['variables'][variable] = value
            return True
        else:
            return False

    def set_variables(self, variables):
        """ Set a list of variables (variable.keys) to new values (variables.values()) in the GPT Input file """
        return {
            var: self.set_variable(var, variables[var])
            for var in variables.keys()
        }

    def load_output(self, file='gpt.out.gdf'):
        """ loads the GPT raw data and puts it into particle groups """

        self.vprint(f'   Loading GPT data from {self.get_gpt_output_file()}')
        touts, screens = parsers.read_gdf_file(file,
                                               self.verbose)  # Raw GPT data

        self.output['particles'] = raw_data_to_particle_groups(
            touts, screens, verbose=self.verbose, ref_ccs=self.ref_ccs)
        self.output['n_tout'] = len(touts)
        self.output['n_screen'] = len(screens)

    @property
    def n_tout(self):
        """ number of tout particle groups """
        return self.output['n_tout']

    @property
    def n_screen(self):
        """ number of screen particle groups"""
        return self.output['n_screen']

    @property
    def tout(self):
        """ Returns output particle groups for touts """
        if ('particles' in self.output):
            return self.output['particles'][:self.output['n_tout']]

    @property
    def tout_ccs(self):
        """ Returns output particle groups for touts transformed into centroid coordinate system """
        if ('particles' in self.output):
            return [
                transform_to_centroid_coordinates(tout) for tout in self.tout
            ]

    @property
    def s_ccs(self):

        s = [
            np.sqrt(self.tout[0]['mean_x']**2 + self.tout[0]['mean_y']**2 +
                    self.tout[0]['mean_z']**2)
        ]

        current_tout = self.tout[0]
        current_p = current_tout['mean_p']

        for next_tout in self.tout[1:]:

            next_p = next_tout['mean_p']

            if (np.abs(current_p - next_p) / current_p <
                    1e-5):  # Beam drifting

                beta = current_tout['mean_beta']
                dt = next_tout['mean_t'] - current_tout['mean_t']
                ds = dt * beta * c

            else:  # Assume straight line acceleration

                dx = next_tout['mean_x'] - current_tout['mean_x']
                dy = next_tout['mean_y'] - current_tout['mean_y']
                dz = next_tout['mean_z'] - current_tout['mean_z']

                ds = np.sqrt(dx**2 + dy**2 + dz**2)

            s.append(s[-1] + ds)

            current_tout = next_tout

        return np.array(s)

    def tout_stat(self, key=None):
        """ Returns array of stats for key from tout particle groups """
        return self.stat(key, data_type='tout')

    def tout_ccs_stat(self, key=None):
        """ Returns array of stats for key from tout particle groups """
        return self.stat(key, data_type='tout_ccs')

    @property
    def screen(self):
        """ Returns output particle groups for screens """
        if ('particles' in self.output):
            return self.output['particles'][self.output['n_tout']:]

    def screen_stat(self, key):
        """ Returns array of stats for key from screen particle groups """
        return self.stat(key, data_type='screen')

    @property
    def particles(self):
        """ Returns output particle groups for touts + screens """
        if ('particles' in self.output):
            return self.output['particles']

    def trajectory(self, pid, data_type='tout'):
        """ Returns a 3d particle trajectory for particle with id = pid """
        if (data_type == 'tout'):
            particle_groups = self.tout
        elif (data_type == 'screen'):
            particle_groups = self.screen
        else:
            raise ValueError(
                f'GPT.trajectory got an unsupported data type = {data_type}.')

        #for pg in particle_groups:
        #    print(pg, pid in pg['id'])
        pgs_with_pid = [pg for pg in particle_groups if (pid in pg['id'])]

        if (len(pgs_with_pid) == 0):
            return None

        variables = ['x', 'y', 'z', 'px', 'py', 'pz', 't']

        trajectory = {
            var: np.zeros((len(pgs_with_pid), ))
            for var in variables
        }

        for ii, pg in enumerate(pgs_with_pid):
            for var in variables:
                trajectory[var][ii] = pg[var][pg['id'] == pid]

        return trajectory

    def run(self, gpt_verbose=False):
        """ performs a basic GPT simulation configured in the current GPT object """

        if not self.configured:
            self.configure()
        #pass
        self.run_gpt(verbose=self.verbose,
                     timeout=self.timeout,
                     gpt_verbose=gpt_verbose)

    def get_run_script(self, write_to_path=True):
        """
        Assembles the run script. Optionally writes a file 'run' with this line to path.
        """

        _, infile = os.path.split(self.input_file)

        tokens = infile.split('.')
        if (len(tokens) > 1):
            outfile = '.'.join(tokens[:-1]) + '.out.gdf'
        else:
            outfile = tokens[0] + '.out.gdf'

        runscript = [
            self.gpt_bin, '-j1', '-v', '-o',
            self.get_gpt_output_file(), self.input_file
        ]

        if write_to_path:
            with open(os.path.join(self.path, 'run'), 'w') as f:
                f.write(' '.join(runscript))

        return runscript

    def get_gpt_output_file(self):
        """ get the name of the GPT output file """
        path, infile = os.path.split(self.input_file)
        tokens = infile.split('.')
        if (len(tokens) > 1):
            outfile = '.'.join(tokens[:-1]) + '.out.gdf'
        else:
            outfile = tokens[0] + '.out.gdf'
        return os.path.join(path, outfile)

    def run_gpt(self,
                verbose=False,
                parse_output=True,
                timeout=None,
                gpt_verbose=False):
        """ RUN GPT and read in results """
        self.vprint('GPT.run_gpt:')

        run_info = {}
        t1 = time()
        run_info['start_time'] = t1

        if self.initial_particles:
            fname = self.write_initial_particles()
            #print(fname)

            # Link input file to new particle file
            self.set_dist_file(fname)

        #init_dir = os.getcwd()
        self.vprint(f'   Running GPT...')

        # Write input file from internal dict
        self.write_input_file()

        runscript = self.get_run_script()

        self.vprint(f'   Running with timeout = {self.timeout} sec.')
        run_time, exception, log = tools.execute(runscript,
                                                 kill_msgs=self.kill_msgs,
                                                 timeout=timeout,
                                                 verbose=gpt_verbose)

        if (exception is not None):
            self.error = True
            run_info["error"] = True
            run_info['why_error'] = exception.strip()

        self.log = log

        if parse_output:
            self.load_output(file=self.get_gpt_output_file())

        run_info['run_time'] = time() - t1
        run_info['run_error'] = self.error
        self.vprint(
            f'   Run finished, total time ellapsed: {run_info["run_time"]:G} (sec)'
        )

        # Add run_info
        self.output.update(run_info)
        self.finished = True

    def fingerprint(self):
        """
        Data fingerprint using the input. 
        """
        return tools.fingerprint(self.input)

    def vprint(self, *args, **kwargs):
        # Verbose print
        if self.verbose:
            print(*args, **kwargs)

    def plot(self,
             y=['sigma_x', 'sigma_y'],
             x='mean_z',
             xlim=None,
             y2=[],
             nice=True,
             include_layout=False,
             include_labels=False,
             include_particles=True,
             include_legend=True,
             **kwargs):
        """
        Convenience plotting function for making nice plots.
        
        """
        plot_stats_with_layout(self,
                               ykeys=y,
                               ykeys2=y2,
                               xkey=x,
                               xlim=xlim,
                               nice=nice,
                               include_layout=include_layout,
                               include_labels=include_labels,
                               include_legend=include_legend,
                               **kwargs)

    def stat(self, key, data_type='all'):
        """
        Calculates any statistic that the ParticleGroup class can calculate, on all particle groups, or just touts, or screens
        """
        if (data_type == 'all'):
            particle_groups = self.output['particles']

        elif (data_type == 'tout'):
            particle_groups = self.tout

        elif (data_type == 'tout_ccs'):
            particle_groups = self.tout_ccs

        elif (data_type == 'screen'):
            particle_groups = self.screen

        else:
            raise ValueError(f'Unsupported GPT data type: {data_type}')

        return particle_stats(particle_groups, key)

    def units(self, key):
        """
        Calculates any statistic that the ParticleGroup class can calculate, on all particle groups.
        """
        """Returns a str decribing the physical units of a stat key."""
        return pg_units(key)

    def write_input_file(self):
        """ Write the updated GPT input file """
        self.vprint(f'   Writing gpt input file to "{self.input_file}"')
        parsers.write_gpt_input_file(self.input, self.input_file, self.ccs_beg)

    def write_initial_particles(self, fname=None):
        """ Write the initial particle data to file for use with GPT """
        if not fname:
            fname = os.path.join(self.path, 'gpt.particles.gdf')
        self.initial_particles.write_gpt(fname,
                                         asci2gdf_bin='$ASCI2GDF_BIN',
                                         verbose=False)
        self.vprint(f'   Initial particles written to "{fname}"')
        return fname

    def load_initial_particles(self, h5):
        """Loads a openPMD-beamphysics particle h5 handle or file"""
        P = ParticleGroup(h5=h5)
        self.initial_particles = P

    def load_archive(self, h5=None):
        """
        Loads input and output from archived h5 file.
        
        See: GPT.archive
        """
        if isinstance(h5, str):
            h5 = os.path.expandvars(h5)
            g = h5py.File(h5, 'r')

            glist = gpt.archive.find_gpt_archives(g)
            n = len(glist)
            if n == 0:
                # legacy: try top level
                message = 'legacy'
            elif n == 1:
                gname = glist[0]
                message = f'group {gname} from'
                g = g[gname]
            else:
                raise ValueError(
                    f'Multiple archives found in file {h5}: {glist}')

            self.vprint(f'Reading {message} archive file {h5}')
        else:
            g = h5

        self.input = gpt.archive.read_input_h5(g['input'])

        if 'initial_particles' in g:
            self.initial_particles = ParticleGroup(g['initial_particles'])

        self.output = gpt.archive.read_output_h5(g['output'])

        self.vprint(
            'Loaded from archive. Note: Must reconfigure to run again.')
        self.configured = False

    def archive(self, h5=None):
        """
        Archive all data to an h5 handle or filename.
        
        If no file is given, a file based on the fingerprint will be created.
        
        """
        if not h5:
            h5 = 'gpt_' + self.fingerprint() + '.h5'

        if isinstance(h5, str):
            h5 = os.path.expandvars(h5)
            g = h5py.File(h5, 'w')
            self.vprint(f'Archiving to file {h5}')
        else:
            # store directly in the given h5 handle
            g = h5

        # Write basic attributes
        gpt.archive.gpt_init(g)

        # All input
        gpt.archive.write_input_h5(g, self.input, name='input')

        if self.initial_particles:
            self.initial_particles.write(g, name='initial_particles')

        # All output
        gpt.archive.write_output_h5(g, self.output, name='output')

        return h5

    @classmethod
    def from_archive(cls, archive_h5):
        """
        Class method to return an GPT object loaded from an archive file
        """
        c = cls()
        c.load_archive(archive_h5)
        return c

    def __str__(self):

        outstr = '\nGPT object:'

        if (self.configured):
            outstr = outstr + "\n   Original input file: " + self.original_input_file
            outstr = outstr + "\n   Template location: " + self.original_path

        if (self.workdir):
            outstr = outstr + "\n   Top level work dir: " + self.workdir

        if (self.use_tempdir):
            outstr = outstr + f"\n   Use temp directory: {self.use_tempdir}"
        #else:
        #    outstr = outstr+f"\n   Work directory: {self.path}"

        # Run control
        outstr = outstr + "\n\nRun Control"
        outstr = outstr + f"\n   Run configured: {self.configured}"

        if (self.configured):
            outstr = outstr + f"\n   Work location: {self.path}"
            outstr = outstr + f"\n   Timeout: {self.timeout} (sec)"

            # Results
            outstr = outstr + "\n\nResults"
            outstr = outstr + f"\n   Finished: {self.finished}"
            outstr = outstr + f"\n   Error occured: {self.error}"
            if (self.error):
                outstr = outstr + f'\n   Cause: {self.output["why_error"]}'
                errline = self.get_syntax_error_line(self.output["why_error"])
                if (errline):
                    outstr = outstr + f'\n   Suspected input file line: "{errline}"'
            rtime = self.output['run_time']
            outstr = outstr + f'\n   Run time: {rtime} (sec)'

        #outstr = outstr+f"\n
        #outstr = outstr+f'\n   Log: {self.log}\n'
        return outstr

    def get_syntax_error_line(self, error_msg):

        s = error_msg.strip().replace('\n', '')
        if (s.endswith('Error: syntax error')):
            error_line_index = int(s[s.find("(") + 1:s.find(")")])
            return self.input['lines'][error_line_index]
        else:
            return None

    def track(self, particles, s=None, output='tout'):
        return track(self, particles, s=s, output=output)

    def track1(self,
               x0=0,
               px0=0,
               y0=0,
               py0=0,
               z0=0,
               pz0=1e-15,
               t0=0,
               s=None,
               species='electron',
               output='tout'):
        return track1(self,
                      x0=x0,
                      px0=px0,
                      y0=y0,
                      py0=py0,
                      z0=z0,
                      pz0=pz0,
                      t0=t0,
                      species=species,
                      s=s,
                      output=output)

    def track1_to_z(self,
                    z_end,
                    ds=0,
                    ccs_beg='wcs',
                    ccs_end='wcs',
                    x0=0,
                    px0=0,
                    y0=0,
                    py0=0,
                    z0=0,
                    pz0=1e-15,
                    t0=0,
                    weight=1,
                    status=1,
                    species='electron',
                    s_screen=0):
        return track1_to_z(self,
                           z_end=z_end,
                           ds=ds,
                           ccs_beg=ccs_beg,
                           ccs_end=ccs_end,
                           x0=x0,
                           px0=px0,
                           y0=y0,
                           py0=py0,
                           z0=z0,
                           pz0=pz0,
                           t0=t0,
                           weight=weight,
                           status=status,
                           species=species,
                           s_screen=s_screen)

    def track1_in_ccs(self,
                      z_beg=0,
                      z_end=0,
                      ccs='wcs',
                      x0=0,
                      px0=0,
                      y0=0,
                      py0=0,
                      pz0=1e-15,
                      t0=0,
                      weight=1,
                      status=1,
                      species='electron',
                      xacc=6.5,
                      GBacc=6.5,
                      workdir=None,
                      use_tempdir=True,
                      n_screen=1,
                      s_beg=0):

        return track1_in_ccs(self,
                             z_beg=z_beg,
                             z_end=z_end,
                             ccs=ccs,
                             x0=x0,
                             px0=px0,
                             y0=y0,
                             py0=py0,
                             pz0=pz0,
                             t0=t0,
                             weight=weight,
                             status=status,
                             species=species,
                             xacc=xacc,
                             GBacc=GBacc,
                             workdir=workdir,
                             use_tempdir=use_tempdir,
                             n_screen=n_screen,
                             s_beg=s_beg)

    def get_zminmax_line(self, z_beg, z_end, ccs='wcs'):
        return get_zminmax_line(self, z_beg, z_end, ccs=ccs)

    def copy(self):
        """
        Returns a deep copy of this object.
        
        If a tempdir is being used, will clear this and deconfigure. 
        """
        G2 = deepcopy(self)
        # Clear this
        if G2.use_tempdir:
            G2.path = None
            G2.configured = False

        return G2
Ejemplo n.º 4
0
class Generator:
    """
    This class defines the main run engine object for distgen and is responsible for
    1. Parsing the input data dictionary passed from a Reader object
    2. Check the input for internal consistency
    3. Collect the parameters for distributions requested in the params dictionary 
    4. Form a the Beam object and populated the particle phase space coordinates
    """
    def __init__(self, input=None, verbose=0):
        """
        The class initialization takes in a verbose level for controlling text output to the user
        """
        self.verbose = verbose

        self.input = input

        # This will be set with .beam()
        self.rands = None

        # This will be set with .run()
        self.particles = None

        if input:
            self.parse_input(input)
            self.configure()

    def parse_input(self, input):
        """
        Parse the input structure passed from a Reader object.  
        The structure is then converted to an easier form for use in populating the Beam object.
        
        YAML or JSON is accepted if params is a filename (str)
        
        Relative paths for input 'file' keys will be expanded.
        """
        if isinstance(input, str):
            if os.path.exists(os.path.expandvars(input)):
                # File
                filename = full_path(input)
                with open(filename) as fid:
                    input = yaml.safe_load(fid)
                # Fill any 'file' keys
                expand_input_filepaths(input,
                                       root=os.path.split(filename)[0],
                                       ignore_keys=['output'])

            else:
                #Try raw string
                input = yaml.safe_load(input)
                assert isinstance(
                    input, dict
                ), f'ERROR: parsing unsuccessful, could not read {input}'
                expand_input_filepaths(input)

        self.input = input

    def configure(self):
        """ Configures the generator for creating a 6d particle distribution:
        1. Copies the input dictionary read in from a file or passed directly
        2. Converts physical quantities to PINT quantities in the params dictionary
        3. Runs consistency checks on the resulting params dict
        """

        self.params = copy.deepcopy(self.input)  # Copy the input dictionary
        if ('start' not in self.params):
            self.params['start'] = {'type': 'free'}
        convert_params(
            self.params
        )  # Conversion of the input dictionary using tools.convert_params
        self.check_input_consistency(
            self.params)  # Check that the result is logically sound

    def check_input_consistency(self, params):
        ''' Perform consistency checks on the user input data'''

        # Make sure all required top level params are present
        required_params = ['n_particle', 'random_type', 'total_charge']
        for rp in required_params:
            assert rp in params, 'Required generator parameter ' + rp + ' not found.'

        # Check that only allowed params present at top level
        allowed_params = required_params + ['output', 'transforms', 'start']
        for p in params:
            #assert p in allowed_params or '_dist'==p[-5:], 'Unexpected distgen input parameter: ' + p[-5:]
            assert p in allowed_params or p.endswith(
                '_dist'), 'Unexpected distgen input parameter: ' + p

        assert params[
            'n_particle'] > 0, 'User must speficy n_particle must > 0.'

        # Check consistency of transverse coordinate definitions
        if (("r_dist" in params) or ("x_dist" in params)
                or ("xy_dist" in params)):
            assert ("r_dist" in params) ^ ("x_dist" in params) ^ (
                "xy_dist" in params
            ), "User must specify only one transverse distribution."
        if (("r_dist" in params) or ("y_dist" in params)
                or ("xy_dist" in params)):
            assert ("r_dist" in params) ^ ("y_dist" in params) ^ (
                "xy_dist"
                in params), "User must specify r dist OR y dist NOT BOTH."

        if (params['start']['type'] == "cathode"):

            vprint("Ignoring user specified z distribution for cathode start.",
                   self.verbose > 0 and "z_dist" in params, 0, True)
            vprint(
                "Ignoring user specified px distribution for cathode start.",
                self.verbose > 0 and "px_dist" in params, 0, True)
            vprint(
                "Ignoring user specified py distribution for cathode start.",
                self.verbose > 0 and "py_dist" in params, 0, True)
            vprint(
                "Ignoring user specified pz distribution for cathode start.",
                self.verbose > 0 and "pz_dist" in params, 0, True)

            assert "MTE" in params[
                'start'], "User must specify the MTE for cathode start."

            # Handle momentum distribution for cathode
            MTE = self.params['start']["MTE"]
            sigma_pxyz = (np.sqrt((MTE / MC2).to_reduced_units()) *
                          unit_registry("GB")).to("eV/c")

            self.params["px_dist"] = {"type": "g", "sigma_px": sigma_pxyz}
            self.params["py_dist"] = {"type": "g", "sigma_py": sigma_pxyz}
            self.params["pz_dist"] = {"type": "g", "sigma_pz": sigma_pxyz}

        elif (params['start']['type'] == 'time'):

            vprint("Ignoring user specified t distribution for time start.",
                   self.verbose > 0 and "t_dist" in params, 0, True)
            if ('t_dist' in params):
                warnings.warn(
                    'Ignoring user specified t distribution for time start.')
                self.params.pop('t_dist')

        if ('output' in self.params):
            out_params = self.params["output"]
            for op in out_params:
                assert op in ['file', 'type'
                              ], f'Unexpected output parameter specified: {op}'
        else:
            self.params['output'] = {"type": None}

    def __getitem__(self, varstr):
        return get_nested_dict(self.input, varstr, sep=':', prefix='distgen')

    def __setitem__(self, varstr, val):
        return set_nested_dict(self.input,
                               varstr,
                               val,
                               sep=':',
                               prefix='distgen')

    def get_dist_params(self):
        """ Loops through the input params dict and collects all distribution definitions """

        dist_vars = [
            p.replace('_dist', '') for p in self.params
            if (p.endswith('_dist'))
        ]
        dist_params = {
            p.replace('_dist', ''): self.params[p]
            for p in self.params if (p.endswith('_dist'))
        }

        if ('r' in dist_vars and 'theta' not in dist_vars):
            vprint("Assuming cylindrical symmetry...", self.verbose > 0, 1,
                   True)
            dist_params['theta'] = {
                'type': 'ut',
                'min_theta': 0 * unit_registry('rad'),
                'max_theta': 2 * pi
            }

        if (self.params['start']['type'] == 'time'
                and 't_dist' in self.params):
            raise ValueError('Error: t_dist should not be set for time start')

        return dist_params

    def get_rands(self, variables):
        """ Gets random numbers [0,1] for the coordinatess in variables 
        using either the Hammersley sequence or rand """

        specials = ['xy']
        self.rands = {var: None for var in variables if var not in specials}

        if ('xy' in variables):
            self.rands['x'] = None
            self.rands['y'] = None

        elif ('r' in variables and 'theta' not in variables):
            self.rands['theta'] = None

        n_coordinate = len(self.rands.keys())
        n_particle = int(self.params['n_particle'])
        shape = (n_coordinate, n_particle)

        if (n_coordinate > 0):
            rns = random_generator(shape, sequence=self.params['random_type'])

        for ii, key in enumerate(self.rands.keys()):
            if (len(rns.shape) > 1):
                self.rands[key] = rns[ii, :] * unit_registry('dimensionless')
            else:
                self.rands[key] = rns[:] * unit_registry('dimensionless')

        var_list = list(self.rands.keys())
        for ii, vii in enumerate(var_list[:-1]):
            viip1 = var_list[ii + 1]
            assert (
                not np.array_equal(self.rands[vii].magnitude,
                                   self.rands[viip1].magnitude)
            ) or n_particle == 1, f'Error: coordinate probalitiies for {vii} and {viip1} are the same!'

            # These lines can be used to check for unwanted correlations
            #v0 = self.rands[vii].magnitude-self.rands[vii].magnitude.mean()
            #v1 = self.rands[viip1].magnitude-self.rands[viip1].magnitude.mean()
            #print( np.mean(v0*v1) )

    def beam(self):
        """ Creates a 6d particle distribution and returns it in a distgen.beam class """

        watch = StopWatch()
        watch.start()

        self.configure()

        verbose = self.verbose
        #outputfile = []

        beam_params = {
            'total_charge': self.params['total_charge'],
            'n_particle': self.params['n_particle']
        }

        if ('transforms' in self.params):
            transforms = self.params['transforms']
        else:
            transforms = None

        #dist_params = {p.replace('_dist',''):self.params[p] for p in self.params if(p.endswith('_dist')) }
        #self.get_rands()

        vprint(f'Distribution format: {self.params["output"]["type"]}',
               self.verbose > 0, 0, True)

        N = int(self.params['n_particle'])
        bdist = Beam(**beam_params)

        if ("file" in self.params['output']):
            outfile = self.params['output']["file"]
        else:
            outfile = "None"
            vprint(
                f'Warning: no output file specified, defaulting to "{outfile}".',
                verbose > 0, 1, True)
        vprint(f'Output file: {outfile}', verbose > 0, 0, True)

        vprint('\nCreating beam distribution....', verbose > 0, 0, True)
        vprint(f"Beam starting from: {self.params['start']['type']}",
               verbose > 0, 1, True)
        vprint(f'Total charge: {bdist.q:G~P}.', verbose > 0, 1, True)
        vprint(f'Number of macroparticles: {N}.', verbose > 0, 1, True)

        units = {
            'x': 'm',
            'y': 'm',
            'z': 'm',
            'px': 'eV/c',
            'py': 'eV/c',
            'pz': 'eV/c',
            't': 's'
        }

        # Initialize coordinates to zero
        for var, unit in units.items():
            bdist[var] = np.full(N, 0.0) * unit_registry(units[var])

        bdist["w"] = np.full((N, ), 1 / N) * unit_registry("dimensionless")

        avgs = {var: 0 * unit_registry(units[var]) for var in units}
        stds = {var: 0 * unit_registry(units[var]) for var in units}

        dist_params = self.get_dist_params(
        )  # Get the relevant dist params, setting defaults as needed, and samples random number generator
        self.get_rands(list(dist_params.keys()))

        # Do radial dist first if requested
        if ('r' in dist_params and 'theta' in dist_params):

            vprint('r distribution: ', verbose > 0, 1, False)

            # Get r distribution
            rdist = get_dist('r', dist_params['r'], verbose=verbose)

            if (rdist.rms() > 0):
                r = rdist.cdfinv(
                    self.rands['r'])  # Sample to get beam coordinates

            # Sample to get beam coordinates
            vprint('theta distribution: ', verbose > 0, 1, False)
            theta_dist = get_dist('theta',
                                  dist_params['theta'],
                                  verbose=verbose)
            theta = theta_dist.cdfinv(self.rands['theta'])

            rrms = rdist.rms()
            avgr = rdist.avg()

            avgCos = 0
            avgSin = 0
            avgCos2 = 0.5
            avgSin2 = 0.5

            bdist['x'] = r * np.cos(theta)
            bdist['y'] = r * np.sin(theta)

            avgs['x'] = avgr * avgCos
            avgs['y'] = avgr * avgSin

            stds['x'] = rrms * np.sqrt(avgCos2)
            stds['y'] = rrms * np.sqrt(avgSin2)

            # remove r, theta from list of distributions to sample
            del dist_params['r']
            del dist_params['theta']

        # Do 2D distributions
        if ("xy" in dist_params):

            vprint('xy distribution: ', verbose > 0, 1, False)
            dist = get_dist('xy', dist_params['xy'], verbose=verbose)
            bdist['x'], bdist['y'] = dist.cdfinv(self.rands['x'],
                                                 self.rands['y'])

            dist_params.pop('xy')

            avgs['x'] = bdist.avg('x')
            avgs['y'] = bdist.avg('y')

            stds['x'] = bdist.std('x')
            stds['y'] = bdist.std('y')

        # Do all other specified single coordinate dists
        for x in dist_params.keys():

            vprint(x + " distribution: ", verbose > 0, 1, False)
            dist = get_dist(x, dist_params[x],
                            verbose=verbose)  # Get distribution

            if (dist.std() > 0):

                # Only reach here if the distribution has > 0 size
                bdist[x] = dist.cdfinv(
                    self.rands[x])  # Sample to get beam coordinates

                # Fix up the avg and std so they are exactly what user asked for
                if ("avg_" + x in dist_params[x]):
                    avgs[x] = dist_params[x]["avg_" + x]
                else:
                    avgs[x] = dist.avg()

                stds[x] = dist.std()
                #if("sigma_"+x in dist_params[x]):
                #    stds[x] = dist_params[x]["sigma_"+x]
                #else:
                #stds[x] = dist.std()
                #print(x, stds[x])

        # Shift and scale coordinates to undo sampling error
        for x in avgs:

            vprint(f'Shifting avg_{x} = {bdist.avg(x):G~P} -> {avgs[x]:G~P}',
                   verbose > 0 and bdist[x].mean() != avgs[x], 1, True)
            vprint(f'Scaling sigma_{x} = {bdist.std(x):G~P} -> {stds[x]:G~P}',
                   verbose > 0 and bdist[x].std() != stds[x], 1, True)

            #bdist = transform(bdist, {'type':f'set_avg_and_std {x}', 'avg_'+x:avgs[x],'sigma_'+x:stds[x], 'verbose':0})
            bdist = set_avg_and_std(
                bdist, **{
                    'variables': x,
                    'avg_' + x: avgs[x],
                    'sigma_' + x: stds[x],
                    'verbose': 0
                })

        # Handle any start type specific settings
        if (self.params['start']['type'] == "cathode"):

            bdist['pz'] = np.abs(bdist['pz'])  # Only take forward hemisphere
            vprint('Cathode start: fixing pz momenta to forward hemisphere',
                   verbose > 0, 1, True)
            vprint(
                f'avg_pz -> {bdist.avg("pz"):G~P}, sigma_pz -> {bdist.std("pz"):G~P}',
                verbose > 0, 2, True)

        elif (self.params['start']['type'] == 'time'):

            if ('tstart' in self.params['start']):
                tstart = self.params['start']['tstart']

            else:
                vprint(
                    "Time start: no start time specified, defaulting to 0 sec.",
                    verbose > 0, 1, True)
                tstart = 0 * unit_registry('sec')

            vprint(
                f'Time start: fixing all particle time values to start time: {tstart:G~P}.',
                verbose > 0, 1, True)
            bdist = set_avg(
                bdist, **{
                    'variables': 't',
                    'avg_t': 0.0 * unit_registry('sec'),
                    'verbose': verbose > 0
                })

        elif (self.params['start']['type'] == 'free'):
            pass

        else:
            raise ValueError(
                f'Beam start type "{self.params["start"]["type"]}" is not supported!'
            )

        # Apply any user desired coordinate transformations
        if (transforms):

            # Check if the user supplied the transform order, otherwise just go through the dictionary
            if ('order' in transforms):
                order = transforms['order']
                if (not isinstance(order, list)):
                    raise ValueError(
                        'Transform "order" key must be associated a list of transform IDs'
                    )
                del transforms['order']
            else:
                order = transforms.keys()

            for name in order:

                T = transforms[name]
                T['verbose'] = verbose > 0
                vprint(
                    f'Applying user supplied transform: "{name}" = {T["type"]}...',
                    verbose > 0, 1, True)
                bdist = transform(bdist, T)

        watch.stop()
        vprint(f'...done. Time Ellapsed: {watch.print()}.\n', verbose > 0, 0,
               True)
        return bdist

    def run(self):
        """ Runs the generator.beam function stores the partice in 
        an openPMD-beamphysics ParticleGroup in self.particles """
        beam = self.beam()
        self.particles = ParticleGroup(data=beam.data())
        vprint(f'Created particles in .particles: \n   {self.particles}',
               self.verbose > 0, 1, False)

    def fingerprint(self):
        """
        Data fingerprint using the input. 
        """
        return fingerprint(self.input)

    def load_archive(self, h5=None):
        """
        Loads input and output from archived h5 file.
        
        
        
        
        See: Generator.archive
        
        
        """
        if isinstance(h5, str):
            g = h5py.File(h5, 'r')

            glist = archive.find_distgen_archives(g)
            n = len(glist)
            if n == 0:
                # legacy: try top level
                message = 'legacy'
            elif n == 1:
                gname = glist[0]
                message = f'group {gname} from'
                g = g[gname]
            else:
                raise ValueError(
                    f'Multiple archives found in file {h5}: {glist}')

            vprint(f'Reading {message} archive file {h5}', self.verbose > 0, 1,
                   False)
        else:
            g = h5

            vprint(f'Reading Distgen archive file {h5}', self.verbose > 0, 1,
                   False)

        self.input = archive.read_input_h5(g['input'])

        if 'particles' in g:
            self.particles = ParticleGroup(g['particles'])
        else:
            vprint(f'No particles found.', self.verbose > 0, 1, False)

    def archive(self, h5=None):
        """
        Archive all data to an h5 handle or filename.
        
        If no file is given, a file based on the fingerprint will be created.
        
        """
        if not h5:
            h5 = 'distgen_' + self.fingerprint() + '.h5'

        if isinstance(h5, str):
            g = h5py.File(h5, 'w')
            # Proper openPMD init
            pmd_init(g, basePath='/', particlesPath='particles/')
            g.attrs['software'] = np.string_('distgen')  # makes a fixed string
            #TODO: add version: g.attrs('version') = np.string_(__version__)

        else:
            g = h5

        # Init
        archive.distgen_init(g)

        # Input
        archive.write_input_h5(g, self.input, name='input')

        # Particles
        if self.particles:
            self.particles.write(g, name='particles')

        return h5

    def __repr__(self):
        s = '<disgten.Generator with input: \n'
        return s + yaml.dump(self.input) + '\n>'

    def check_inputs(self, params):
        """ Checks the params sent to the generator only contain allowed inputs """

        # Make sure user isn't passing the wrong parameters:
        allowed_params = self.optional_params + self.required_params + [
            'verbose'
        ]
        for param in params:
            assert param in allowed_params, 'Incorrect param given to ' + self.__class__.__name__ + '.__init__(**kwargs): ' + param + '\nAllowed params: ' + str(
                allowed_params)

        # Make sure all required parameters are specified
        for req in self.required_params:
            assert req in params, 'Required input parameter ' + req + ' to ' + self.__class__.__name__ + '.__init__(**kwargs) was not found.'