Beispiel #1
0
class _TableInternal(_InternalAction):
    """Implements the logic for a simple text based logger backend.

    This currently has to check the logged quantities every time to ensure it
    has not changed since the last run of `~.act`. Performance could be
    improved by allowing for writing of data without checking for a change in
    logged quantities, but would be more fragile.
    """

    _invalid_logger_categories = LoggerCategories.any([
        'sequence', 'object', 'particle', 'bond', 'angle', 'dihedral',
        'improper', 'pair', 'constraint', 'strings'
    ])

    def __init__(self,
                 logger,
                 output=stdout,
                 header_sep='.',
                 delimiter=' ',
                 pretty=True,
                 max_precision=10,
                 max_header_len=None):

        param_dict = ParameterDict(header_sep=str,
                                   delimiter=str,
                                   min_column_width=int,
                                   max_header_len=OnlyTypes(int,
                                                            allow_none=True),
                                   pretty=bool,
                                   max_precision=int,
                                   output=OnlyTypes(
                                       _OutputWriter,
                                       postprocess=_ensure_writable),
                                   logger=Logger)

        param_dict.update(
            dict(header_sep=header_sep,
                 delimiter=delimiter,
                 min_column_width=max(10, max_precision + 6),
                 max_header_len=max_header_len,
                 max_precision=max_precision,
                 pretty=pretty,
                 output=output,
                 logger=logger))
        self._param_dict = param_dict

        # internal variables that are not part of the state.
        # Ensure that only scalar and potentially string are set for the logger
        if (LoggerCategories.scalar not in logger.categories
                or logger.categories & self._invalid_logger_categories !=
                LoggerCategories.NONE):
            raise ValueError(
                "Given Logger must have the scalar categories set.")

        self._cur_headers_with_width = dict()
        self._fmt = _Formatter(pretty, max_precision)
        self._comm = None

    def _setattr_param(self, attr, value):
        """Makes self._param_dict attributes read only."""
        raise ValueError("Attribute {} is read-only.".format(attr))

    def attach(self, simulation):
        self._comm = simulation.device._comm

    def detach(self):
        self._comm = None

    def _get_log_dict(self):
        """Get a flattened dict for writing to output."""
        return {
            key: value[0]
            for key, value in dict_flatten(self.logger.log()).items()
        }

    def _update_headers(self, new_keys):
        """Update headers and write the current headers to output.

        This function could be made simpler and faster by moving some of the
        transformation to act. Since we don't expect the headers to change often
        however, this would likely slow the writer down. The design is to
        off-load any potnentially unnecessary calculations to this function even
        if that means more overall computation when headers change.
        """
        header_output_list = []
        header_dict = {}
        for namespace in new_keys:
            header = self._determine_header(namespace, self.header_sep,
                                            self.max_header_len)
            column_size = max(len(header), self.min_column_width)
            header_dict[namespace] = column_size
            header_output_list.append((header, column_size))
        self._cur_headers_with_width = header_dict
        self.output.write(
            self.delimiter.join((self._fmt.format_str(hdr, width)
                                 for hdr, width in header_output_list)))
        self.output.write('\n')

    @staticmethod
    def _determine_header(namespace, sep, max_len):
        if max_len is None:
            return sep.join(namespace)
        else:
            index = -1
            char_count = len(namespace[-1])
            for name in reversed(namespace[:-1]):
                char_count += len(name)
                if char_count > max_len:
                    break
                index -= 1
            return sep.join(namespace[index:])

    def _write_row(self, data):
        """Write a row of data to output."""
        headers = self._cur_headers_with_width
        self.output.write(
            self.delimiter.join(
                (self._fmt(data[k], headers[k]) for k in headers)))
        self.output.write('\n')

    def act(self, timestep=None):
        """Write row to designated output.

        Will also write header when logged quantities are determined to have
        changed.
        """
        output_dict = self._get_log_dict()
        if self._comm is not None and self._comm.rank == 0:
            # determine if a header needs to be written. This is always the case
            # for the first call of act, and if the logged quantities change
            # within a run.
            new_keys = output_dict.keys()
            if new_keys != self._cur_headers_with_width.keys():
                self._update_headers(new_keys)

            # Write the data and flush. We must flush to ensure that the data
            # isn't merely stored in Python ready to be written later.
            self._write_row(output_dict)
            self.output.flush()

    def __getstate__(self):
        state = copy.copy(self.__dict__)
        state.pop('_comm', None)
        # This is to handle when the output specified is just stdout. By default
        # file objects like this are not picklable, so we need to handle it
        # differently. We let `None` represent stdout in the state dictionary.
        # Most other file like objects will simply fail to be pickled here.
        if self.output == stdout:
            param_dict = ParameterDict()
            param_dict.update(state['_param_dict'])
            state['_param_dict'] = param_dict
            del state['_param_dict']['output']
            state['_param_dict']['output'] = None
            return state
        else:
            return super().__getstate__()

    def __setstate__(self, state):
        if state['_param_dict']['output'] is None:
            del state['_param_dict']['output']
            state['_param_dict']['output'] = stdout
            state['_param_dict']._type_converter['output'] = OnlyTypes(
                _OutputWriter, postprocess=_ensure_writable),
        self.__dict__ = state
