Beispiel #1
0
def create_toy_graph():
  synsets = {}
  for wn_id, name in enumerate(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']):
    synsets[name] = imagenet_spec.Synset(wn_id, name, set(), set())

  # (parent-child) tuples
  is_a_relations = [('a', 'b'), ('a', 'c'), ('b', 'g'), ('c', 'd'), ('c', 'e'),
                    ('e', 'f'), ('e', 'h')]
  # The graph is a tree that looks like:
  #        a
  #    b       c
  #  g       d   e
  #             h f
  for t in is_a_relations:
    parent, child = t
    synsets[parent].children.add(synsets[child])
    synsets[child].parents.add(synsets[parent])

  subset = ['f', 'g']
  synsets_subset = [s for s in synsets.values() if s.words in subset]

  # Get the graph of all and only the ancestors of synsets_subset
  graph_nodes = imagenet_spec.create_sampling_graph(synsets_subset)

  # The created graph should contain all and only the ancestors of subset and
  # collapses all nodes that have exactly 1 child. It should be:
  #    a
  #  g   f

  # Create a data structure mapping graph_nodes to their reachable leaves
  spanning_leaves = imagenet_spec.get_spanning_leaves(graph_nodes)
  return graph_nodes, spanning_leaves, synsets_subset
Beispiel #2
0
  def test_imagenet_specification(self):
    spec = imagenet_spec.create_imagenet_specification(learning_spec.Split)
    splits, _, graph_nodes, synsets_2012, num_synset_2012_images, roots = spec
    span_leaves = imagenet_spec.get_spanning_leaves(graph_nodes)
    num_span_images = imagenet_spec.get_num_spanning_images(
        span_leaves, num_synset_2012_images)

    validate_graph(graph_nodes, synsets_2012, self)
    validate_spanning_leaves(span_leaves, synsets_2012, self)
    self.validate_splits(splits)
    self.validate_num_span_images(span_leaves, num_span_images)

    test_lowest_common_ancestor(graph_nodes, self)
    test_get_upward_paths(graph_nodes, self)
    # Make sure that in no sub-tree can the LCA of two chosen leaves of that
    # sub-tree be a node that is an ancestor of the sub-tree's root.
    valid_subgraph, test_subgraph = splits[learning_spec.Split.VALID], splits[
        learning_spec.Split.TEST]
    valid_root, test_root = roots['valid'], roots['test']
    test_lowest_common_ancestor(valid_subgraph, self, valid_root)
    test_get_upward_paths(valid_subgraph, self, valid_root)
    test_lowest_common_ancestor(test_subgraph, self, test_root)
    test_get_upward_paths(test_subgraph, self, test_root)
Beispiel #3
0
  def __init__(self,
               dataset_spec,
               split,
               episode_descr_config,
               pool=None,
               use_dag_hierarchy=False,
               use_bilevel_hierarchy=False,
               use_all_classes=False):
    """Initializes an EpisodeDescriptionSampler.episode_config.

    Args:
      dataset_spec: DatasetSpecification, dataset specification.
      split: one of Split.TRAIN, Split.VALID, or Split.TEST.
      episode_descr_config: An instance of EpisodeDescriptionConfig containing
        parameters relating to sampling shots and ways for episodes.
      pool: A string ('train' or 'test') or None, indicating which example-level
        split to select, if the current dataset has them.
      use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured
        ontology is defined in dataset_spec, use it to choose related classes.
      use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology
        is defined in dataset_spec, use it for sampling classes.
      use_all_classes: Boolean, defaults to False. Uses all available classes,
        in order, instead of sampling. Overrides `num_ways` to the number of
        classes in `split`.

    Raises:
      RuntimeError: if required parameters are missing.
      ValueError: Inconsistent parameters.
    """
    self.dataset_spec = dataset_spec
    self.split = split
    self.pool = pool
    self.use_dag_hierarchy = use_dag_hierarchy
    self.use_bilevel_hierarchy = use_bilevel_hierarchy
    self.use_all_classes = use_all_classes
    self.num_ways = episode_descr_config.num_ways
    self.num_support = episode_descr_config.num_support
    self.num_query = episode_descr_config.num_query
    self.min_ways = episode_descr_config.min_ways
    self.max_ways_upper_bound = episode_descr_config.max_ways_upper_bound
    self.max_num_query = episode_descr_config.max_num_query
    self.max_support_set_size = episode_descr_config.max_support_set_size
    self.max_support_size_contrib_per_class = episode_descr_config.max_support_size_contrib_per_class
    self.min_log_weight = episode_descr_config.min_log_weight
    self.max_log_weight = episode_descr_config.max_log_weight
    self.min_examples_in_class = episode_descr_config.min_examples_in_class

    self.class_set = dataset_spec.get_classes(self.split)
    self.num_classes = len(self.class_set)
    # Filter out classes with too few examples
    self._filtered_class_set = []
    # Store (class_id, n_examples) of skipped classes for logging.
    skipped_classes = []
    for class_id in self.class_set:
      n_examples = dataset_spec.get_total_images_per_class(class_id, pool=pool)
      if n_examples < self.min_examples_in_class:
        skipped_classes.append((class_id, n_examples))
      else:
        self._filtered_class_set.append(class_id)
    self.num_filtered_classes = len(self._filtered_class_set)

    if skipped_classes:
      logging.info(
          'Skipping the following classes, which do not have at least '
          '%d examples', self.min_examples_in_class)
    for class_id, n_examples in skipped_classes:
      logging.info('%s (ID=%d, %d examples)',
                   dataset_spec.class_names[class_id], class_id, n_examples)

    if self.min_ways and self.num_filtered_classes < self.min_ways:
      raise ValueError(
          '"min_ways" is set to {}, but split {} of dataset {} only has {} '
          'classes with at least {} examples ({} total), so it is not possible '
          'to create an episode for it. This may have resulted from applying a '
          'restriction on this split of this dataset by specifying '
          'benchmark.restrict_classes or benchmark.min_examples_in_class.'
          .format(self.min_ways, split, dataset_spec.name,
                  self.num_filtered_classes, self.min_examples_in_class,
                  self.num_classes))

    if self.use_all_classes:
      if self.num_classes != self.num_filtered_classes:
        raise ValueError('"use_all_classes" is not compatible with a value of '
                         '"min_examples_in_class" ({}) that results in some '
                         'classes being excluded.'.format(
                             self.min_examples_in_class))
      self.num_ways = self.num_classes

    # Maybe overwrite use_dag_hierarchy or use_bilevel_hierarchy if requested.
    if episode_descr_config.ignore_dag_ontology:
      self.use_dag_hierarchy = False
    if episode_descr_config.ignore_bilevel_ontology:
      self.use_bilevel_hierarchy = False

    # For Omniglot.
    if self.use_bilevel_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"num_ways".')
      if self.min_examples_in_class > 0:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"min_examples_in_class".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.BiLevelDatasetSpecification):
        raise ValueError('Only applicable to datasets with a bi-level '
                         'dataset specification.')
      # The id's of the superclasses of the split (a contiguous range of ints).
      all_superclasses = dataset_spec.get_superclasses(self.split)
      self.superclass_set = []
      for i in all_superclasses:
        if self.dataset_spec.classes_per_superclass[i] < self.min_ways:
          raise ValueError(
              'Superclass: %d has num_classes=%d < min_ways=%d.' %
              (i, self.dataset_spec.classes_per_superclass[i], self.min_ways))
        self.superclass_set.append(i)
    # For ImageNet.
    elif self.use_dag_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.HierarchicalDatasetSpecification):
        raise ValueError('Only applicable to datasets with a hierarchical '
                         'dataset specification.')

      # A DAG for navigating the ontology for the given split.
      graph = dataset_spec.get_split_subgraph(self.split)

      # Map the absolute class IDs in the split's class set to IDs relative to
      # the split.
      class_set = self.class_set
      abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set))

      # Extract the sets of leaves and internal nodes in the DAG.
      leaves = set(imagenet_specification.get_leaves(graph))
      internal_nodes = graph - leaves  # set difference

      # Map each node of the DAG to the Synsets of the leaves it spans.
      spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph)

      # Build a list of lists storing the relative class IDs of the spanning
      # leaves for each eligible internal node.
      self.span_leaves_rel = []
      for node in internal_nodes:
        node_leaves = spanning_leaves_dict[node]
        # Build a list of relative class IDs of leaves that have at least
        # min_examples_in_class examples.
        ids_rel = []
        for leaf in node_leaves:
          abs_id = dataset_spec.class_names_to_ids[leaf.wn_id]
          if abs_id in self._filtered_class_set:
            ids_rel.append(abs_to_rel_ids[abs_id])

        # Internal nodes are eligible if they span at least
        # `min_allowed_classes` and at most `max_eligible` leaves.
        if self.min_ways <= len(ids_rel) <= MAX_SPANNING_LEAVES_ELIGIBLE:
          self.span_leaves_rel.append(ids_rel)

      num_eligible_nodes = len(self.span_leaves_rel)
      if num_eligible_nodes < 1:
        raise ValueError('There are no classes eligible for participating in '
                         'episodes. Consider changing the value of '
                         '`EpisodeDescriptionSampler.min_ways` in gin, or '
                         'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')
