Esempio n. 1
0
  def validate_splits(self, splits):
    """Check the correctness of the class splits."""
    train_graph = splits[learning_spec.Split.TRAIN]
    valid_graph = splits[learning_spec.Split.VALID]
    test_graph = splits[learning_spec.Split.TEST]

    # Make sure that by following child/parent pointers of nodes of a given
    # split's subgraph, we will reach nodes that also belong to that subgraph.
    def ensure_isolated(nodes):
      for n in nodes:
        for c in n.children:
          self.assertIn(c, nodes)
        for p in n.parents:
          self.assertIn(p, nodes)

    ensure_isolated(train_graph)
    ensure_isolated(valid_graph)
    ensure_isolated(test_graph)

    train_classes = imagenet_spec.get_leaves(train_graph)
    valid_classes = imagenet_spec.get_leaves(valid_graph)
    test_classes = imagenet_spec.get_leaves(test_graph)

    # Ensure that there is no overlap between classes of different splits
    # and that combined they cover all ILSVRC 2012 classes
    all_classes = train_classes + valid_classes + test_classes
    self.assertLen(set(all_classes), 1000)  # all covered
    self.assertLen(set(all_classes), len(all_classes))  # no duplicates
Esempio n. 2
0
    def initialize(self, restricted_classes_per_split=None):
        """Initializes a HierarchicalDatasetSpecification.

    Args:
      restricted_classes_per_split: A dict that specifies for each split, a
        number to restrict its classes to. This number must be no greater than
        the total number of classes of that split. By default this is None and
        no restrictions are applied (all classes are used).
    """
        # Set self.class_names_to_ids to the inverse dict of self.class_names.
        self.class_names_to_ids = dict(
            zip(self.class_names.values(), self.class_names.keys()))

        # Maps each Split enum to the number of its classes.
        self.classes_per_split = self.get_classes_per_split()

        # Map each class ID to its corresponding number of examples.
        examples_per_class = {}
        for split in learning_spec.Split:
            leaves = imagenet_specification.get_leaves(
                self.split_subgraphs[split])
            for node in leaves:
                num_examples = self.images_per_class[split][node]
                examples_per_class[self.class_names_to_ids[
                    node.wn_id]] = num_examples
        self.examples_per_class = examples_per_class

        if restricted_classes_per_split is not None:
            _check_validity_of_restricted_classes_per_split(
                restricted_classes_per_split, self.classes_per_split)
            # Apply the restriction.
            for split, restricted_num_classes in restricted_classes_per_split.items(
            ):
                self.classes_per_split[split] = restricted_num_classes
Esempio n. 3
0
 def validate_splits(self, splits, spanning_leaves):
   # Make sure that the classes assigned to each split cover all the leaves
   # and no class is assigned to more than one splits
   train_wn_ids = splits['train']
   valid_wn_ids = splits['valid']
   test_wn_ids = splits['test']
   self.assertFalse(train_wn_ids & valid_wn_ids)
   self.assertFalse(train_wn_ids & test_wn_ids)
   self.assertFalse(test_wn_ids & valid_wn_ids)
   all_wn_ids = train_wn_ids | valid_wn_ids | test_wn_ids
   leaves = imagenet_spec.get_leaves(spanning_leaves.keys())
   self.assertLen(all_wn_ids, len(leaves))  # all covered
Esempio n. 4
0
def test_lowest_common_ancestor(graph_nodes, test_instance, root=None):
  # Test the computation of the lowest common ancestor of two nodes.
  # Randomly sample two leaves a number of times, find their lowest common
  # ancestor and its height and verify that they are computed correctly.
  leaves = imagenet_spec.get_leaves(graph_nodes)
  for _ in range(10000):
    first_ind = np.random.randint(len(leaves))
    second_ind = np.random.randint(len(leaves))
    while first_ind == second_ind:
      second_ind = np.random.randint(len(leaves))
    leaf_a = leaves[first_ind]
    leaf_b = leaves[second_ind]
    lca, height = imagenet_spec.get_lowest_common_ancestor(leaf_a, leaf_b)
    test_lowest_common_ancestor_(
        lca, height, leaf_a, leaf_b, test_instance, root=root)
