Exemple #1
0
def read_plot3d_grid(grid_file, multiblock=True, dim=3, blanking=False,
                     planes=False, binary=True, big_endian=False,
                     single_precision=True, unformatted=True, logger=None):
    """
    Returns a :class:`DomainObj` initialized from Plot3D `grid_file`.

    grid_file: string
        Grid filename.
    """
    logger = logger or NullLogger()
    domain = DomainObj()

    mode = 'rb' if binary else 'r'
    with open(grid_file, mode) as inp:
        logger.info('reading grid file %r', grid_file)
        stream = Stream(inp, binary, big_endian, single_precision, False,
                        unformatted, False)

        # Read zone dimensions.
        shape = _read_plot3d_shape(stream, multiblock, dim, logger)

        # Read zone coordinates.
        for i in range(len(shape)):
            zone = domain.add_zone('', Zone())
            name = domain.zone_name(zone)
            logger.debug('reading coordinates for %s', name)
            _read_plot3d_coords(zone, stream, shape[i], blanking, planes,
                                logger)
    return domain
Exemple #2
0
def load_from_eggpkg(package, entry_group, entry_name, instance_name=None,
                     logger=None, observer=None):
    """
    Load object graph state by invoking the given package entry point.
    Returns the root object.

    package: string
        Name of package to load from.

    entry_group: string
        Name of group.

    entry_name: string
        Name of entry point in group.

    instance_name: string
        Name for instance loaded.

    logger: Logger
        Used for recording progress, etc.

    observer: callable
        Called via an :class:`EggObserver`.
    """
    logger = logger or NullLogger()
    observer = EggObserver(observer, logger)
    logger.debug('Loading %s from %s in %s...',
                 entry_name, package, os.getcwd())
    try:
        dist = pkg_resources.get_distribution(package)
    except pkg_resources.DistributionNotFound as exc:
        logger.error('Distribution not found: %s', exc)
        raise exc
    return _load_from_distribution(dist, entry_group, entry_name, instance_name,
                                   logger, observer)
Exemple #3
0
def load(instream, fmt=SAVE_CPICKLE, package=None, logger=None):
    """
    Load object(s) from an input stream (or filename).
    If `instream` is a string that is not an existing filename or
    absolute path, then it is searched for using :mod:`pkg_resources`.
    Returns the root object.

    instream: file or string
        Stream or filename to load from.

    fmt: int
        Format of state data.

    package: string
        Name of package to use.

    logger: Logger
        Used for recording progress, etc.
    """
    logger = logger or NullLogger()

    new_stream = False
    if isinstance(instream, basestring):
        if not os.path.exists(instream) and not os.path.isabs(instream):
            # Try to locate via pkg_resources.
            if not package:
                dot = instream.rfind('.')
                if dot < 0:
                    raise ValueError("Bad state filename '%s'." % instream)
                package = instream[:dot]
            logger.debug("Looking for '%s' in package '%s'", instream, package)
            path = pkg_resources.resource_filename(package, instream)
            if not os.path.exists(path):
                raise IOError("State file '%s' not found." % instream)
            instream = path

            # The state file assumes a sys.path.
            package_dir = os.path.dirname(path)
            if not package_dir in sys.path:
                sys.path.append(package_dir)

        if fmt is SAVE_CPICKLE or fmt is SAVE_PICKLE:
            mode = 'rb'
        else:
            mode = 'rU'
        instream = open(instream, mode)
        new_stream = True

    try:
        if fmt is SAVE_CPICKLE:
            top = cPickle.load(instream)
        elif fmt is SAVE_PICKLE:
            top = pickle.load(instream)
        else:
            raise RuntimeError("Can't load object using format '%s'" % fmt)
    finally:
        if new_stream:
            instream.close()

    return top
Exemple #4
0
def write_authorized_keys(allowed_users, filename, logger=None):
    """
    Write `allowed_users` to `filename` in ssh format.
    The file will be made private if supported on this platform.

    allowed_users: dict
        Dictionary of public keys indexed by user.

    filename: string
        File to write to.

    logger: :class:`logging.Logger`
        Used for log messages.
    """
    logger = logger or NullLogger()

    with open(filename, 'w') as out:
        for user in sorted(allowed_users.keys()):
            pubkey = allowed_users[user]
            buf = 'ssh-rsa'
            key_data = _longstr(len(buf), 4)
            key_data += buf
            buf = _longstr(pubkey.e)
            key_data += _longstr(len(buf), 4)
            key_data += buf
            buf = _longstr(pubkey.n)
            key_data += _longstr(len(buf), 4)
            key_data += buf
            data = base64.b64encode(key_data)
            out.write('ssh-rsa %s %s\n\n' % (data, user))

    if sys.platform == 'win32' and not HAVE_PYWIN32:  # pragma no cover
        logger.warning("Can't make authorized keys file %r private", filename)
    else:
        make_private(filename)
