def create_toy_graph(): synsets = {} for wn_id, name in enumerate(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']): synsets[name] = imagenet_spec.Synset(wn_id, name, set(), set()) # (parent-child) tuples is_a_relations = [('a', 'b'), ('a', 'c'), ('b', 'g'), ('c', 'd'), ('c', 'e'), ('e', 'f'), ('e', 'h')] # The graph is a tree that looks like: # a # b c # g d e # h f for t in is_a_relations: parent, child = t synsets[parent].children.add(synsets[child]) synsets[child].parents.add(synsets[parent]) subset = ['f', 'g'] synsets_subset = [s for s in synsets.values() if s.words in subset] # Get the graph of all and only the ancestors of synsets_subset graph_nodes = imagenet_spec.create_sampling_graph(synsets_subset) # The created graph should contain all and only the ancestors of subset and # collapses all nodes that have exactly 1 child. It should be: # a # g f # Create a data structure mapping graph_nodes to their reachable leaves spanning_leaves = imagenet_spec.get_spanning_leaves(graph_nodes) return graph_nodes, spanning_leaves, synsets_subset
def test_imagenet_specification(self): spec = imagenet_spec.create_imagenet_specification(learning_spec.Split) splits, _, graph_nodes, synsets_2012, num_synset_2012_images, roots = spec span_leaves = imagenet_spec.get_spanning_leaves(graph_nodes) num_span_images = imagenet_spec.get_num_spanning_images( span_leaves, num_synset_2012_images) validate_graph(graph_nodes, synsets_2012, self) validate_spanning_leaves(span_leaves, synsets_2012, self) self.validate_splits(splits) self.validate_num_span_images(span_leaves, num_span_images) test_lowest_common_ancestor(graph_nodes, self) test_get_upward_paths(graph_nodes, self) # Make sure that in no sub-tree can the LCA of two chosen leaves of that # sub-tree be a node that is an ancestor of the sub-tree's root. valid_subgraph, test_subgraph = splits[learning_spec.Split.VALID], splits[ learning_spec.Split.TEST] valid_root, test_root = roots['valid'], roots['test'] test_lowest_common_ancestor(valid_subgraph, self, valid_root) test_get_upward_paths(valid_subgraph, self, valid_root) test_lowest_common_ancestor(test_subgraph, self, test_root) test_get_upward_paths(test_subgraph, self, test_root)
def __init__(self, dataset_spec, split, episode_descr_config, pool=None, use_dag_hierarchy=False, use_bilevel_hierarchy=False, use_all_classes=False): """Initializes an EpisodeDescriptionSampler.episode_config. Args: dataset_spec: DatasetSpecification, dataset specification. split: one of Split.TRAIN, Split.VALID, or Split.TEST. episode_descr_config: An instance of EpisodeDescriptionConfig containing parameters relating to sampling shots and ways for episodes. pool: A string ('train' or 'test') or None, indicating which example-level split to select, if the current dataset has them. use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured ontology is defined in dataset_spec, use it to choose related classes. use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology is defined in dataset_spec, use it for sampling classes. use_all_classes: Boolean, defaults to False. Uses all available classes, in order, instead of sampling. Overrides `num_ways` to the number of classes in `split`. Raises: RuntimeError: if required parameters are missing. ValueError: Inconsistent parameters. """ self.dataset_spec = dataset_spec self.split = split self.pool = pool self.use_dag_hierarchy = use_dag_hierarchy self.use_bilevel_hierarchy = use_bilevel_hierarchy self.use_all_classes = use_all_classes self.num_ways = episode_descr_config.num_ways self.num_support = episode_descr_config.num_support self.num_query = episode_descr_config.num_query self.min_ways = episode_descr_config.min_ways self.max_ways_upper_bound = episode_descr_config.max_ways_upper_bound self.max_num_query = episode_descr_config.max_num_query self.max_support_set_size = episode_descr_config.max_support_set_size self.max_support_size_contrib_per_class = episode_descr_config.max_support_size_contrib_per_class self.min_log_weight = episode_descr_config.min_log_weight self.max_log_weight = episode_descr_config.max_log_weight self.min_examples_in_class = episode_descr_config.min_examples_in_class self.class_set = dataset_spec.get_classes(self.split) self.num_classes = len(self.class_set) # Filter out classes with too few examples self._filtered_class_set = [] # Store (class_id, n_examples) of skipped classes for logging. skipped_classes = [] for class_id in self.class_set: n_examples = dataset_spec.get_total_images_per_class(class_id, pool=pool) if n_examples < self.min_examples_in_class: skipped_classes.append((class_id, n_examples)) else: self._filtered_class_set.append(class_id) self.num_filtered_classes = len(self._filtered_class_set) if skipped_classes: logging.info( 'Skipping the following classes, which do not have at least ' '%d examples', self.min_examples_in_class) for class_id, n_examples in skipped_classes: logging.info('%s (ID=%d, %d examples)', dataset_spec.class_names[class_id], class_id, n_examples) if self.min_ways and self.num_filtered_classes < self.min_ways: raise ValueError( '"min_ways" is set to {}, but split {} of dataset {} only has {} ' 'classes with at least {} examples ({} total), so it is not possible ' 'to create an episode for it. This may have resulted from applying a ' 'restriction on this split of this dataset by specifying ' 'benchmark.restrict_classes or benchmark.min_examples_in_class.' .format(self.min_ways, split, dataset_spec.name, self.num_filtered_classes, self.min_examples_in_class, self.num_classes)) if self.use_all_classes: if self.num_classes != self.num_filtered_classes: raise ValueError('"use_all_classes" is not compatible with a value of ' '"min_examples_in_class" ({}) that results in some ' 'classes being excluded.'.format( self.min_examples_in_class)) self.num_ways = self.num_classes # Maybe overwrite use_dag_hierarchy or use_bilevel_hierarchy if requested. if episode_descr_config.ignore_dag_ontology: self.use_dag_hierarchy = False if episode_descr_config.ignore_bilevel_ontology: self.use_bilevel_hierarchy = False # For Omniglot. if self.use_bilevel_hierarchy: if self.num_ways is not None: raise ValueError('"use_bilevel_hierarchy" is incompatible with ' '"num_ways".') if self.min_examples_in_class > 0: raise ValueError('"use_bilevel_hierarchy" is incompatible with ' '"min_examples_in_class".') if not isinstance(dataset_spec, dataset_spec_lib.BiLevelDatasetSpecification): raise ValueError('Only applicable to datasets with a bi-level ' 'dataset specification.') # The id's of the superclasses of the split (a contiguous range of ints). all_superclasses = dataset_spec.get_superclasses(self.split) self.superclass_set = [] for i in all_superclasses: if self.dataset_spec.classes_per_superclass[i] < self.min_ways: raise ValueError( 'Superclass: %d has num_classes=%d < min_ways=%d.' % (i, self.dataset_spec.classes_per_superclass[i], self.min_ways)) self.superclass_set.append(i) # For ImageNet. elif self.use_dag_hierarchy: if self.num_ways is not None: raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".') if not isinstance(dataset_spec, dataset_spec_lib.HierarchicalDatasetSpecification): raise ValueError('Only applicable to datasets with a hierarchical ' 'dataset specification.') # A DAG for navigating the ontology for the given split. graph = dataset_spec.get_split_subgraph(self.split) # Map the absolute class IDs in the split's class set to IDs relative to # the split. class_set = self.class_set abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set)) # Extract the sets of leaves and internal nodes in the DAG. leaves = set(imagenet_specification.get_leaves(graph)) internal_nodes = graph - leaves # set difference # Map each node of the DAG to the Synsets of the leaves it spans. spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph) # Build a list of lists storing the relative class IDs of the spanning # leaves for each eligible internal node. self.span_leaves_rel = [] for node in internal_nodes: node_leaves = spanning_leaves_dict[node] # Build a list of relative class IDs of leaves that have at least # min_examples_in_class examples. ids_rel = [] for leaf in node_leaves: abs_id = dataset_spec.class_names_to_ids[leaf.wn_id] if abs_id in self._filtered_class_set: ids_rel.append(abs_to_rel_ids[abs_id]) # Internal nodes are eligible if they span at least # `min_allowed_classes` and at most `max_eligible` leaves. if self.min_ways <= len(ids_rel) <= MAX_SPANNING_LEAVES_ELIGIBLE: self.span_leaves_rel.append(ids_rel) num_eligible_nodes = len(self.span_leaves_rel) if num_eligible_nodes < 1: raise ValueError('There are no classes eligible for participating in ' 'episodes. Consider changing the value of ' '`EpisodeDescriptionSampler.min_ways` in gin, or ' 'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')
def __init__(self, dataset_spec, split, pool=None, use_dag_hierarchy=False, use_bilevel_hierarchy=False, use_all_classes=False, num_ways=None, num_support=None, num_query=None, min_ways=None, max_ways_upper_bound=None, max_num_query=None, max_support_set_size=None, max_support_size_contrib_per_class=None, min_log_weight=None, max_log_weight=None): """Initializes an EpisodeDescriptionSampler. Args: dataset_spec: DatasetSpecification, dataset specification. split: one of Split.TRAIN, Split.VALID, or Split.TEST. pool: A string ('train' or 'test') or None, indicating which example-level split to select, if the current dataset has them. use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured ontology is defined in dataset_spec, use it to choose related classes. use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology is defined in dataset_spec, use it for sampling classes. use_all_classes: Boolean, defaults to False. Uses all available classes, in order, instead of sampling. Overrides `num_ways` to the number of classes in `split`. num_ways: Integer (optional), fixes the number of classes ("ways") to be used in each episode if provided. Incompatible with using any hierarchy. num_support: Integer (optional), fixes the number of examples for each class in the support set if provided. num_query: Integer (optional), fixes the number of examples for each class in the query set if provided. min_ways: Integer, the minimum value when sampling ways (has to be provided if `num_ways` is None). max_ways_upper_bound: Integer, the maximum value when sampling ways (has to be provided through if `num_ways` is None). Note that the number of available classes acts as another upper bound. max_num_query: Integer, the maximum number of query examples per class (has to be provided if `num_query` is None). max_support_set_size: Integer, the maximum size for the support set (has to be provided `num_support` is None). max_support_size_contrib_per_class: Integer, the maximum contribution for any given class to the support set size. (has to be provided if `num_support` is None). min_log_weight: Float, the minimum log-weight to give to any particular class when determining the number of support examples per class (has to be provided if `num_support` is None). max_log_weight: Float, the maximum log-weight to give to any particular class (has to be provided if `num_support` is None). Raises: RuntimeError: if required parameters are missing. ValueError: Inconsistent parameters. """ arg_groups = { 'num_ways': (num_ways, ('min_ways', 'max_ways_upper_bound'), (min_ways, max_ways_upper_bound)), 'num_query': (num_query, ('max_num_query',), (max_num_query,)), 'num_support': (num_support, ('max_support_set_size', 'max_support_size_contrib_per_class', 'min_log_weight', 'max_log_weight'), (max_support_set_size, max_support_size_contrib_per_class, min_log_weight, max_log_weight)), } for first_arg_name, values in arg_groups.items(): first_arg, required_arg_names, required_args = values if ((first_arg is None) and any(arg is None for arg in required_args)): # Get name of the nones none_arg_names = [ name for var, name in zip(required_args, required_arg_names) if var is None ] raise RuntimeError( 'The following arguments: %s can not be None, since %s is None. ' 'Arguments can be set up with gin, for instance by providing ' '`--gin_file=learn/gin/setups/data_config.gin` or calling ' '`gin.parse_config_file(...)` in the code. Please ensure the ' 'following gin arguments of EpisodeDescriptionSampler are set: ' '%s' % (none_arg_names, first_arg_name, none_arg_names)) self.dataset_spec = dataset_spec self.split = split self.pool = pool self.use_dag_hierarchy = use_dag_hierarchy self.use_bilevel_hierarchy = use_bilevel_hierarchy self.use_all_classes = use_all_classes self.num_ways = num_ways self.num_support = num_support self.num_query = num_query # Gin parameters self.min_ways = min_ways self.max_ways_upper_bound = max_ways_upper_bound self.max_num_query = max_num_query self.max_support_set_size = max_support_set_size self.max_support_size_contrib_per_class = max_support_size_contrib_per_class self.min_log_weight = min_log_weight self.max_log_weight = max_log_weight self.class_set = dataset_spec.get_classes(self.split) self.num_classes = len(self.class_set) if self.use_all_classes: self.num_ways = self.num_classes # For Omniglot. if self.use_bilevel_hierarchy: if self.num_ways is not None: raise ValueError('"use_bilevel_hierarchy" is incompatible with ' '"num_ways".') if not isinstance(dataset_spec, dataset_spec_lib.BiLevelDatasetSpecification): raise ValueError('Only applicable to datasets with a bi-level ' 'dataset specification.') # The id's of the superclasses of the split (a contiguous range of ints). self.superclass_set = dataset_spec.get_superclasses(self.split) # For ImageNet. elif self.use_dag_hierarchy: if self.num_ways is not None: raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".') if not isinstance(dataset_spec, dataset_spec_lib.HierarchicalDatasetSpecification): raise ValueError('Only applicable to datasets with a hierarchical ' 'dataset specification.') # A DAG for navigating the ontology for the given split. graph = dataset_spec.get_split_subgraph(self.split) # Map the absolute class IDs in the split's class set to IDs relative to # the split. class_set = self.dataset_spec.get_classes(self.split) abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set)) # Extract the sets of leaves and internal nodes in the DAG. leaves = set(imagenet_specification.get_leaves(graph)) internal_nodes = graph - leaves # set difference # Map each node of the DAG to the Synsets of the leaves it spans. spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph) # Build a list of lists storing the relative class IDs of the spanning # leaves for each eligible internal node. self.span_leaves_rel = [] for node in internal_nodes: node_leaves = spanning_leaves_dict[node] # Internal nodes are eligible if they span at least # `min_allowed_classes` and at most `max_eligible` leaves. if self.min_ways <= len(node_leaves) <= MAX_SPANNING_LEAVES_ELIGIBLE: # Build a list of relative class IDs for this internal node. ids = [dataset_spec.class_names_to_ids[s.wn_id] for s in node_leaves] ids_rel = [abs_to_rel_ids[abs_id] for abs_id in ids] self.span_leaves_rel.append(ids_rel) num_eligible_nodes = len(self.span_leaves_rel) if num_eligible_nodes < 1: raise ValueError('There are no classes eligible for participating in ' 'episodes. Consider changing the value of ' '`EpisodeDescriptionSampler.min_ways` in gin, or ' 'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')
def __init__(self, dataset_spec, split, episode_descr_config, pool=None, use_dag_hierarchy=False, use_bilevel_hierarchy=False, use_all_classes=False): """Initializes an EpisodeDescriptionSampler.episode_config. Args: dataset_spec: DatasetSpecification, dataset specification. split: one of Split.TRAIN, Split.VALID, or Split.TEST. episode_descr_config: An instance of EpisodeDescriptionConfig containing parameters relating to sampling shots and ways for episodes. pool: A string ('train' or 'test') or None, indicating which example-level split to select, if the current dataset has them. use_dag_hierarchy: Boolean, defaults to False. If a DAG-structured ontology is defined in dataset_spec, use it to choose related classes. use_bilevel_hierarchy: Boolean, defaults to False. If a bi-level ontology is defined in dataset_spec, use it for sampling classes. use_all_classes: Boolean, defaults to False. Uses all available classes, in order, instead of sampling. Overrides `num_ways` to the number of classes in `split`. Raises: RuntimeError: if required parameters are missing. ValueError: Inconsistent parameters. """ self.dataset_spec = dataset_spec self.split = split self.pool = pool self.use_dag_hierarchy = use_dag_hierarchy self.use_bilevel_hierarchy = use_bilevel_hierarchy self.use_all_classes = use_all_classes self.num_ways = episode_descr_config.num_ways self.num_support = episode_descr_config.num_support self.num_query = episode_descr_config.num_query self.min_ways = episode_descr_config.min_ways self.max_ways_upper_bound = episode_descr_config.max_ways_upper_bound self.max_num_query = episode_descr_config.max_num_query self.max_support_set_size = episode_descr_config.max_support_set_size self.max_support_size_contrib_per_class = episode_descr_config.max_support_size_contrib_per_class self.min_log_weight = episode_descr_config.min_log_weight self.max_log_weight = episode_descr_config.max_log_weight self.class_set = dataset_spec.get_classes(self.split) self.num_classes = len(self.class_set) if self.use_all_classes: self.num_ways = self.num_classes # Maybe overwrite use_dag_hierarchy or use_bilevel_hierarchy if requested. if episode_descr_config.ignore_dag_ontology: self.use_dag_hierarchy = False if episode_descr_config.ignore_bilevel_ontology: self.use_bilevel_hierarchy = False # For Omniglot. if self.use_bilevel_hierarchy: if self.num_ways is not None: raise ValueError('"use_bilevel_hierarchy" is incompatible with ' '"num_ways".') if not isinstance(dataset_spec, dataset_spec_lib.BiLevelDatasetSpecification): raise ValueError('Only applicable to datasets with a bi-level ' 'dataset specification.') # The id's of the superclasses of the split (a contiguous range of ints). self.superclass_set = dataset_spec.get_superclasses(self.split) # For ImageNet. elif self.use_dag_hierarchy: if self.num_ways is not None: raise ValueError('"use_dag_hierarchy" is incompatible with "num_ways".') if not isinstance(dataset_spec, dataset_spec_lib.HierarchicalDatasetSpecification): raise ValueError('Only applicable to datasets with a hierarchical ' 'dataset specification.') # A DAG for navigating the ontology for the given split. graph = dataset_spec.get_split_subgraph(self.split) # Map the absolute class IDs in the split's class set to IDs relative to # the split. class_set = self.dataset_spec.get_classes(self.split) abs_to_rel_ids = dict((abs_id, i) for i, abs_id in enumerate(class_set)) # Extract the sets of leaves and internal nodes in the DAG. leaves = set(imagenet_specification.get_leaves(graph)) internal_nodes = graph - leaves # set difference # Map each node of the DAG to the Synsets of the leaves it spans. spanning_leaves_dict = imagenet_specification.get_spanning_leaves(graph) # Build a list of lists storing the relative class IDs of the spanning # leaves for each eligible internal node. self.span_leaves_rel = [] for node in internal_nodes: node_leaves = spanning_leaves_dict[node] # Internal nodes are eligible if they span at least # `min_allowed_classes` and at most `max_eligible` leaves. if self.min_ways <= len(node_leaves) <= MAX_SPANNING_LEAVES_ELIGIBLE: # Build a list of relative class IDs for this internal node. ids = [dataset_spec.class_names_to_ids[s.wn_id] for s in node_leaves] ids_rel = [abs_to_rel_ids[abs_id] for abs_id in ids] self.span_leaves_rel.append(ids_rel) num_eligible_nodes = len(self.span_leaves_rel) if num_eligible_nodes < 1: raise ValueError('There are no classes eligible for participating in ' 'episodes. Consider changing the value of ' '`EpisodeDescriptionSampler.min_ways` in gin, or ' 'or MAX_SPANNING_LEAVES_ELIGIBLE in data.py.')