Esempio n. 5
0
def validate_graph(graph_nodes, subset_synsets, test_instance):
  """Checks that the DAG structure is as expected."""
  # 1) Test that the leaves are all and only the ILSVRC 2012 synsets
  leaves = imagenet_spec.get_leaves(graph_nodes)
  test_instance.assertEqual(len(leaves), len(subset_synsets))
  test_instance.assertEqual(set(leaves), set(subset_synsets))

  # 2) Validate the connectivity
  # If a node is listed as a child of another, the latter must also be listed as
  # a parent of the former, and similarly if a node is listed as a parent of
  # another, the latter must also be listed as a child of the former.
  for n in graph_nodes:
    for c in n.children:
      test_instance.assertIn(n, c.parents)
    for p in n.parents:
      test_instance.assertIn(n, p.children)

  # 3) Check that no node has only 1 child, as it's not possible to create an
  # episode from such a node.
  for n in graph_nodes:
    test_instance.assertNotEqual(len(n.children), 1)

  # 4) Check that the graph is detached from the remaining non-graph synsets.
  # We want to guarantee that by following parent or child pointers of graph
  # nodes we will stay within the graph.
  for n in graph_nodes:
    for c in n.children:
      test_instance.assertIn(c, graph_nodes)
    for p in n.parents:
      test_instance.assertIn(p, graph_nodes)

  # 5) Check that every node in graph nodes is either an ILSVRC 2012 synset or
  # the ancestor of an ILSVRC 2012 synset
  for n in graph_nodes:
    if n in subset_synsets:
      continue
    has_2012_descendent = False
    for s in subset_synsets:
      has_2012_descendent = imagenet_spec.is_descendent(s, n)
      if has_2012_descendent:
        break
    test_instance.assertTrue(has_2012_descendent)
