def test_infer_ten_classes(): analyser = output_analysers.ClassificationAnalyser(name="a") dataset = utils.generate_one_hot_labels(dtype="dataset", num_classes=10) for data in dataset: analyser.update(data) analyser.finalize() assert analyser.num_classes == 10
def test_multi_label_two_classes_has_two_columns(): analyser = output_analysers.ClassificationAnalyser(name="a", multi_label=True) dataset = tf.data.Dataset.from_tensor_slices(np.random.rand(10, 2)).batch(32) for data in dataset: analyser.update(data) analyser.finalize() assert analyser.encoded
def test_specify_two_classes_fit_single_column(): analyser = output_analysers.ClassificationAnalyser(name="a", num_classes=2) dataset = tf.data.Dataset.from_tensor_slices(np.random.rand(10, 1)).batch(32) for data in dataset: analyser.update(data) analyser.finalize() assert analyser.num_classes == 2
def test_infer_single_column_two_classes(): analyser = output_analysers.ClassificationAnalyser(name="a") dataset = tf.data.Dataset.from_tensor_slices(np.random.randint( 0, 2, 10)).batch(32) for data in dataset: analyser.update(data) analyser.finalize() assert analyser.num_classes == 2
def test_wrong_num_classes_error(): analyser = output_analysers.ClassificationAnalyser(name="a", num_classes=5) dataset = tf.data.Dataset.from_tensor_slices(np.random.rand(10, 3)).batch(32) with pytest.raises(ValueError) as info: for data in dataset: analyser.update(data) analyser.finalize() assert "Expect the target data for a to have shape" in str(info.value)
def test_clf_head_one_hot_shape_error(): analyser = output_analysers.ClassificationAnalyser(name="a", num_classes=9) dataset = tf.data.Dataset.from_tensor_slices( utils.generate_one_hot_labels(dtype="np", num_classes=10)).batch(32) with pytest.raises(ValueError) as info: for data in dataset: analyser.update(data) analyser.finalize() assert "Expect the target data for a to have shape" in str(info.value)
def test_one_class_error(): analyser = output_analysers.ClassificationAnalyser(name="a") dataset = tf.data.Dataset.from_tensor_slices(np.array(["a", "a", "a"])).batch(32) with pytest.raises(ValueError) as info: for data in dataset: analyser.update(data) analyser.finalize() assert "Expect the target data for a to have at least 2 classes" in str( info.value)