示例#1
0
class BasicTrainConfig(BaseConfig):
    """Basic training configuration

    Required parameters:
    - train_dataloader (Iterable with length): training data loader
    - model (`torch.nn.Module`): model to train
    - optimizer (`torch.optim.Optimizer`): optimizer to use for training
    - criterion (`torch.nn.Module`): loss function to use for training
    - num_epochs (int): number of epochs

    Optional parameters:
    - metrics (dict): dictionary with ignite metrics, e.g `{'precision': Precision()}`
    - ...

    """
    train_dataloader = attr.ib(validator=is_iterable_with_length, default=None)

    model = attr.ib(validator=instance_of(nn.Module), default=None)
    optimizer = attr.ib(validator=instance_of(Optimizer), default=None)
    criterion = attr.ib(validator=instance_of(nn.Module), default=None)
    num_epochs = attr.ib(validator=and_(instance_of(int), is_positive),
                         default=None)

    metrics = attr.ib(default={},
                      validator=optional(is_dict_of_key_value_type(
                          str, Metric)))
    log_interval = attr.ib(default=100,
                           validator=optional(
                               and_(instance_of(int), is_positive)))

    trainer_checkpoint_interval = attr.ib(default=1000,
                                          validator=optional(
                                              and_(instance_of(int),
                                                   is_positive)))
    model_checkpoint_kwargs = attr.ib(default=None,
                                      validator=optional(instance_of(dict)))

    lr_scheduler = attr.ib(default=None,
                           validator=optional(instance_of(_LRScheduler)))
    reduce_lr_on_plateau = attr.ib(default=None,
                                   validator=optional(
                                       instance_of(ReduceLROnPlateau)))
    reduce_lr_on_plateau_var = attr.ib(default='loss',
                                       validator=optional(instance_of(str)))

    val_dataloader = attr.ib(default=None,
                             validator=optional(is_iterable_with_length))
    val_metrics = attr.ib(default=None,
                          validator=optional(
                              is_dict_of_key_value_type(str, Metric)))
    val_interval_epochs = attr.ib(default=1,
                                  validator=optional(
                                      and_(instance_of(int), is_positive)))

    train_eval_dataloader = attr.ib(
        default=None, validator=optional(is_iterable_with_length))

    early_stopping_kwargs = attr.ib(default=None,
                                    validator=optional(instance_of(dict)))
示例#2
0
class SolverConfig(_BaseConfig):
    """

    """

    optimizer = attr.ib(default=_dummy_optim, validator=instance_of(Optimizer))

    criterion = attr.ib(default=nn.Module(), validator=instance_of(nn.Module))

    num_epochs = attr.ib(default=None,
                         validator=optional(and_(instance_of(int),
                                                 is_positive)))
    num_iterations = attr.ib(default=None,
                             validator=optional(
                                 and_(instance_of(int), is_positive)))
示例#3
0
class InventoryModel:
    goods: Dict[str, float] = attrib(validator=validators.deep_mapping(
        key_validator=validators.and_(validators.instance_of(str),
                                      non_empty_str),
        value_validator=validators.and_(
            validators.instance_of(float),
            positive_int  # type: ignore
        ),
        mapping_validator=validators.instance_of(dict),
    ))
    discounts: Dict[str, DiscountModel] = attrib(
        validator=validators.deep_mapping(
            key_validator=validators.instance_of(str),
            value_validator=validators.instance_of(DiscountModel),
            mapping_validator=validators.instance_of(dict),
        ),
        converter=defaultdict_of_classes(  # type: ignore
            DiscountModel, lambda: NO_DISCOUNT),
    )
    multibuy: Dict[str, MultibuyModel] = attrib(
        validator=validators.deep_mapping(
            key_validator=validators.instance_of(str),
            value_validator=validators.instance_of(MultibuyModel),
            mapping_validator=validators.instance_of(dict),
        ),
        converter=defaultdict_of_classes(  # type: ignore
            MultibuyModel, lambda: NO_MULTIBUY),
    )

    def __attrs_post_init__(self):
        goods = set(self.goods.keys())

        # checks if not discounting unknown items
        for attr in ["discounts", "multibuy"]:
            unknown = set(getattr(self, attr).keys()).difference(goods)
            if len(unknown) > 0:
                raise ValueError(
                    f"Unknown goods as key in '{attr}': {unknown}")

        unknown = {
            multibuy_item.discounts_goods
            for multibuy_item in self.multibuy.values()
        }.difference(goods)

        if len(unknown) > 0:
            raise ValueError(
                f"Unknown goods as value for 'multibuy.*.discounted_goods': {unknown}"
            )
