def randomly_sample_things_with_attribute(tx, query, example_var_name, attribute_var_name, attribute_vals, sample_size_per_label, population_size): concepts = {} labels = {} for a in attribute_vals: target_concept_query = query.format(a) extractor = label_extraction.ConceptLabelExtractor( target_concept_query, (example_var_name, collections.OrderedDict([(attribute_var_name, attribute_vals)])), sampling_method=random.random_sample) concepts_with_labels = extractor(tx, sample_size_per_label, population_size) if len(concepts_with_labels) == 0: raise RuntimeError( f'Couldn\'t find any concepts to match target query "{target_concept_query}"' ) concepts[a] = [ concepts_with_label[0] for concepts_with_label in concepts_with_labels ] # TODO Should this be an array not a list? labels[a] = np.array([ concepts_with_label[1][attribute_var_name] for concepts_with_label in concepts_with_labels ], dtype=np.float32) return concepts, labels
def test_get_called_for_each_attribute_variable(self): concept_label_extractor = label_extraction.ConceptLabelExtractor( self._query, self._vars_config) concept_label_extractor(self._grakn_tx, 5, 5) self._answer_mock.get.assert_has_calls( [mock.call('x'), mock.call('age_var'), mock.call('gender_var')])
def test_output_format_as_expected(self): concept_label_extractor = label_extraction.ConceptLabelExtractor( self._query, self._vars_config) concepts_with_labels = concept_label_extractor(self._grakn_tx, 5, 5) expected_output = [(self._person_mock, { 'age_var': [66], 'gender_var': [0, 1] })] self.assertListEqual(expected_output, concepts_with_labels)
def test_value_called_for_each_attribute(self): concept_label_extractor = label_extraction.ConceptLabelExtractor( self._query, self._vars_config) concept_label_extractor(self._grakn_tx, 5, 5) self._mock_age.value.assert_called_once() self._mock_gender.value.assert_called_once()