コード例 #1
0
def randomly_sample_things_with_attribute(tx, query, example_var_name,
                                          attribute_var_name, attribute_vals,
                                          sample_size_per_label,
                                          population_size):
    concepts = {}
    labels = {}

    for a in attribute_vals:
        target_concept_query = query.format(a)

        extractor = label_extraction.ConceptLabelExtractor(
            target_concept_query,
            (example_var_name,
             collections.OrderedDict([(attribute_var_name, attribute_vals)])),
            sampling_method=random.random_sample)
        concepts_with_labels = extractor(tx, sample_size_per_label,
                                         population_size)
        if len(concepts_with_labels) == 0:
            raise RuntimeError(
                f'Couldn\'t find any concepts to match target query "{target_concept_query}"'
            )

        concepts[a] = [
            concepts_with_label[0]
            for concepts_with_label in concepts_with_labels
        ]
        # TODO Should this be an array not a list?
        labels[a] = np.array([
            concepts_with_label[1][attribute_var_name]
            for concepts_with_label in concepts_with_labels
        ],
                             dtype=np.float32)

    return concepts, labels
コード例 #2
0
 def test_get_called_for_each_attribute_variable(self):
     concept_label_extractor = label_extraction.ConceptLabelExtractor(
         self._query, self._vars_config)
     concept_label_extractor(self._grakn_tx, 5, 5)
     self._answer_mock.get.assert_has_calls(
         [mock.call('x'),
          mock.call('age_var'),
          mock.call('gender_var')])
コード例 #3
0
    def test_output_format_as_expected(self):

        concept_label_extractor = label_extraction.ConceptLabelExtractor(
            self._query, self._vars_config)
        concepts_with_labels = concept_label_extractor(self._grakn_tx, 5, 5)

        expected_output = [(self._person_mock, {
            'age_var': [66],
            'gender_var': [0, 1]
        })]
        self.assertListEqual(expected_output, concepts_with_labels)
コード例 #4
0
 def test_value_called_for_each_attribute(self):
     concept_label_extractor = label_extraction.ConceptLabelExtractor(
         self._query, self._vars_config)
     concept_label_extractor(self._grakn_tx, 5, 5)
     self._mock_age.value.assert_called_once()
     self._mock_gender.value.assert_called_once()