Exemple #5
0
    def is_equivalent(self, other, logger=None, tolerance=0.):
        """
        Test if self and `other` are equivalent.

        other: :class:`DomainObj`
            The domain to check against.

        logger: Logger or None
            Used to log debug messages that will indicate what if anything
            is not equivalent.

        tolerance: float
            The maximum relative difference in array values to be considered
            equivalent.
        """
        logger = logger or NullLogger()
        if not isinstance(other, DomainObj):
            logger.debug('other is not a DomainObj object.')
            return False

        if len(self.zones) != len(other.zones):
            logger.debug('zone count mismatch.')
            return False

        for zone in self.zones:
            name = self.zone_name(zone)
            try:
                other_zone = getattr(other, name)
            except AttributeError:
                logger.debug('other is missing zone %r.', name)
                return False
            if not zone.is_equivalent(other_zone, logger, tolerance):
                logger.debug('zone %r equivalence failed.', name)
                return False
        return True
Exemple #6
0
def unpack_zipfile(filename, logger=None, textfiles=None):
    """
    Unpack 'zip' file `filename`.
    Returns ``(nfiles, nbytes)``.

    filename: string
        Name of zip file to unpack.

    logger: Logger
        Used for recording progress.

    textfiles: list
        List of :mod:`fnmatch` style patterns specifying which unpacked files
        are text files possibly needing newline translation. If not supplied,
        the first 4KB of each is scanned for a zero byte. If none is found, then the
        file is assumed to be a text file.
    """
    logger = logger or NullLogger()

    # ZipInfo.create_system code for local system.
    local_system = 0 if sys.platform == 'win32' else 3

    nfiles = 0
    nbytes = 0
    with zipfile.ZipFile(filename, 'r') as zipped:
        for info in zipped.infolist():
            filename, size = info.filename, info.file_size
            logger.debug('unpacking %r (%d)...', filename, size)
            zipped.extract(info)

            if sys.platform != 'win32':
                # Set permissions, extract() doesn't.
                rwx = (info.external_attr >> 16) & 0777
                if rwx:
                    os.chmod(filename, rwx)  # Only if something valid.

            # Requires mismatched systems.
            if info.create_system != local_system:  # pragma no cover
                if textfiles is None:
                    with open(filename, 'rb') as inp:
                        data = inp.read(1 << 12)
                    if '\0' not in data:
                        logger.debug('translating %r...', filename)
                        translate_newlines(filename)
                else:
                    for pattern in textfiles:
                        if fnmatch.fnmatch(filename, pattern):
                            logger.debug('translating %r...', filename)
                            translate_newlines(filename)
            nfiles += 1
            nbytes += size

    return (nfiles, nbytes)
def save(root, outstream, fmt=SAVE_CPICKLE, proto=-1, logger=None):
    """
    Save the state of `root` and its children to an output stream (or filename).
    If `outstream` is a string, then it is used as a filename.
    The format can be supplied in case something other than :mod:`cPickle`
    is needed.  For the :mod:`pickle` formats, a `proto` of -1 means use the
    highest protocol.

    root: object
        The root of the object tree to save.

    outstream: file or string
        Stream or filename to save to.

    fmt: int
        What format to save the object state in.

    proto: int
        What protocol to use when pickling.

    logger: Logger
        Used for recording progress, etc.
    """
    logger = logger or NullLogger()

    if isinstance(outstream, basestring):
        if (fmt is SAVE_CPICKLE or fmt is SAVE_PICKLE) and proto != 0:
            mode = 'wb'
        else:
            mode = 'w'
        try:
            outstream = open(outstream, mode)
        except IOError as exc:
            raise type(exc)("Can't save to '%s': %s" %
                            (outstream, exc.strerror))
    if fmt is SAVE_CPICKLE:
        cPickle.dump(root, outstream, proto)
    elif fmt is SAVE_PICKLE:
        pickle.dump(root, outstream, proto)
    #elif fmt is SAVE_YAML:
    #yaml.dump(root, outstream)
    #elif fmt is SAVE_LIBYAML:
    ## Test machines have libyaml.
    #if _libyaml is False:  #pragma no cover
    #logger.warning('libyaml not available, using yaml instead')
    #yaml.dump(root, outstream, Dumper=Dumper)
    else:
        raise RuntimeError("Can't save object using format '%s'" % fmt)
Exemple #8
0
def read_plot3d_shape(grid_file, multiblock=True, dim=3, binary=True,
                      big_endian=False, unformatted=True, logger=None):
    """
    Returns a list of zone dimensions from Plot3D `grid_file`.

    grid_file: string
        Grid filename.
    """
    logger = logger or NullLogger()

    mode = 'rb' if binary else 'r'
    with open(grid_file, mode) as inp:
        logger.info('reading grid file %r', grid_file)
        stream = Stream(inp, binary, big_endian, True, False,
                        unformatted, False)
        return _read_plot3d_shape(stream, multiblock, dim, logger)