Esempio n. 6
0
  def __init__(self,
               dataset_spec,
               split,
               episode_descr_config,
               pool=None,
               use_dag_hierarchy=False,
               use_bilevel_hierarchy=False,
               use_all_classes=False):
    """Initializes an EpisodeDescriptionSampler.episode_config.

    Args:
      dataset_spec: DatasetSpecification, dataset specification.
      split: one of Split.TRAIN, Split.VALID, or Split.TEST.
      episode_descr_config: An instance of EpisodeDescriptionConfig containing
        parameters relating to sampling shots and ways for episodes.
      pool: A string ('train' or 'test') or None, indicating which example-level
        split to select, if the current dataset has them.
      use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured
        ontology is defined in dataset_spec, use it to choose related classes.
      use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology
        is defined in dataset_spec, use it for sampling classes.
      use_all_classes: Boolean, defaults to False. Uses all available classes,
        in order, instead of sampling. Overrides `num_ways` to the number of
        classes in `split`.

    Raises:
      RuntimeError: if required parameters are missing.
      ValueError: Inconsistent parameters.
    """
    self.dataset_spec = dataset_spec
    self.split = split
    self.pool = pool
    self.use_dag_hierarchy = use_dag_hierarchy
    self.use_bilevel_hierarchy = use_bilevel_hierarchy
    self.use_all_classes = use_all_classes
    self.num_ways = episode_descr_config.num_ways
    self.num_support = episode_descr_config.num_support
    self.num_query = episode_descr_config.num_query
    self.min_ways = episode_descr_config.min_ways
    self.max_ways_upper_bound = episode_descr_config.max_ways_upper_bound
    self.max_num_query = episode_descr_config.max_num_query
    self.max_support_set_size = episode_descr_config.max_support_set_size
    self.max_support_size_contrib_per_class = episode_descr_config.max_support_size_contrib_per_class
    self.min_log_weight = episode_descr_config.min_log_weight
    self.max_log_weight = episode_descr_config.max_log_weight
    self.min_examples_in_class = episode_descr_config.min_examples_in_class

    self.class_set = dataset_spec.get_classes(self.split)
    self.num_classes = len(self.class_set)
    # Filter out classes with too few examples
    self._filtered_class_set = []
    # Store (class_id, n_examples) of skipped classes for logging.
    skipped_classes = []
    for class_id in self.class_set:
      n_examples = dataset_spec.get_total_images_per_class(class_id, pool=pool)
      if n_examples < self.min_examples_in_class:
        skipped_classes.append((class_id, n_examples))
      else:
        self._filtered_class_set.append(class_id)
    self.num_filtered_classes = len(self._filtered_class_set)

    if skipped_classes:
      logging.info(
          'Skipping the following classes, which do not have at least '
          '%d examples', self.min_examples_in_class)
    for class_id, n_examples in skipped_classes:
      logging.info('%s (ID=%d, %d examples)',
                   dataset_spec.class_names[class_id], class_id, n_examples)

    if self.min_ways and self.num_filtered_classes < self.min_ways:
      raise ValueError(
          '"min_ways" is set to {}, but split {} of dataset {} only has {} '
          'classes with at least {} examples ({} total), so it is not possible '
          'to create an episode for it. This may have resulted from applying a '
          'restriction on this split of this dataset by specifying '
          'benchmark.restrict_classes or benchmark.min_examples_in_class.'
          .format(self.min_ways, split, dataset_spec.name,
                  self.num_filtered_classes, self.min_examples_in_class,
                  self.num_classes))

    if self.use_all_classes:
      if self.num_classes != self.num_filtered_classes:
        raise ValueError('"use_all_classes" is not compatible with a value of '
                         '"min_examples_in_class" ({}) that results in some '
                         'classes being excluded.'.format(
                             self.min_examples_in_class))
      self.num_ways = self.num_classes

    # Maybe overwrite use_dag_hierarchy or use_bilevel_hierarchy if requested.
    if episode_descr_config.ignore_dag_ontology:
      self.use_dag_hierarchy = False
    if episode_descr_config.ignore_bilevel_ontology:
      self.use_bilevel_hierarchy = False

    # For Omniglot.
    if self.use_bilevel_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"num_ways".')
      if self.min_examples_in_class > 0:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"min_examples_in_class".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.BiLevelDatasetSpecification):
        raise ValueError('Only applicable to datasets with a bi-level '
                         'dataset specification.')
      # The id's of the superclasses of the split (a contiguous range of ints).
      all_superclasses = dataset_spec.get_superclasses(self.split)
      self.superclass_set = []
      for i in all_superclasses:
        if self.dataset_spec.classes_per_superclass[i] < self.min_ways:
          raise ValueError(
              'Superclass: %d has num_classes=%d < min_ways=%d.' %
              (i, self.dataset_spec.classes_per_superclass[i], self.min_ways))
        self.superclass_set.append(i)
    # For ImageNet.
    elif self.use_dag_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.HierarchicalDatasetSpecification):
        raise ValueError('Only applicable to datasets with a hierarchical '
                         'dataset specification.')

      # A DAG for navigating the ontology for the given split.
      graph = dataset_spec.get_split_subgraph(self.split)

      # Map the absolute class IDs in the split's class set to IDs relative to
      # the split.
      class_set = self.class_set
      abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set))

      # Extract the sets of leaves and internal nodes in the DAG.
      leaves = set(imagenet_specification.get_leaves(graph))
      internal_nodes = graph - leaves  # set difference

      # Map each node of the DAG to the Synsets of the leaves it spans.
      spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph)

      # Build a list of lists storing the relative class IDs of the spanning
      # leaves for each eligible internal node.
      self.span_leaves_rel = []
      for node in internal_nodes:
        node_leaves = spanning_leaves_dict[node]
        # Build a list of relative class IDs of leaves that have at least
        # min_examples_in_class examples.
        ids_rel = []
        for leaf in node_leaves:
          abs_id = dataset_spec.class_names_to_ids[leaf.wn_id]
          if abs_id in self._filtered_class_set:
            ids_rel.append(abs_to_rel_ids[abs_id])

        # Internal nodes are eligible if they span at least
        # `min_allowed_classes` and at most `max_eligible` leaves.
        if self.min_ways <= len(ids_rel) <= MAX_SPANNING_LEAVES_ELIGIBLE:
          self.span_leaves_rel.append(ids_rel)

      num_eligible_nodes = len(self.span_leaves_rel)
      if num_eligible_nodes < 1:
        raise ValueError('There are no classes eligible for participating in '
                         'episodes. Consider changing the value of '
                         '`EpisodeDescriptionSampler.min_ways` in gin, or '
                         'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')
 def list_leaf_num_images(split):
     return [
         self.images_per_class[split][n] for n in
         imagenet_specification.get_leaves(self.split_subgraphs[split])
     ]
 def count_split_classes(split):
     graph = self.split_subgraphs[split]
     leaves = imagenet_specification.get_leaves(graph)
     return len(leaves)