示例#4
0
def sequence_of(validator, allow_empty=False):
    if isinstance(validator, list):
        validator = and_(*validator)

    def validate(inst, attr, value):
        is_seq(inst, attr, value)

        if not allow_empty and not value:
            raise ValueError("'{name}' cannot be empty".format(name=attr.name,
                                                               value=value))

        for i, item in enumerate(value):
            try:
                validator(inst, attr, item)
            except (ValueError, TypeError):
                # now that we know it's invalid, let's re-do the validation
                # with a better attribute name. Since we have to copy data,
                # we only do it when we know we have to raise an exception.
                item_attr = copy.copy(attr)
                object.__setattr__(item_attr, 'name',
                                   '{}[{}]'.format(attr.name, i))
                try:
                    validator(inst, item_attr, item)
                except Exception as exc:
                    six.raise_from(exc, None)
                else:
                    # if we somehow got here, raise original exception
                    raise

    return validate
示例#5
0
    def test_success(self):
        """
        Succeeds if all wrapped validators succeed.
        """
        v = and_(instance_of(int), always_pass)

        v(None, simple_attr("test"), 42)
示例#6
0
    def test_success(self):
        """
        Succeeds if all wrapped validators succeed.
        """
        v = and_(instance_of(int), always_pass)

        v(None, simple_attr("test"), 42)
示例#7
0
    def test_fail(self):
        """
        Fails if any wrapped validator fails.
        """
        v = and_(instance_of(int), always_fail)

        with pytest.raises(ZeroDivisionError):
            v(None, simple_attr("test"), 42)
示例#8
0
    def test_fail(self):
        """
        Fails if any wrapped validator fails.
        """
        v = and_(instance_of(int), always_fail)

        with pytest.raises(ZeroDivisionError):
            v(None, simple_attr("test"), 42)
示例#9
0
def implicit_or(validator):
    if isinstance(validator, list):
        validator = and_(*validator)

    def validate(inst, attr, value):
        if value == IMPLICIT_ZERO:
            return

        validator(inst, attr, value)

    return validate
示例#10
0
class MemoryAmount:
    """
    An amount of memory, consisting of an *amount*
    paired with its corresponding `MemoryUnit` *unit*.
    """

    amount: int = attrib(validator=and_(instance_of(int), in_(Range.at_least(1))))
    unit: MemoryUnit = attrib(validator=None)

    _PARSE_PATTERN = re.compile(r"(\d+) ?([TtGgMmKk])[bB]?")

    @staticmethod
    def parse(memory_string: str) -> "MemoryAmount":
        parts = MemoryAmount._PARSE_PATTERN.match(memory_string)
        if parts:
            return MemoryAmount(
                amount=int(parts.group(1)), unit=MemoryUnit.parse(parts.group(2))
            )
        else:
            raise RuntimeError(
                f"Cannot parse {memory_string} as an amount of memory.  "
                f"Expected an integer followed by K, M, G, or T"
            )
示例#11
0
 class C(object):
     a1 = attr.ib("a1", validator=and_(instance_of(int), ))
     a2 = attr.ib("a2", validator=[
         instance_of(int),
     ])
示例#12
0
 class C:
     a1 = attr.ib("a1", validator=and_(instance_of(int)))
     a2 = attr.ib("a2", validator=[instance_of(int)])
示例#13
0
def sequence_of(types):
    def validator(_, attrib, value):
        assert all(isinstance(x, types) for x in value), attrib.name

    return and_(sequence, validator)