Exemple #9
0
def check_requirements(required, logger=None, indent_level=0):
    """
    Display requirements (if logger debug level enabled) and note conflicts.
    Returns a list of unavailable requirements.

    required: list
        List of package requirements.

    logger: Logger
        Used for recording progress, etc.

    indent_level: int
        Used to improve readability of log messages.
    """
    def _recursive_check(required, logger, level, visited, working_set,
                         not_avail):
        indent = '    ' * level
        indent2 = '    ' * (level + 1)
        for req in required:
            logger.log(LOG_DEBUG2, '%schecking %s', indent, req)
            dist = None
            try:
                dist = working_set.find(req)
            # Difficult to generate a distribution that can't be reloaded.
            except pkg_resources.VersionConflict:  # pragma no cover
                dist = working_set.by_key[req.key]
                logger.debug('%sconflicts with %s %s', indent2,
                             dist.project_name, dist.version)
                not_avail.append(req)
            else:
                # Difficult to generate a distribution that can't be reloaded.
                if dist is None:  # pragma no cover
                    logger.debug('%sno distribution found', indent2)
                    not_avail.append(req)
                else:
                    logger.log(LOG_DEBUG2, '%s%s %s', indent2,
                               dist.project_name, dist.version)
                    if not dist in visited:
                        visited.add(dist)
                        _recursive_check(dist.requires(), logger, level + 1,
                                         visited, working_set, not_avail)

    logger = logger or NullLogger()
    not_avail = []
    _recursive_check(required, logger, indent_level, set(),
                     pkg_resources.WorkingSet(), not_avail)
    return not_avail
Exemple #10
0
def write_plot3d_grid(domain, grid_file, planes=False, binary=True,
                      big_endian=False, single_precision=True,
                      unformatted=True, logger=None):
    """
    Writes `domain` to `grid_file` in Plot3D format.
    Ghost data is not written.

    domain: :class:`DomainObj` or :class:`Zone`
        The domain or zone to be written.

    grid_file: string
        Grid filename.
    """
    logger = logger or NullLogger()

    if isinstance(domain, DomainObj):
        writing_domain = True
        zones = domain.zones
    elif isinstance(domain, Zone):
        writing_domain = False
        zones = [domain]
    else:
        raise TypeError("'domain' argument must be a DomainObj or Zone")

    mode = 'wb' if binary else 'w'
    with open(grid_file, mode) as out:
        logger.info('writing grid file %r', grid_file)
        stream = Stream(out, binary, big_endian, single_precision, False,
                        unformatted, False)
        if len(zones) > 1:
            # Write number of zones.
            stream.write_int(len(zones), full_record=True)

        # Write zone dimensions.
        _write_plot3d_dims(domain, stream, logger)

        # Write zone coordinates.
        for zone in zones:
            if writing_domain:
                name = domain.zone_name(zone)
            else:
                name = 'zone'
            logger.debug('writing coords for %s', name)
            _write_plot3d_coords(zone, stream, planes, logger)
Exemple #11
0
def pack_zipfile(patterns, filename, logger=None):
    """
    Create 'zip' file `filename` of files in `patterns`.
    Returns ``(nfiles, nbytes)``.

    patterns: list
        List of :mod:`fnmatch` style patterns.

    filename: string
        Name of zip file to create.

    logger: Logger
        Used for recording progress.

    .. note::
        The code uses :meth:`glob.glob` to process `patterns`.
        It does not check for the existence of any matches.

    """
    logger = logger or NullLogger()

    # Scan to see if we have to use zip64 flag.
    nbytes = 0
    for pattern in patterns:
        for path in glob.glob(pattern):
            nbytes += os.path.getsize(path)
    zip64 = nbytes > zipfile.ZIP64_LIMIT
    compression = zipfile.ZIP_DEFLATED

    nfiles = 0
    nbytes = 0
    zipped = zipfile.ZipFile(filename, 'w', compression, zip64)
    try:
        for pattern in patterns:
            for path in glob.glob(pattern):
                size = os.path.getsize(path)
                logger.debug("packing '%s' (%d)...", path, size)
                zipped.write(path)
                nfiles += 1
                nbytes += size
    finally:
        zipped.close()
    return (nfiles, nbytes)
Exemple #12
0
def load_from_eggfile(filename,
                      entry_group,
                      entry_name,
                      logger=None,
                      observer=None):
    """
    Extracts files in egg to a subdirectory matching the saved object name.
    Then loads object graph state by invoking the given entry point.
    Returns the root object.

    filename: string
        Name of egg file.

    entry_group: string
        Name of group.

    entry_name: string
        Name of entry point in group.

    logger: Logger
        Used for recording progress, etc.

    observer: callable
        Called via an :class:`EggObserver`.
    """
    logger = logger or NullLogger()
    observer = EggObserver(observer, logger)
    logger.debug('Loading %s from %s in %s...', entry_name, filename,
                 os.getcwd())

    egg_dir, dist = _dist_from_eggfile(filename, logger, observer)

    # Just being defensive, '.' is typically in the path.
    if not '.' in sys.path:  # pragma no cover
        sys.path.append('.')
    orig_dir = os.getcwd()
    os.chdir(egg_dir)
    try:
        return _load_from_distribution(dist, entry_group, entry_name, None,
                                       logger, observer)
    finally:
        os.chdir(orig_dir)