Beispiel #4
0
  def __init__(self,
               dataset_spec,
               split,
               pool=None,
               use_dag_hierarchy=False,
               use_bilevel_hierarchy=False,
               use_all_classes=False,
               num_ways=None,
               num_support=None,
               num_query=None,
               min_ways=None,
               max_ways_upper_bound=None,
               max_num_query=None,
               max_support_set_size=None,
               max_support_size_contrib_per_class=None,
               min_log_weight=None,
               max_log_weight=None):
    """Initializes an EpisodeDescriptionSampler.

    Args:
      dataset_spec: DatasetSpecification, dataset specification.
      split: one of Split.TRAIN, Split.VALID, or Split.TEST.
      pool: A string ('train' or 'test') or None, indicating which example-level
        split to select, if the current dataset has them.
      use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured
        ontology is defined in dataset_spec, use it to choose related classes.
      use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology
        is defined in dataset_spec, use it for sampling classes.
      use_all_classes: Boolean, defaults to False. Uses all available classes,
        in order, instead of sampling. Overrides `num_ways` to the number of
        classes in `split`.
      num_ways: Integer (optional), fixes the number of classes ("ways") to be
        used in each episode if provided. Incompatible with using any hierarchy.
      num_support: Integer (optional), fixes the number of examples for each
        class in the support set if provided.
      num_query: Integer (optional), fixes the number of examples for each class
        in the query set if provided.
      min_ways: Integer, the minimum value when sampling ways (has to be
        provided if `num_ways` is None).
      max_ways_upper_bound: Integer, the maximum value when sampling ways (has
        to be provided through if `num_ways` is None). Note that the number of
        available classes acts as another upper bound.
      max_num_query: Integer, the maximum number of query examples per class
        (has to be provided if `num_query` is None).
      max_support_set_size: Integer, the maximum size for the support set (has
        to be provided `num_support` is None).
      max_support_size_contrib_per_class: Integer, the maximum contribution for
        any given class to the support set size. (has to be provided if
        `num_support` is None).
      min_log_weight: Float, the minimum log-weight to give to any particular
        class when determining the number of support examples per class (has to
        be provided if `num_support` is None).
      max_log_weight: Float, the maximum log-weight to give to any particular
        class (has to be provided if `num_support` is None).

    Raises:
      RuntimeError: if required parameters are missing.
      ValueError: Inconsistent parameters.
    """
    arg_groups = {
        'num_ways': (num_ways, ('min_ways', 'max_ways_upper_bound'),
                     (min_ways, max_ways_upper_bound)),
        'num_query': (num_query, ('max_num_query',), (max_num_query,)),
        'num_support':
            (num_support,
             ('max_support_set_size', 'max_support_size_contrib_per_class',
              'min_log_weight', 'max_log_weight'),
             (max_support_set_size, max_support_size_contrib_per_class,
              min_log_weight, max_log_weight)),
    }

    for first_arg_name, values in arg_groups.items():
      first_arg, required_arg_names, required_args = values
      if ((first_arg is None) and any(arg is None for arg in required_args)):
        # Get name of the nones
        none_arg_names = [
            name for var, name in zip(required_args, required_arg_names)
            if var is None
        ]
        raise RuntimeError(
            'The following arguments: %s can not be None, since %s is None. '
            'Arguments can be set up with gin, for instance by providing '
            '`--gin_file=learn/gin/setups/data_config.gin` or calling '
            '`gin.parse_config_file(...)` in the code. Please ensure the '
            'following gin arguments of EpisodeDescriptionSampler are set: '
            '%s' % (none_arg_names, first_arg_name, none_arg_names))
    self.dataset_spec = dataset_spec
    self.split = split
    self.pool = pool
    self.use_dag_hierarchy = use_dag_hierarchy
    self.use_bilevel_hierarchy = use_bilevel_hierarchy
    self.use_all_classes = use_all_classes
    self.num_ways = num_ways
    self.num_support = num_support
    self.num_query = num_query
    # Gin parameters
    self.min_ways = min_ways
    self.max_ways_upper_bound = max_ways_upper_bound
    self.max_num_query = max_num_query
    self.max_support_set_size = max_support_set_size
    self.max_support_size_contrib_per_class = max_support_size_contrib_per_class
    self.min_log_weight = min_log_weight
    self.max_log_weight = max_log_weight

    self.class_set = dataset_spec.get_classes(self.split)
    self.num_classes = len(self.class_set)

    if self.use_all_classes:
      self.num_ways = self.num_classes

    # For Omniglot.
    if self.use_bilevel_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.BiLevelDatasetSpecification):
        raise ValueError('Only applicable to datasets with a bi-level '
                         'dataset specification.')
      # The id's of the superclasses of the split (a contiguous range of ints).
      self.superclass_set = dataset_spec.get_superclasses(self.split)

    # For ImageNet.
    elif self.use_dag_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.HierarchicalDatasetSpecification):
        raise ValueError('Only applicable to datasets with a hierarchical '
                         'dataset specification.')

      # A DAG for navigating the ontology for the given split.
      graph = dataset_spec.get_split_subgraph(self.split)

      # Map the absolute class IDs in the split's class set to IDs relative to
      # the split.
      class_set = self.dataset_spec.get_classes(self.split)
      abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set))

      # Extract the sets of leaves and internal nodes in the DAG.
      leaves = set(imagenet_specification.get_leaves(graph))
      internal_nodes = graph - leaves  # set difference

      # Map each node of the DAG to the Synsets of the leaves it spans.
      spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph)

      # Build a list of lists storing the relative class IDs of the spanning
      # leaves for each eligible internal node.
      self.span_leaves_rel = []
      for node in internal_nodes:
        node_leaves = spanning_leaves_dict[node]
        # Internal nodes are eligible if they span at least
        # `min_allowed_classes` and at most `max_eligible` leaves.
        if self.min_ways <= len(node_leaves) <= MAX_SPANNING_LEAVES_ELIGIBLE:
          # Build a list of relative class IDs for this internal node.
          ids = [dataset_spec.class_names_to_ids[s.wn_id] for s in node_leaves]
          ids_rel = [abs_to_rel_ids[abs_id] for abs_id in ids]
          self.span_leaves_rel.append(ids_rel)

      num_eligible_nodes = len(self.span_leaves_rel)
      if num_eligible_nodes < 1:
        raise ValueError('There are no classes eligible for participating in '
                         'episodes. Consider changing the value of '
                         '`EpisodeDescriptionSampler.min_ways` in gin, or '
                         'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')