Esempio n. 9
0
  def __init__(self,
               dataset_spec,
               split,
               pool=None,
               use_dag_hierarchy=False,
               use_bilevel_hierarchy=False,
               use_all_classes=False,
               num_ways=None,
               num_support=None,
               num_query=None,
               min_ways=None,
               max_ways_upper_bound=None,
               max_num_query=None,
               max_support_set_size=None,
               max_support_size_contrib_per_class=None,
               min_log_weight=None,
               max_log_weight=None):
    """Initializes an EpisodeDescriptionSampler.

    Args:
      dataset_spec: DatasetSpecification, dataset specification.
      split: one of Split.TRAIN, Split.VALID, or Split.TEST.
      pool: A string ('train' or 'test') or None, indicating which example-level
        split to select, if the current dataset has them.
      use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured
        ontology is defined in dataset_spec, use it to choose related classes.
      use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology
        is defined in dataset_spec, use it for sampling classes.
      use_all_classes: Boolean, defaults to False. Uses all available classes,
        in order, instead of sampling. Overrides `num_ways` to the number of
        classes in `split`.
      num_ways: Integer (optional), fixes the number of classes ("ways") to be
        used in each episode if provided. Incompatible with using any hierarchy.
      num_support: Integer (optional), fixes the number of examples for each
        class in the support set if provided.
      num_query: Integer (optional), fixes the number of examples for each class
        in the query set if provided.
      min_ways: Integer, the minimum value when sampling ways (has to be
        provided if `num_ways` is None).
      max_ways_upper_bound: Integer, the maximum value when sampling ways (has
        to be provided through if `num_ways` is None). Note that the number of
        available classes acts as another upper bound.
      max_num_query: Integer, the maximum number of query examples per class
        (has to be provided if `num_query` is None).
      max_support_set_size: Integer, the maximum size for the support set (has
        to be provided `num_support` is None).
      max_support_size_contrib_per_class: Integer, the maximum contribution for
        any given class to the support set size. (has to be provided if
        `num_support` is None).
      min_log_weight: Float, the minimum log-weight to give to any particular
        class when determining the number of support examples per class (has to
        be provided if `num_support` is None).
      max_log_weight: Float, the maximum log-weight to give to any particular
        class (has to be provided if `num_support` is None).

    Raises:
      RuntimeError: if required parameters are missing.
      ValueError: Inconsistent parameters.
    """
    arg_groups = {
        'num_ways': (num_ways, ('min_ways', 'max_ways_upper_bound'),
                     (min_ways, max_ways_upper_bound)),
        'num_query': (num_query, ('max_num_query',), (max_num_query,)),
        'num_support':
            (num_support,
             ('max_support_set_size', 'max_support_size_contrib_per_class',
              'min_log_weight', 'max_log_weight'),
             (max_support_set_size, max_support_size_contrib_per_class,
              min_log_weight, max_log_weight)),
    }

    for first_arg_name, values in arg_groups.items():
      first_arg, required_arg_names, required_args = values
      if ((first_arg is None) and any(arg is None for arg in required_args)):
        # Get name of the nones
        none_arg_names = [
            name for var, name in zip(required_args, required_arg_names)
            if var is None
        ]
        raise RuntimeError(
            'The following arguments: %s can not be None, since %s is None. '
            'Arguments can be set up with gin, for instance by providing '
            '`--gin_file=learn/gin/setups/data_config.gin` or calling '
            '`gin.parse_config_file(...)` in the code. Please ensure the '
            'following gin arguments of EpisodeDescriptionSampler are set: '
            '%s' % (none_arg_names, first_arg_name, none_arg_names))
    self.dataset_spec = dataset_spec
    self.split = split
    self.pool = pool
    self.use_dag_hierarchy = use_dag_hierarchy
    self.use_bilevel_hierarchy = use_bilevel_hierarchy
    self.use_all_classes = use_all_classes
    self.num_ways = num_ways
    self.num_support = num_support
    self.num_query = num_query
    # Gin parameters
    self.min_ways = min_ways
    self.max_ways_upper_bound = max_ways_upper_bound
    self.max_num_query = max_num_query
    self.max_support_set_size = max_support_set_size
    self.max_support_size_contrib_per_class = max_support_size_contrib_per_class
    self.min_log_weight = min_log_weight
    self.max_log_weight = max_log_weight

    self.class_set = dataset_spec.get_classes(self.split)
    self.num_classes = len(self.class_set)

    if self.use_all_classes:
      self.num_ways = self.num_classes

    # For Omniglot.
    if self.use_bilevel_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.BiLevelDatasetSpecification):
        raise ValueError('Only applicable to datasets with a bi-level '
                         'dataset specification.')
      # The id's of the superclasses of the split (a contiguous range of ints).
      self.superclass_set = dataset_spec.get_superclasses(self.split)

    # For ImageNet.
    elif self.use_dag_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.HierarchicalDatasetSpecification):
        raise ValueError('Only applicable to datasets with a hierarchical '
                         'dataset specification.')

      # A DAG for navigating the ontology for the given split.
      graph = dataset_spec.get_split_subgraph(self.split)

      # Map the absolute class IDs in the split's class set to IDs relative to
      # the split.
      class_set = self.dataset_spec.get_classes(self.split)
      abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set))

      # Extract the sets of leaves and internal nodes in the DAG.
      leaves = set(imagenet_specification.get_leaves(graph))
      internal_nodes = graph - leaves  # set difference

      # Map each node of the DAG to the Synsets of the leaves it spans.
      spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph)

      # Build a list of lists storing the relative class IDs of the spanning
      # leaves for each eligible internal node.
      self.span_leaves_rel = []
      for node in internal_nodes:
        node_leaves = spanning_leaves_dict[node]
        # Internal nodes are eligible if they span at least
        # `min_allowed_classes` and at most `max_eligible` leaves.
        if self.min_ways <= len(node_leaves) <= MAX_SPANNING_LEAVES_ELIGIBLE:
          # Build a list of relative class IDs for this internal node.
          ids = [dataset_spec.class_names_to_ids[s.wn_id] for s in node_leaves]
          ids_rel = [abs_to_rel_ids[abs_id] for abs_id in ids]
          self.span_leaves_rel.append(ids_rel)

      num_eligible_nodes = len(self.span_leaves_rel)
      if num_eligible_nodes < 1:
        raise ValueError('There are no classes eligible for participating in '
                         'episodes. Consider changing the value of '
                         '`EpisodeDescriptionSampler.min_ways` in gin, or '
                         'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')