Exemple #13
0
def read_plot3d_q(grid_file, q_file, multiblock=True, dim=3, blanking=False,
                  planes=False, binary=True, big_endian=False,
                  single_precision=True, unformatted=True, logger=None):
    """
    Returns a :class:`DomainObj` initialized from Plot3D `grid_file` and
    `q_file`.  Q variables are assigned to 'density', 'momentum', and
    'energy_stagnation_density'.  Scalars are assigned to 'mach', 'alpha',
    'reynolds', and 'time'.

    grid_file: string
        Grid filename.

    q_file: string
        Q data filename.
    """
    logger = logger or NullLogger()

    domain = read_plot3d_grid(grid_file, multiblock, dim, blanking, planes,
                              binary, big_endian, single_precision,
                              unformatted, logger)

    mode = 'rb' if binary else 'r'
    with open(q_file, mode) as inp:
        logger.info('reading Q file %r', q_file)
        stream = Stream(inp, binary, big_endian, single_precision, False,
                        unformatted, False)
        if multiblock:
            # Read number of zones.
            nblocks = stream.read_int(full_record=True)
        else:
            nblocks = 1
        if nblocks != len(domain.zones):
            raise RuntimeError('Q zones %d != Grid zones %d'
                               % (nblocks, len(domain.zones)))

        # Read zone dimensions.
        if unformatted:
            reclen = stream.read_recordmark()
            expected = stream.reclen_ints(dim * nblocks)
            if reclen != expected:
                logger.warning('unexpected dimensions recordlength'
                               ' %d vs. %d', reclen, expected)
        for zone in domain.zones:
            name = domain.zone_name(zone)
            imax, jmax, kmax = _read_plot3d_dims(stream, dim)
            if dim > 2:
                logger.debug('    %s: %dx%dx%d', name, imax, jmax, kmax)
                zone_i, zone_j, zone_k = zone.shape
                if imax != zone_i or jmax != zone_j or kmax != zone_k:
                    raise RuntimeError('%s: Q %dx%dx%d != Grid %dx%dx%d'
                                       % (name, imax, jmax, kmax,
                                          zone_i, zone_j, zone_k))
            else:
                logger.debug('    %s: %dx%d', name, imax, jmax)
                zone_i, zone_j = zone.shape
                if imax != zone_i or jmax != zone_j:
                    raise RuntimeError('%s: Q %dx%d != Grid %dx%d'
                                       % (name, imax, jmax, zone_i, zone_j))
        if unformatted:
            reclen2 = stream.read_recordmark()
            if reclen2 != reclen:
                logger.warning('mismatched dimensions recordlength'
                               ' %d vs. %d', reclen2, reclen)

        # Read zone scalars and variables.
        for zone in domain.zones:
            name = domain.zone_name(zone)
            logger.debug('reading data for %s', name)
            _read_plot3d_qscalars(zone, stream, logger)
            _read_plot3d_qvars(zone, stream, planes, logger)

    return domain
    def test_null_logger(self):
        logging.debug('')
        logging.debug('test_null_logger')

        logger = NullLogger()
        logger.debug('debug message')
        logger.info('info message')
        logger.warning('warning message')
        logger.error('error message')
        logger.critical('critical message')
        logger.log(1, 'logged at level 1')
Exemple #15
0
def get_key_pair(user_host,
                 logger=None,
                 overwrite_cache=False,
                 ignore_ssh=False):
    """
    Returns RSA key containing both public and private keys for the user
    identified in `user_host`.  This can be an expensive operation, so
    we avoid generating a new key pair whenever possible.
    If ``~/.ssh/id_rsa`` exists and is private, that key is returned.

    user_host: string
        Format ``user@host``.

    logger: :class:`logging.Logger`
        Used for debug messages.

    overwrite_cache: bool
        If True, a new key is generated and forced into the cache of existing
        known keys.  Used for testing.

    ignore_ssh: bool
        If True, ignore any existing ssh id_rsa key file.  Used for testing.

    .. note::

        To avoid unnecessary key generation, the public/private key pair for
        the current user is stored in the private file ``~/.openmdao/keys``.
        On Windows this requires the pywin32 extension.  Also, the public
        key is stored in ssh form in ``~/.openmdao/id_rsa.pub``.

    """
    logger = logger or NullLogger()

    with _KEY_CACHE_LOCK:
        if overwrite_cache:
            key_pair = _generate(user_host, logger)
            _KEY_CACHE[user_host] = key_pair
            return key_pair

        # Look in previously generated keys.
        try:
            key_pair = _KEY_CACHE[user_host]
        except KeyError:
            # If key for current user (typical), check filesystem.
            # TODO: file lock to protect from separate processes.
            user, host = user_host.split('@')
            if user == getpass.getuser():
                current_user = True
                key_pair = None

                # Try to re-use SSH key. Exceptions should *never* be exercised!
                if not ignore_ssh:
                    id_rsa = \
                        os.path.expanduser(os.path.join('~', '.ssh', 'id_rsa'))
                    if is_private(id_rsa):
                        try:
                            with open(id_rsa, 'r') as inp:
                                key_pair = RSA.importKey(inp.read())
                        except Exception as exc:  # pragma no cover
                            logger.warning('ssh id_rsa import: %r', exc)
                        else:
                            generate = False
                    else:  # pragma no cover
                        logger.warning('Ignoring insecure ssh id_rsa.')

                if key_pair is None:
                    # Look for OpenMDAO key.
                    key_file = \
                        os.path.expanduser(os.path.join('~', '.openmdao', 'keys'))
                    if is_private(key_file):
                        try:
                            with open(key_file, 'rb') as inp:
                                key_pair = cPickle.load(inp)
                        except Exception:
                            generate = True
                        else:
                            generate = False
                    else:
                        logger.warning('Insecure keyfile! Regenerating keys.')
                        os.remove(key_file)
                        generate = True

            # Difficult to run test as non-current user.
            else:  # pragma no cover
                current_user = False
                generate = True

            if generate:
                key_pair = _generate(user_host, logger)
                if current_user:
                    key_dir = os.path.dirname(key_file)
                    if not os.path.exists(key_dir):
                        os.mkdir(key_dir)

                    # Save key pair in protected file.
                    if sys.platform == 'win32' and not HAVE_PYWIN32:  # pragma no cover
                        logger.debug('No pywin32, not saving keyfile')
                    else:
                        make_private(key_dir)  # Private while writing keyfile.
                        with open(key_file, 'wb') as out:
                            cPickle.dump(key_pair, out,
                                         cPickle.HIGHEST_PROTOCOL)
                        try:
                            make_private(key_file)
                        # Hard to cause (recoverable) error here.
                        except Exception:  # pragma no cover
                            os.remove(key_file)  # Remove unsecured file.
                            raise

                    # Save public key in ssh form.
                    users = {user_host: key_pair.publickey()}
                    filename = os.path.join(key_dir, 'id_rsa.pub')
                    write_authorized_keys(users, filename, logger)

            _KEY_CACHE[user_host] = key_pair

    return key_pair
