Example #1
0
def get_synsets_from_class_ids(class_ids):
    """Returns the Synsets of the appropriate subgraph corresponding to class_ids.

  For each class id in class_ids, the corresponding Synset is found among the
  Synsets of the subgraph corresponding to the split that is chosen for the
  fine-grainedness analysis.

  Args:
    class_ids: A np.array of ints in the range between 1 and the total number of
      classes that contains the two class id's chosen for an episode.

  Returns:
    A list of Synsets.

  Raises:
    ValueError: The dataset specification is not found in the expected location.
  """
    # First load the DatasetSpecification of ImageNet.
    dataset_records_path = os.path.join(FLAGS.records_root_dir, 'ilsvrc_2012')
    imagenet_data_spec = dataset_spec.load_dataset_spec(dataset_records_path)

    # A set of Synsets of the split's subgraph.
    split_enum = get_finegrainedness_split_enum()
    split_subgraph = imagenet_data_spec.split_subgraphs[split_enum]

    # Go from class_ids (integers in the range from 1 to the total number of
    # classes in the Split) to WordNet id's, e.g n02075296.
    wn_ids = []
    for class_id in class_ids:
        wn_ids.append(imagenet_data_spec.class_names[class_id])

    # Find the Synsets in split_subgraph whose WordNet id's are wn_ids.
    synsets = imagenet_spec.get_synsets_from_ids(wn_ids, split_subgraph)
    return [synsets[wn_id] for wn_id in wn_ids]
Example #2
0
 def test_on_toy_graph(self):
   specification = create_toy_graph()
   toy_graph, _, _ = specification
   wn_ids = [5, 0, 6]
   id_to_synset = imagenet_spec.get_synsets_from_ids(wn_ids, toy_graph)
   self.assertEqual(set(id_to_synset.keys()), set(wn_ids))
   for wn_id, synset in id_to_synset.items():
     self.assertEqual(wn_id, synset.wn_id)