Esempio n. 10
0
  def __init__(self,
               dataset_spec,
               split,
               episode_descr_config,
               pool=None,
               use_dag_hierarchy=False,
               use_bilevel_hierarchy=False,
               use_all_classes=False):
    """Initializes an EpisodeDescriptionSampler.episode_config.

    Args:
      dataset_spec: DatasetSpecification, dataset specification.
      split: one of Split.TRAIN, Split.VALID, or Split.TEST.
      episode_descr_config: An instance of EpisodeDescriptionConfig containing
        parameters relating to sampling shots and ways for episodes.
      pool: A string ('train' or 'test') or None, indicating which example-level
        split to select, if the current dataset has them.
      use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured
        ontology is defined in dataset_spec, use it to choose related classes.
      use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology
        is defined in dataset_spec, use it for sampling classes.
      use_all_classes: Boolean, defaults to False. Uses all available classes,
        in order, instead of sampling. Overrides `num_ways` to the number of
        classes in `split`.

    Raises:
      RuntimeError: if required parameters are missing.
      ValueError: Inconsistent parameters.
    """
    self.dataset_spec = dataset_spec
    self.split = split
    self.pool = pool
    self.use_dag_hierarchy = use_dag_hierarchy
    self.use_bilevel_hierarchy = use_bilevel_hierarchy
    self.use_all_classes = use_all_classes
    self.num_ways = episode_descr_config.num_ways
    self.num_support = episode_descr_config.num_support
    self.num_query = episode_descr_config.num_query
    self.min_ways = episode_descr_config.min_ways
    self.max_ways_upper_bound = episode_descr_config.max_ways_upper_bound
    self.max_num_query = episode_descr_config.max_num_query
    self.max_support_set_size = episode_descr_config.max_support_set_size
    self.max_support_size_contrib_per_class = episode_descr_config.max_support_size_contrib_per_class
    self.min_log_weight = episode_descr_config.min_log_weight
    self.max_log_weight = episode_descr_config.max_log_weight

    self.class_set = dataset_spec.get_classes(self.split)
    self.num_classes = len(self.class_set)

    if self.use_all_classes:
      self.num_ways = self.num_classes

    # Maybe overwrite use_dag_hierarchy or use_bilevel_hierarchy if requested.
    if episode_descr_config.ignore_dag_ontology:
      self.use_dag_hierarchy = False
    if episode_descr_config.ignore_bilevel_ontology:
      self.use_bilevel_hierarchy = False

    # For Omniglot.
    if self.use_bilevel_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_bilevel_hierarchy" is incompatible with '
                         '"num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.BiLevelDatasetSpecification):
        raise ValueError('Only applicable to datasets with a bi-level '
                         'dataset specification.')
      # The id's of the superclasses of the split (a contiguous range of ints).
      self.superclass_set = dataset_spec.get_superclasses(self.split)

    # For ImageNet.
    elif self.use_dag_hierarchy:
      if self.num_ways is not None:
        raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".')

      if not isinstance(dataset_spec,
                        dataset_spec_lib.HierarchicalDatasetSpecification):
        raise ValueError('Only applicable to datasets with a hierarchical '
                         'dataset specification.')

      # A DAG for navigating the ontology for the given split.
      graph = dataset_spec.get_split_subgraph(self.split)

      # Map the absolute class IDs in the split's class set to IDs relative to
      # the split.
      class_set = self.dataset_spec.get_classes(self.split)
      abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set))

      # Extract the sets of leaves and internal nodes in the DAG.
      leaves = set(imagenet_specification.get_leaves(graph))
      internal_nodes = graph - leaves  # set difference

      # Map each node of the DAG to the Synsets of the leaves it spans.
      spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph)

      # Build a list of lists storing the relative class IDs of the spanning
      # leaves for each eligible internal node.
      self.span_leaves_rel = []
      for node in internal_nodes:
        node_leaves = spanning_leaves_dict[node]
        # Internal nodes are eligible if they span at least
        # `min_allowed_classes` and at most `max_eligible` leaves.
        if self.min_ways <= len(node_leaves) <= MAX_SPANNING_LEAVES_ELIGIBLE:
          # Build a list of relative class IDs for this internal node.
          ids = [dataset_spec.class_names_to_ids[s.wn_id] for s in node_leaves]
          ids_rel = [abs_to_rel_ids[abs_id] for abs_id in ids]
          self.span_leaves_rel.append(ids_rel)

      num_eligible_nodes = len(self.span_leaves_rel)
      if num_eligible_nodes < 1:
        raise ValueError('There are no classes eligible for participating in '
                         'episodes. Consider changing the value of '
                         '`EpisodeDescriptionSampler.min_ways` in gin, or '
                         'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')