Exemple #16
0
def read_authorized_keys(filename=None, logger=None):
    """
    Return dictionary of public keys, indexed by user, read from `filename`.
    The file must be in ssh format, and only RSA keys are processed.
    If the file is not private, then no keys are returned.

    filename: string
        File to read from. The default is ``~/.ssh/authorized_keys``.

    logger: :class:`logging.Logger`
        Used for log messages.
    """
    if not filename:
        filename = \
            os.path.expanduser(os.path.join('~', '.ssh', 'authorized_keys'))

    logger = logger or NullLogger()

    if not os.path.exists(filename):
        raise RuntimeError('%r does not exist' % filename)

    if not is_private(filename):
        if sys.platform != 'win32' or HAVE_PYWIN32:
            raise RuntimeError('%r is not private' % filename)
        else:  # pragma no cover
            logger.warning('Allowed users file %r is not private', filename)

    errors = 0
    keys = {}
    with open(filename, 'r') as inp:
        for line in inp:
            line = line.rstrip()
            sharp = line.find('#')
            if sharp >= 0:
                line = line[:sharp]
            if not line:
                continue

            key_type, blank, rest = line.partition(' ')
            if key_type != 'ssh-rsa':
                logger.error('unsupported key type: %r', key_type)
                errors += 1
                continue

            key_data, blank, user_host = rest.partition(' ')
            if not key_data:
                logger.error('bad line (missing key data):')
                logger.error(line)
                errors += 1
                continue

            try:
                user, host = user_host.split('@')
            except ValueError:
                logger.error('bad line (require user@host):')
                logger.error(line)
                errors += 1
                continue

            logger.debug('user %r, host %r', user, host)
            try:
                ip_addr = socket.gethostbyname(host)
            except socket.gaierror:
                logger.warning('unknown host %r', host)
                logger.warning(line)

            data = base64.b64decode(key_data)
            start = 0
            name_len = _longint(data, start, 4)
            start += 4
            name = data[start:start + name_len]
            if name != 'ssh-rsa':
                logger.error('name error: %r vs. ssh-rsa', name)
                logger.error(line)
                errors += 1
                continue

            start += name_len
            e_len = _longint(data, start, 4)
            start += 4
            e = _longint(data, start, e_len)
            start += e_len
            n_len = _longint(data, start, 4)
            start += 4
            n = _longint(data, start, n_len)
            start += n_len
            if start != len(data):
                logger.error('length error: %d vs. %d', start, len(data))
                logger.error(line)
                errors += 1
                continue

            try:
                pubkey = RSA.construct((n, e))
            except Exception as exc:
                logger.error('key construct error: %r', exc)
                errors += 1
            else:
                keys[user_host] = pubkey

    if errors:
        raise RuntimeError('%d errors in %r, check log for details' %
                           (errors, filename))
    return keys
    def test_null_logger(self):
        logging.debug('')
        logging.debug('test_null_logger')

        logger = NullLogger()
        logger.debug('debug message')
        logger.info('info message')
        logger.warning('warning message')
        logger.error('error message')
        logger.critical('critical message')
        logger.log(1, 'logged at level 1')