Beispiel #2
0
class _GSDLogWriter:
    """Helper class to store `hoomd.logging.Logger` log data to GSD file.

    Class Attributes:
        _per_categories (`hoomd.logging.LoggerCategories`): category that
            contains all per-{particle,bond,...} quantities.
        _convert_categories (`hoomd.logging.LoggerCategories`): categories that
            contains all types that must be converted for storage in a GSD file.
        _skip_categories (`hoomd.logging.LoggerCategories`): categories that
            should be skipped by and not stored.
        _special_keys (`list` of `str`): list of loggable quantity names that
            need to be treated specially. In general, this is only for
            `type_shapes`.
        _global_prepend (`str`): a str that gets prepending into the namespace
            of each logged quantity.
    """
    _per_categories = LoggerCategories.any([
        'angle', 'bond', 'constraint', 'dihedral', 'improper', 'pair',
        'particle'
    ])
    _convert_categories = LoggerCategories.any(['string', 'strings'])
    _skip_categories = LoggerCategories['object']
    _special_keys = ['type_shapes']
    _global_prepend = 'log'

    def __init__(self, logger):
        self.logger = logger

    def log(self):
        """Get the flattened dictionary for consumption by GSD object."""
        log = dict()
        for key, value in dict_flatten(self.logger.log()).items():
            if 'state' in key and _iterable_is_incomplete(value[0]):
                pass
            log_value, type_category = value
            type_category = LoggerCategories[type_category]
            # This has to be checked first since type_shapes has a category
            # LoggerCategories.object.
            if key[-1] in self._special_keys:
                self._log_special(log, key[-1], log_value)
            # Now we can skip any categories we don't process, in this case
            # LoggerCategories.object.
            if type_category not in self._skip_categories:
                if log_value is None:
                    continue
                else:
                    # This places logged quantities that are
                    # per-{particle,bond,...} into the correct GSD namespace
                    # log/particles/{remaining namespace}. This preserves OVITO
                    # intergration.
                    if type_category in self._per_categories:
                        log['/'.join((self._global_prepend,
                                      type_category.name + 's') +
                                     key)] = log_value
                    elif type_category in self._convert_categories:
                        self._log_convert_value(
                            log, '/'.join((self._global_prepend, ) + key),
                            type_category, log_value)
                    else:
                        log['/'.join((self._global_prepend,) + key)] = \
                            log_value
            else:
                pass
        return log

    def _write_frame(self, _gsd):
        _gsd.writeLogQuantities(self.log())

    def _log_special(self, dict_, key, value):
        """Handles special keys such as type_shapes.

        When adding a key to this make sure this is the only option. In general,
        special cases like this should be avoided if possible.
        """
        if key == 'type_shapes':
            shape_list = [
                bytes(json.dumps(type_shape) + '\0', 'UTF-8')
                for type_shape in value
            ]
            max_len = np.max([len(shape) for shape in shape_list])
            num_shapes = len(shape_list)
            str_array = np.array(shape_list)
            dict_['particles/type_shapes'] = \
                str_array.view(dtype=np.int8).reshape(num_shapes, max_len)

    def _log_convert_value(self, dict_, key, category, value):
        """Convert loggable types that cannot be directly stored by GSD."""
        if category == LoggerCategories.string:
            value = bytes(value, 'UTF-8')
            value = np.array([value], dtype=np.dtype((bytes, len(value) + 1)))
            value = value.view(dtype=np.int8)
        elif category == LoggerCategories.strings:
            value = [bytes(v + '\0', 'UTF-8') for v in value]
            max_len = np.max([len(string) for string in value])
            num_strings = len(value)
            value = np.array(value)
            value = value.view(dtype=np.int8).reshape(num_strings, max_len)
        dict_[key] = value