示例#14
0
class ConfigDocFormat():
    """
    Class to store formats for printing out configuration setting information.
    Tries to handle a lot of heavy lifting and provide a nicer interface for
    common case definitions and a convenient interface for template authors.

    Attributes:

       doc (str): the base docstring to be used by the template.

       template (str or jinja2.Template): Template that is used to render
          the final docstring for a full config group.

          ??? Info "Available Template Variables"
              These will be variables the template has access to:

              - `doc`: primary documentation for the root config group
              - `val_table`: dict w/ various info on each var.
                  - `exists`: Does the table exist? Other entries will not be
                       there if this is false.
                  - `name`: The name of the table
                  - `headers`: column headings (in definition order)
                  - `entries`: list of dicts for each table entry, where the
                       keys are the column headings and the values are the
                       formatted strings.

                       **Note:** Any variables available to format specifiers, like
                       `is_group` or `is_val`, will also be made available
                       to the template under each entry.

              - `group_table`: dict w/ various info on groups
                  - `exists`: Does the table exist
                  - `name`: The name of the table
                  - `headers`: column headings (in definition order)
                  - `entries`: list of dicts for each table entry, where the
                       keys are the column headings and the values are the
                       formatted strings.

                       **Note:** Any variables available to format specifiers, like
                       `is_group` or `is_val`, will also be made available
                       to the template under each entry.

              - `combined_table`: A table of combined vars and groups
                  - `exists`: Does the table exist
                  - `name`: The name of the table
                  - `headers`: column headings (in definition order)
                  - `entries`: list of dicts for each table entry, where the
                       keys are the column headings and the values are the
                       formatted strings.

                        **Note:** Any variables available to format specifiers, like
                        `is_group` or `is_val`, will also be made available
                        to the template under each entry.

                  - `group_col_widths`: map from keys to colwidths for a group.
                       Should be used to set the `colspan` parameter in table
                       cells if needed.

                       The width of a column will be `0` if it should be
                       omitted.

                  - `val_col_widths`: map from keys to colwidths for a value
                       Should be used to set the `colspan` parameter in table
                       cells if needed.

                       The width of a column will be `0` if it should be
                       omitted.

       val_table_name (dict): The title for the value table.

       val_table_format (dict): Dictionary used to set the columns the val
          table. Each key should be a column heading, with the value being
          a format string used to format the entry.

          Use `None` if you don't want the val table to appear at all. The
          table will also be ommitted if there are no entries for it.

          ??? Info "Available Format Specifiers"
              These are the keys that will be available to the format strings.

               - `is_val`: is this a variable?
               - `is_group`: is this a group?
               - `defined_in`: the module where their was defined / last updated
               - `val`: the value the term was set to
               - `val_repr`: thst string representation of a value.
               - `doc`: the provided docstring for the term
               - `name`: the name of the term, only the final component
               - `path`: the full pathname of the value

       group_table_name (dict): The title for the subgroup table.

       group_table_format (dict): Dictionary used to set the columns the group
          table. Each key should be a column heading, with the value being
          a format string used to format the entry.

          Use `None` if you don't want the group table to appear at all. The
          table will also be ommitted if there are no entries for it.

          ??? Info "Available Format Specifiers"
              These are the keys that will be available to the format strings.

               - `is_val`: is this a variable?
               - `is_group`: is this a group?
               - `defined_in`: the module where their was defined / last updated
               - `doc`: the provided docstring for the term
               - `name`: the name of the term, only the final component
               - `path`: the full pathname of the value

       combine_tables (bool): Do we combine the value and group tables into a
           single joint table?

       combined_table_name (str): The title for the combined table.

       combined_table_header (list[str]): An ordered list of headers to use for
           a combined entry table. These should match keys in the `table_format`
           parameters.

       recurse_entries (bool): If `True` we recurse to all entries and subgroups
           otherwise we only print the direct children of the config group.

       other_vars (dict): other variables to be passed to the template directly.
    """

    template = attr.ib()

    doc = attr.ib(default="")

    @template.validator
    def __template_validator(self, attr, val):
        if isinstance(val, str):
            self.template = jinja2.Template(textwrap.dedent(val))
            return None
        elif isinstance(val, jinja2.Template):
            self.template = val
        else:
            raise TypeError(
                ("ConfigDocFormat expects a template of type `str` or of " +
                 "type `jinja2.Template`. Instead got a template of type {}"
                 ).format(type(val)))

    val_table_name = attr.ib(
        default="",
        validator=valid.instance_of(str),
    )

    def __normalize_format_dict(format_dict):
        return {
            key: textwrap.dedent(val)
            for (key, val) in format_dict.items()
        }

    val_table_format = attr.ib(
        default=None,
        converter=__normalize_format_dict,
        validator=valid.optional(
            valid.deep_mapping(
                key_validator=valid.instance_of(str),
                value_validator=valid.instance_of(str),
            )),
    )

    val_table_headers = attr.ib(validator=valid.optional(
        valid.and_(
            valid.instance_of(list),
            valid.deep_iterable(valid.instance_of(str)),
        )), )

    @val_table_headers.default
    def val_table_headers_default(self):
        hide_doc = True  # Docs must be the first statement to show up
        """
        List of headers for the value table.
        """
        if self.val_table_format != None:

            keys = list(self.val_table_format)

            log.debug(
                textwrap.dedent("""
                Attempting to get ordered headers for val table:

                  Dict:
                  %s

                  List:
                  %s
                """),
                pformat(self.val_table_format),
                pformat(keys),
            )

            return keys
        else:
            return []

    group_table_headers = attr.ib(validator=valid.optional(
        valid.and_(
            valid.instance_of(list),
            valid.deep_iterable(valid.instance_of(str)),
        )), )

    @group_table_headers.default
    def group_table_headers_default(self):
        hide_doc = True  # Docs must be the first statement to show up
        """
        List of headers for the group table
        """
        if self.group_table_format != None:
            return list(self.group_table_format)
        else:
            return []

    group_table_name = attr.ib(
        default="",
        validator=valid.instance_of(str),
    )

    group_table_format = attr.ib(
        default=None,
        converter=__normalize_format_dict,
        validator=valid.optional(
            valid.deep_mapping(
                key_validator=valid.instance_of(str),
                value_validator=valid.instance_of(str),
            )),
    )

    combine_tables = attr.ib(
        default=False,
        validator=valid.instance_of(bool),
    )

    combined_table_name = attr.ib(
        default="",
        validator=valid.instance_of(str),
    )

    combined_table_headers = attr.ib(validator=valid.optional(
        valid.and_(
            valid.instance_of(list),
            valid.deep_iterable(valid.instance_of(str)),
        )), )

    @combined_table_headers.default
    def combined_table_headers_default(self):
        hide_doc = True  # Docs must be the first statement to show up
        graph = graphlib.TopologicalSorter()
        table_edges = list(
            zip(self.val_table_headers[::2], self.val_table_headers[1::2]))
        table_edges += list(
            zip(self.group_table_headers[::2], self.group_table_headers[1::2]))
        for (start, end) in table_edges:
            graph.add(end, start)
        return list(graph.static_order())

    @staticmethod
    def get_colwidths(combined_cols, elem_cols):
        hide_doc = True  # Docs must be the first statement to show up
        col_widths = dict()

        # get whether each combined cell is in the group cols
        is_member = list(map(lambda x: x in elem_cols, combined_cols))

        log.debug(
            textwrap.dedent("""
            Generating Colspan widths:

               Combined Cols:
               %s

               Elem Type Cols:
               %s

               Membership List:
               %s
            """),
            pformat(combined_cols),
            pformat(elem_cols),
            pformat(is_member),
        )

        for (ind, col) in enumerate(combined_cols):
            if is_member[ind]:
                # Get list of members that follow, grouped by whether they're
                # within our elements of interest
                next_elems = is_member[ind + 1:]
                next_groups = itertools.groupby(next_elems)

                # This overly tedious unpacking is because itertools.groupby
                # returns a list of tuples with iterators in them. I'm just
                # converting them into a more sane format. You could make this
                # more efficient but I can't be arsed.
                next_cols = []
                for k, group in next_groups:
                    grp_list = []
                    for item in list(group):
                        grp_list += [item]
                    next_cols += [grp_list]

                # every valid entry is 1 col wide
                width = 1

                # If we're followed by a bunch of entries not in our set
                # then we can grow wide enough to fill those too.
                if ((len(next_cols) != 0) and (len(next_cols[0]) != 0)
                        and (not next_cols[0][0])):
                    width += len(next_cols[0])

                log.debug(
                    textwrap.dedent("""
                    For col '%s' at index '%s':

                       Next Members:
                       %s

                       Next Groups:
                       %s

                       Final Width:
                       %s
                    """), col, ind, next_elems, next_cols, width)

                col_widths[col] = width
            else:
                # invalid entries should be dropped
                col_widths[col] = 0

        return col_widths

    recurse_entries = attr.ib(default=True, validator=valid.instance_of(bool))

    other_vars = attr.ib(
        factory=dict,
        validator=valid.and_(
            valid.instance_of(dict),
            valid.deep_mapping(
                key_validator=valid.instance_of(str),
                value_validator=(lambda _a, _b, _c: True),
            ),
        ),
    )

    @property
    def combined_group_col_widths(self):
        hide_doc = True  # Docs must be the first statement to show up
        """
        The column widths for a group entry in a combined table
        """

        widths = ConfigDocFormat.get_colwidths(self.combined_table_headers,
                                               self.group_table_headers)

        log.debug(
            textwrap.dedent("""
            Genrating Group Col Widths:

               Combined Headers:
               %s

               Group Headers:
               %s

               Result Widths:
               %s
            """),
            pformat(self.combined_table_headers),
            pformat(self.group_table_headers),
            pformat(widths),
        )

        return widths

    @property
    def combined_val_col_widths(self):
        hide_doc = True  # Docs must be the first statement to show up
        """
        The column widths for a value entry in a combined table
        """

        widths = ConfigDocFormat.get_colwidths(self.combined_table_headers,
                                               self.val_table_headers)

        log.debug(
            textwrap.dedent("""
            Genrating Value Col Widths:

               Combined Headers:
               %s

               Value Headers:
               %s

               Result Widths:
               %s
            """),
            pformat(self.combined_table_headers),
            pformat(self.val_table_headers),
            pformat(widths),
        )

        return widths

    @classmethod
    def get_val_data(cls, config_val):
        hide_doc = True  # Docs must be the first statement to show up
        """
        Exact and format data for a ConfigValue
        """
        return {
            config_val.path_string(): {
                'is_val': True,
                'is_group': False,
                'val': config_val.value,
                'val_repr': repr(config_val.value),
                'doc': config_val.doc,
                'defined_in': config_val.ctxt.__qualname__,
                'name': config_val.path[-1],
                'path': config_val.path_string(),
            }
        }

    @classmethod
    def get_group_data(cls, config_group):
        hide_doc = True  # Docs must be the first statement to show up
        """
        Extract and format data for a ConfigGroup
        """
        return {
            config_group.path_string(): {
                'is_val': False,
                'is_group': True,
                'doc': config_group.doc,
                'name': config_group.path[-1],
                'path': config_group.path_string(),
            }
        }

    @classmethod
    def collect_group_data(self, config_group, recurse=False):
        hide_doc = True  # Docs must be the first statement to show up
        """
        Run through the members of a ConfigGroup and gather all the template
        data in a dictionary, where the keys are the path to each term.
        This will recurse if necessary.
        """
        data = dict()
        for member in config_group.members.values():
            if isinstance(member, ConfigValue):
                data.update(ConfigDocFormat.get_val_data(member))
            else:
                data.update(ConfigDocFormat.get_group_data(member))
                if recurse:
                    data.update(
                        ConfigDocFormat.collect_group_data(member, recurse))
        return data

    @classmethod
    def format_data(self, format_dict, data):
        hide_doc = True  # Docs must be the first statement to show up
        """
        Format all the values in a format dictionary given a particular
        dataset.
        """

        output = dict()
        for (key, format_string) in format_dict.items():
            output[key] = format_string.format_map(data)
        for (key, value) in data.items():
            if key not in output:
                output[key] = value

        log.debug(
            textwrap.dedent("""
            Formatting Data:

              Format String:
              %s

              Data:
              %s

              Result:
              %s
            """),
            pformat(format_dict, depth=6, indent=2),
            pformat(data, depth=6, indent=2),
            pformat(output, depth=6, indent=2),
        )

        return output

    def format_val_data(self, val_data):
        hide_doc = True  # Docs must be the first statement to show up
        """
        format a value table entry if we have a formatter for it.
        """
        if self.val_table_format != None:
            return self.format_data(self.val_table_format, val_data)
        else:
            return dict()

    def format_group_data(self, group_data):
        hide_doc = True  # Docs must be the first statement to show up
        """
        Format a group table entry if we have a formatter for it.
        """
        if self.group_table_format != None:
            return self.format_data(self.group_table_format, group_data)
        else:
            return dict()

    def format_member_data(self, data):
        hide_doc = True  # Docs must be the first statement to show up

        log.debug("Formatting Member Data:\n\n%s", pformat(data))

        output = None
        if data['is_val']:
            output = self.format_val_data(data)
        else:
            output = self.format_group_data(data)

        log.debug("Formatted Member Data:\n\n%s", pformat(output))

        return output

    def process_config_group(self, config):
        hide_doc = True  # Docs must be the first statement to show up
        """
        Run through a configuration group to produce all the variables that
        will be fed to the template.
        """

        # Process data from config group
        rows = self.collect_group_data(config, self.recurse_entries)
        formatted = dict()
        for (path, data) in rows.items():
            formatted[path] = self.format_member_data(data)
            log.debug("Formatted `%s` data to get:\n\n%s", path,
                      pformat(formatted[path]))

        vals = {
            path: data
            for (path, data) in formatted.items() if data['is_val']
        }
        groups = {
            path: data
            for (path, data) in formatted.items() if data['is_group']
        }

        log.debug(
            textwrap.dedent("""
            Control Group Raw Contents:

              Rows:
              %s

              Vals:
              %s

              Groups :
              %s
            """),
            pformat(rows),
            pformat(vals),
            pformat(groups),
        )

        # Figure out which tables should exist in our output.
        val_table_can_exist = ((len(vals) > 0)
                               and (self.val_table_format != None))
        group_table_can_exist = ((len(groups) > 0)
                                 and (self.group_table_format != None))
        combined_table_exists = ((val_table_can_exist
                                  and group_table_can_exist)
                                 and self.combine_tables)
        val_table_exists = (not combined_table_exists) and val_table_can_exist
        group_table_exists = (
            not combined_table_exists) and group_table_can_exist

        log.debug(
            textwrap.dedent("""
            Table Existence Data:

              Val Table Can Exist: %s
              Grp Table Can Exist: %s

              Val Table Exists: %s
              Grp Table Exists: %s
              Cmb Table Exists: %s
            """),
            val_table_can_exist,
            group_table_can_exist,
            val_table_exists,
            group_table_exists,
            combined_table_exists,
        )

        # generate top level output data
        output = dict()
        output.update(self.other_vars)
        output['doc'] = config.doc
        output['val_table'] = dict()
        output['group_table'] = dict()
        output['combined_table'] = dict()

        # generate val table data
        output['val_table']['exists'] = val_table_exists
        if val_table_exists:
            output['val_table']['name'] = self.val_table_name
            output['val_table']['headers'] = self.val_table_headers
            output['val_table']['entries'] = [
                vals[path] for path in sorted(vals.keys())
            ]

        # generate group table data
        output['group_table']['exists'] = group_table_exists
        if group_table_exists:
            output['group_table']['name'] = self.group_table_name
            output['group_table']['headers'] = self.group_table_headers
            output['group_table']['entries'] = [
                groups[path] for path in sorted(groups.keys())
            ]

        # generate group table data
        output['combined_table']['exists'] = combined_table_exists
        if combined_table_exists:
            output['combined_table']['name'] = self.combined_table_name
            output['combined_table']['headers'] = self.combined_table_headers
            output['combined_table']['entries'] = [
                formatted[path] for path in sorted(formatted.keys())
            ]
            output['combined_table']['group_col_widths'] = (
                self.combined_group_col_widths)
            output['combined_table']['val_col_widths'] = (
                self.combined_val_col_widths)

        log.debug("Processed Doc Data:\n\n%s", pformat(output))

        return output

    @staticmethod
    def render_docs(docs_format, config_group, empty_doc=""):
        """
        Render the documentation for a config group using the provided
        formatter.
        """
        processed_data = docs_format.process_config_group(config_group)

        if (processed_data['val_table']['exists']
                or processed_data['group_table']['exists']
                or processed_data['combined_table']['exists']):

            return docs_format.template.render(processed_data)
        else:
            return empty_doc