def save_to_egg(entry_pts,
                version=None,
                py_dir=None,
                src_dir=None,
                src_files=None,
                dst_dir=None,
                logger=None,
                observer=None,
                need_requirements=True):
    """
    Save state and other files to an egg. Analyzes the objects saved for
    distribution dependencies.  Modules not found in any distribution are
    recorded in an ``egg-info/openmdao_orphans.txt`` file.  Also creates and
    saves loader scripts for each entry point.

    entry_pts: list
        List of ``(obj, obj_name, obj_group)`` tuples.
        The first of these specifies the root object and package name.

    version: string
        Defaults to a timestamp of the form 'YYYY.MM.DD.HH.mm'.

    py_dir: string
        The (root) directory for local Python files.
        It defaults to the current directory.

    src_dir: string
        The root of all (relative) `src_files`.

    dst_dir: string
        The directory to write the egg in.

    observer: callable
        Will be called via an :class:`EggObserver` intermediary.

    need_requirements: bool
        If True, distributions required by the egg will be determined.
        This can be set False if the egg is just being used for distribution
        within the local host.

    Returns ``(egg_filename, required_distributions, orphan_modules)``.
    """
    root, name, group = entry_pts[0]
    logger = logger or NullLogger()
    observer = eggobserver.EggObserver(observer, logger)

    orig_dir = os.getcwd()

    if not py_dir:
        py_dir = orig_dir
    else:
        py_dir = os.path.realpath(py_dir)
    if sys.platform == 'win32':  # pragma no cover
        py_dir = py_dir.lower()

    if src_dir:
        src_dir = os.path.realpath(src_dir)
        if sys.platform == 'win32':  # pragma no cover
            src_dir = src_dir.lower()

    src_files = src_files or set()

    if not version:
        now = datetime.datetime.now()  # Could consider using utcnow().
        version = '%d.%02d.%02d.%02d.%02d' % \
                  (now.year, now.month, now.day, now.hour, now.minute)

    dst_dir = dst_dir or orig_dir
    if not os.access(dst_dir, os.W_OK):
        msg = "Can't save to '%s', no write permission" % dst_dir
        observer.exception(msg)
        raise IOError(msg)

    egg_name = eggwriter.egg_filename(name, version)
    logger.debug('Saving to %s in %s...', egg_name, orig_dir)

    # Clone __main__ as a 'real' module and have classes reference that.
    _fix_main(logger)

    # Get a list of all objects we'll be saving.
    objs = _get_objects(root, logger)

    # Check that each object can be pickled.
    _check_objects(objs, logger, observer)

    # Verify that no __main__ references are still around.
    _verify_objects(root, logger, observer)

    tmp_dir = None
    try:
        if need_requirements:
            # Determine distributions and local modules required.
            required_distributions, local_modules, orphan_modules = \
                _get_distributions(objs, py_dir, logger, observer)
        else:
            required_distributions = set()
            local_modules = set()
            orphan_modules = set()
            # Collect Python modules.
            for dirpath, dirnames, filenames in os.walk(py_dir):
                dirs = copy.copy(dirnames)
                for path in dirs:
                    if not os.path.exists(
                            os.path.join(dirpath, path, '__init__.py')):
                        dirnames.remove(path)
                for path in filenames:
                    if path.endswith('.py'):
                        local_modules.add(os.path.join(dirpath, path))

        # Ensure module defining root is local. Saving a root which is defined
        # in a package can be hidden if the package was installed.
        root_mod = root.__class__.__module__
        root_mod = sys.modules[root_mod].__file__
        if root_mod.endswith('.pyc') or root_mod.endswith('.pyo'):
            root_mod = root_mod[:-1]
        local_modules.add(os.path.abspath(root_mod))

        logger.log(LOG_DEBUG2, '    py_dir: %s', py_dir)
        logger.log(LOG_DEBUG2, '    src_dir: %s', src_dir)
        logger.log(LOG_DEBUG2, '    local_modules:')
        for module in sorted(local_modules):
            mod = module
            if mod.startswith(py_dir):
                mod = mod[len(py_dir) + 1:]
            logger.log(LOG_DEBUG2, '        %s', mod)

        # Move to scratch area.
        tmp_dir = tempfile.mkdtemp(prefix='Egg_', dir=tmp_dir)
        os.chdir(tmp_dir)
        os.mkdir(name)
        cleanup_files = []

        try:
            # Copy external files from src_dir.
            if src_dir:
                for path in src_files:
                    subdir = os.path.dirname(path)
                    if subdir:
                        subdir = os.path.join(name, subdir)
                        if not os.path.exists(subdir):
                            os.makedirs(subdir)
                    src = os.path.join(src_dir, path)
                    dst = os.path.join(name, path)
                    if sys.platform == 'win32':  # pragma no cover
                        shutil.copy2(src, dst)
                    else:
                        os.symlink(src, dst)

            # Copy local modules from py_dir.
            for path in local_modules:
                if not os.path.exists(
                        os.path.join(name, os.path.basename(path))):
                    if not os.path.isabs(path):
                        path = os.path.join(py_dir, path)
                    shutil.copy(path, name)

            # For each entry point...
            entry_info = []
            for obj, obj_name, obj_group in entry_pts:
                clean_name = obj_name
                if clean_name.startswith(name + '.'):
                    clean_name = clean_name[len(name) + 1:]
                clean_name = clean_name.replace('.', '_')

                # Save state of object hierarchy.
                state_name, state_path = \
                    _write_state_file(name, obj, clean_name, logger, observer)
                src_files.add(state_name)
                cleanup_files.append(state_path)

                # Create loader script.
                loader = '%s_loader' % clean_name
                loader_path = os.path.join(name, loader + '.py')
                cleanup_files.append(loader_path)
                _write_loader_script(loader_path, state_name, name,
                                     obj is root)

                entry_info.append((obj_group, obj_name, loader))

            # If needed, make an empty __init__.py
            init_path = os.path.join(name, '__init__.py')
            if not os.path.exists(init_path):
                cleanup_files.append(init_path)
                out = open(init_path, 'w')
                out.close()

            # Save everything to an egg.
            doc = root.__doc__ or ''
            entry_map = _create_entry_map(entry_info)
            orphans = [mod for mod, path in orphan_modules]
            eggwriter.write(name, version, doc, entry_map, src_files,
                            required_distributions, orphans, dst_dir, logger,
                            observer.observer)
        finally:
            for path in cleanup_files:
                if os.path.exists(path):
                    os.remove(path)
    finally:
        os.chdir(orig_dir)
        if tmp_dir:
            shutil.rmtree(tmp_dir, onerror=onerror)

    return (egg_name, required_distributions, orphan_modules)