Beispiel #5
0
  def __init__(self,
               dataset_spec,
               split,
               episode_descr_config,
               pool=None,
               use_dag_hierarchy=False,
               use_bilevel_hierarchy=False,
               use_all_classes=False):
    """Initializes an EpisodeDescriptionSampler.episode_config.

    Args:
      dataset_spec: DatasetSpecification, dataset specification.
      split: one of Split.TRAIN, Split.VALID, or Split.TEST.
      episode_descr_config: An instance of EpisodeDescriptionConfig containing
        parameters relating to sampling shots and ways for episodes.
      pool: A string ('train' or 'test') or None, indicating which example-level
        split to select, if the current dataset has them.
      use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured
        ontology is defined in dataset_spec, use it to choose related classes.
      use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology
        is defined in dataset_spec, use it for sampling classes.
      use_all_classes: Boolean, defaults to False. Uses all available classes,
        in order, instead of sampling. Overrides `num_ways` to the number of
        classes in `split`.

    Raises:
      RuntimeError: if required parameters are missing.
      ValueError: Inconsistent parameters.
    """
    self.dataset_spec = dataset_spec
    self.split = split
    self.pool = pool
    self.use_dag_hierarchy = use_dag_hierarchy
    self.use_bilevel_hierarchy = use_bilevel_hierarchy
    self.use_all_classes = use_all_classes
    self.num_ways = episode_descr_config.num_ways
    self.num_support = episode_descr_config.num_support
    self.num_query = episode_descr_config.num_query
    self.min_ways = episode_descr_config.min_ways
    self.max_ways_upper_bound = episode_descr_config.max_ways_upper_bound
    self.max_num_query = episode_descr_config.max_num_query
    self.max_support_set_size = episode_descr_config.max_support_set_size
    self.max_support_size_contrib_per_class = episode_descr_config.max_support_size_contrib_per_class
    self.min_log_weight = episode_descr_config.min_log_weight
    self.max_log_weight = episode_descr_config.max_log_weight

    self.class_set = dataset_spec.get_classes(self.split)
    self.num_classes = len(self.class_set)

    if self.use_all_classes:
      self.num_ways = self.num_classes

    # Maybe overwrite use_dag_hierarchy or use_bilevel_hierarchy if requested.
    if episode_descr_config.ignore_dag_ontology:
      self.use_dag_hierarchy = False
    if episode_descr_config.ignore_bilevel_ontology:
      self.use_bilevel_hierarchy = False

    # For Omniglot.
    if self.use_bilevel_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.BiLevelDatasetSpecification):
        raise ValueError('Only applicable to datasets with a bi-level '
                         'dataset specification.')
      # The id's of the superclasses of the split (a contiguous range of ints).
      self.superclass_set = dataset_spec.get_superclasses(self.split)

    # For ImageNet.
    elif self.use_dag_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.HierarchicalDatasetSpecification):
        raise ValueError('Only applicable to datasets with a hierarchical '
                         'dataset specification.')

      # A DAG for navigating the ontology for the given split.
      graph = dataset_spec.get_split_subgraph(self.split)

      # Map the absolute class IDs in the split's class set to IDs relative to
      # the split.
      class_set = self.dataset_spec.get_classes(self.split)
      abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set))

      # Extract the sets of leaves and internal nodes in the DAG.
      leaves = set(imagenet_specification.get_leaves(graph))
      internal_nodes = graph - leaves  # set difference

      # Map each node of the DAG to the Synsets of the leaves it spans.
      spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph)

      # Build a list of lists storing the relative class IDs of the spanning
      # leaves for each eligible internal node.
      self.span_leaves_rel = []
      for node in internal_nodes:
        node_leaves = spanning_leaves_dict[node]
        # Internal nodes are eligible if they span at least
        # `min_allowed_classes` and at most `max_eligible` leaves.
        if self.min_ways <= len(node_leaves) <= MAX_SPANNING_LEAVES_ELIGIBLE:
          # Build a list of relative class IDs for this internal node.
          ids = [dataset_spec.class_names_to_ids[s.wn_id] for s in node_leaves]
          ids_rel = [abs_to_rel_ids[abs_id] for abs_id in ids]
          self.span_leaves_rel.append(ids_rel)

      num_eligible_nodes = len(self.span_leaves_rel)
      if num_eligible_nodes < 1:
        raise ValueError('There are no classes eligible for participating in '
                         'episodes. Consider changing the value of '
                         '`EpisodeDescriptionSampler.min_ways` in gin, or '
                         'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')