Exemple #19
0
def read_plot3d_f(grid_file, f_file, varnames=None, multiblock=True, dim=3,
                  blanking=False, planes=False, binary=True, big_endian=False,
                  single_precision=True, unformatted=True, logger=None):
    """
    Returns a :class:`DomainObj` initialized from Plot3D `grid_file` and
    `f_file`.  Variables are assigned to names of the form `f_N`.

    grid_file: string
        Grid filename.

    f_file: string
        Function data filename.
    """
    logger = logger or NullLogger()

    domain = read_plot3d_grid(grid_file, multiblock, dim, blanking, planes,
                              binary, big_endian, single_precision,
                              unformatted, logger)

    mode = 'rb' if binary else 'r'
    with open(f_file, mode) as inp:
        logger.info('reading F file %r', f_file)
        stream = Stream(inp, binary, big_endian, single_precision, False,
                        unformatted, False)
        if multiblock:
            # Read number of zones.
            nblocks = stream.read_int(full_record=True)
        else:
            nblocks = 1
        if nblocks != len(domain.zones):
            raise RuntimeError('F zones %d != Grid zones %d'
                               % (nblocks, len(domain.zones)))

        # Read zone dimensions.
        if unformatted:
            reclen = stream.read_recordmark()
            expected = stream.reclen_ints((dim+1) * nblocks)
            if reclen != expected:
                logger.warning('unexpected dimensions recordlength'
                               ' %d vs. %d', reclen, expected)
        for zone in domain.zones:
            name = domain.zone_name(zone)
            imax, jmax, kmax, nvars = _read_plot3d_dims(stream, dim, True)
            if dim > 2:
                logger.debug('    %s: %dx%dx%d %d',
                             name, imax, jmax, kmax, nvars)
                zone_i, zone_j, zone_k = zone.shape
                if imax != zone_i or jmax != zone_j or kmax != zone_k:
                    raise RuntimeError('%s: F %dx%dx%d != Grid %dx%dx%d'
                                       % (name, imax, jmax, kmax,
                                          zone_i, zone_j, zone_k))
            else:
                logger.debug('    %s: %dx%d %d', name, imax, jmax, nvars)
                zone_i, zone_j = zone.shape
                if imax != zone_i or jmax != zone_j:
                    raise RuntimeError('%s: F %dx%d != Grid %dx%d'
                                       % (name, imax, jmax, zone_i, zone_j))
        if unformatted:
            reclen2 = stream.read_recordmark()
            if reclen2 != reclen:
                logger.warning('mismatched dimensions recordlength'
                               ' %d vs. %d', reclen2, reclen)

        # Read zone variables.
        for zone in domain.zones:
            name = domain.zone_name(zone)
            logger.debug('reading data for %s', name)
            _read_plot3d_fvars(zone, stream, dim, nvars, varnames, planes,
                               logger)
    return domain
Exemple #20
0
def write_plot3d_q(domain, grid_file, q_file, planes=False, binary=True,
                   big_endian=False, single_precision=True, unformatted=True,
                   logger=None):
    """
    Writes `domain` to `grid_file` and `q_file` in Plot3D format.
    Requires 'density', 'momentum', and 'energy_stagnation_density' variables
    as well as 'mach', 'alpha', 'reynolds', and 'time' scalars.
    Ghost data is not written.

    domain: :class:`DomainObj` or :class:`Zone`
        The domain or zone to be written.

    grid_file: string
        Grid filename.

    q_file: string
        Q data filename.
    """
    logger = logger or NullLogger()

    if isinstance(domain, DomainObj):
        writing_domain = True
        zones = domain.zones
    elif isinstance(domain, Zone):
        writing_domain = False
        zones = [domain]
    else:
        raise TypeError("'domain' argument must be a DomainObj or Zone")

    # Verify we have the needed data.
    for zone in zones:
        flow = zone.flow_solution
        missing = []
        for name in ('mach', 'alpha', 'reynolds', 'time',
                     'density', 'momentum', 'energy_stagnation_density'):
            if not hasattr(flow, name):
                missing.append(name)
        if missing:
            if writing_domain:
                name = domain.zone_name(zone)
            else:
                name = ''
            raise AttributeError('zone %s flow_solution is missing %s'
                                 % (name, missing))
    # Write grid file.
    write_plot3d_grid(domain, grid_file, planes, binary, big_endian,
                      single_precision, unformatted, logger)
    # Write Q file.
    mode = 'wb' if binary else 'w'
    with open(q_file, mode) as out:
        logger.info('writing Q file %r', q_file)
        stream = Stream(out, binary, big_endian, single_precision, False,
                        unformatted, False)
        if len(zones) > 1:
            # Write number of zones.
            stream.write_int(len(zones), full_record=True)

        # Write zone dimensions.
        _write_plot3d_dims(domain, stream, logger)

        # Write zone scalars and variables.
        varnames = ('density', 'momentum', 'energy_stagnation_density')
        for zone in zones:
            if writing_domain:
                name = domain.zone_name(zone)
            else:
                name = 'zone'
            logger.debug('writing data for %s', name)
            _write_plot3d_qscalars(zone, stream, logger)
            _write_plot3d_vars(zone, stream, varnames, planes, logger)
Exemple #21
0
def read_q(grid_file, q_file, multiblock=True, blanking=False, logger=None):
    """
    Read grid and solution files.
    Returns a :class:`DomainObj` initialized from `grid_file` and `q_file`.

    grid_file: string
        Grid filename.

    q_file: string
        Q data filename.
    """
    logger = logger or NullLogger()

    domain = read_plot3d_grid(grid_file,
                              multiblock,
                              dim=3,
                              blanking=blanking,
                              planes=False,
                              binary=True,
                              big_endian=False,
                              single_precision=False,
                              unformatted=True,
                              logger=logger)

    with open(q_file, 'rb') as inp:
        logger.info("reading Q file '%s'", q_file)
        stream = Stream(inp,
                        binary=True,
                        big_endian=False,
                        single_precision=False,
                        integer_8=False,
                        unformatted=True,
                        recordmark_8=False)
        if multiblock:
            # Read number of zones.
            nblocks = stream.read_int(full_record=True)
        else:
            nblocks = 1
        if nblocks != len(domain.zones):
            raise RuntimeError('Q zones %d != Grid zones %d' \
                               % (nblocks, len(domain.zones)))

        # Read zone dimensions, nq, nqc.
        reclen = stream.read_recordmark()
        expected = stream.reclen_ints(3 * nblocks + 2)
        if reclen != expected:
            logger.warning('unexpected dimensions recordlength'
                           ' %d vs. %d', reclen, expected)

        for zone in domain.zones:
            name = domain.zone_name(zone)
            imax, jmax, kmax = stream.read_ints(3)
            if imax < 1 or jmax < 1 or kmax < 1:
                raise ValueError("invalid dimensions: %dx%dx%d" \
                                 % (imax, jmax, kmax))
            logger.debug('    %s: %dx%dx%d', name, imax, jmax, kmax)
            zone_i, zone_j, zone_k = zone.shape
            if imax != zone_i or jmax != zone_j or kmax != zone_k:
                raise RuntimeError('%s: Q %dx%dx%d != Grid %dx%dx%d' \
                                   % (name, imax, jmax, kmax,
                                      zone_i, zone_j, zone_k))

        nq, nqc = stream.read_ints(2)
        logger.debug('    nq %d, nqc %d', nq, nqc)

        reclen2 = stream.read_recordmark()
        if reclen2 != reclen:
            logger.warning('mismatched dimensions recordlength'
                           ' %d vs. %d', reclen2, reclen)

        # Read zone scalars and variables.
        for zone in domain.zones:
            name = domain.zone_name(zone)
            logger.debug('reading data for %s', name)
            _read_scalars(zone, nqc, stream, logger)
            _read_vars(zone, nq, nqc, stream, logger)

    return domain
Exemple #22
0
def write_plot3d_f(domain, grid_file, f_file, varnames=None, planes=False,
                   binary=True, big_endian=False, single_precision=True,
                   unformatted=True, logger=None):
    """
    Writes `domain` to `grid_file` and `f_file` in Plot3D format.
    If `varnames` is None, then all arrays and then all vectors are written.
    Ghost data is not written.

    domain: :class:`DomainObj` or :class:`Zone`
        The domain or zone to be written.

    grid_file: string
        Grid filename.

    f_file: string
        Function data filename.
    """
    logger = logger or NullLogger()

    if isinstance(domain, DomainObj):
        writing_domain = True
        zones = domain.zones
    elif isinstance(domain, Zone):
        writing_domain = False
        zones = [domain]
    else:
        raise TypeError("'domain' argument must be a DomainObj or Zone")

    if varnames is None:
        flow = zones[0].flow_solution
        varnames = [flow.name_of_obj(obj) for obj in flow.arrays]
        varnames.extend([flow.name_of_obj(obj) for obj in flow.vectors])

    # Verify we have the needed data.
    for zone in zones:
        flow = zone.flow_solution
        missing = []
        for name in varnames:
            if not hasattr(flow, name):
                missing.append(name)
        if missing:
            if writing_domain:
                name = domain.zone_name(zone)
            else:
                name = ''
            raise AttributeError('zone %s flow_solution is missing %s'
                                 % (name, missing))
    # Write grid file.
    write_plot3d_grid(domain, grid_file, planes, binary, big_endian,
                      single_precision, unformatted, logger)
    # Write F file.
    mode = 'wb' if binary else 'w'
    with open(f_file, mode) as out:
        logger.info('writing F file %r', f_file)
        stream = Stream(out, binary, big_endian, single_precision, False,
                        unformatted, False)
        if len(zones) > 1:
            # Write number of zones.
            stream.write_int(len(zones), full_record=True)

        # Write zone dimensions.
        _write_plot3d_dims(domain, stream, logger, varnames)

        # Write zone variables.
        for zone in zones:
            if writing_domain:
                name = domain.zone_name(zone)
            else:
                name = 'zone'
            logger.debug('writing data for %s', name)
            _write_plot3d_vars(zone, stream, varnames, planes, logger)