コード例 #1
0
ファイル: test_pipeline.py プロジェクト: vsarmien/turicreate
    def test_pipeline_classifier_creation(self):

        input_names = self.scikit_data.feature_names
        p_classifier = PipelineClassifier(input_names, [1, 0])
        p_classifier.add_model(self.libsvm_spec)

        self.assertIsNotNone(p_classifier.spec)
        self.assertEqual(
            len(p_classifier.spec.pipelineClassifier.pipeline.models), 1)

        # Test the model class of the svm model
        spec = p_classifier.spec.pipelineClassifier.pipeline.models[0]
        self.assertIsNotNone(spec.description)

        # Test the interface class
        self.assertEqual(spec.description.predictedFeatureName, 'target')

        # Test the inputs and outputs
        self.assertEqual(len(spec.description.output), 1)
        self.assertEqual(spec.description.output[0].name, 'target')
        self.assertEqual(spec.description.output[0].type.WhichOneof('Type'),
                         'int64Type')

        for input_type in spec.description.input:
            self.assertEqual(input_type.type.WhichOneof('Type'), 'doubleType')
        self.assertEqual(sorted(input_names),
                         sorted(map(lambda x: x.name, spec.description.input)))
コード例 #2
0
    def test_pipeline_classifier_set_training_inputs(self):
        builder = self.create_base_builder()
        builder.spec.isUpdatable = False
        training_input = [("input", datatypes.Array(3)), ("target", "String")]

        # fails due to missing sub-models
        p_classifier = PipelineClassifier(self.input_features,
                                          self.output_names)
        p_classifier.set_training_input(training_input)
        with self.assertRaises(ValueError):
            p_classifier.make_updatable()
        self.assertEqual(p_classifier.spec.isUpdatable, False)

        # fails due to sub-model being not updatable
        p_classifier.add_model(builder.spec)
        with self.assertRaises(ValueError):
            p_classifier.make_updatable()
        self.assertEqual(p_classifier.spec.isUpdatable, False)

        builder.spec.isUpdatable = True
        p_classifier.add_model(builder.spec)

        self.assertEqual(p_classifier.spec.isUpdatable, False)
        p_classifier.make_updatable()
        self.assertEqual(p_classifier.spec.isUpdatable, True)
        self.assertEqual(p_classifier.spec.description.trainingInput[0].name,
                         "input")
        self.assertEqual(
            p_classifier.spec.description.trainingInput[0].type.WhichOneof(
                "Type"),
            "multiArrayType",
        )
        self.assertEqual(p_classifier.spec.description.trainingInput[1].name,
                         "target")
        self.assertEqual(
            p_classifier.spec.description.trainingInput[1].type.WhichOneof(
                "Type"),
            "stringType",
        )

        # fails since once updatable does not allow adding new models
        with self.assertRaises(ValueError):
            p_classifier.add_model(builder.spec)
        self.assertEqual(p_classifier.spec.isUpdatable, True)
コード例 #3
0
    def test_pipeline_classifier_make_updatable(self):
        builder = self.create_base_builder()
        builder.spec.isUpdatable = False
        training_input = [('input', datatypes.Array(3)), ('target', 'String')]

        # fails due to missing sub-models
        p_classifier = PipelineClassifier(self.input_features,
                                          self.output_names,
                                          training_features=training_input)
        with self.assertRaises(ValueError):
            p_classifier.make_updatable()
        self.assertEqual(p_classifier.spec.isUpdatable, False)

        # fails due to sub-model being not updatable
        p_classifier.add_model(builder.spec)
        with self.assertRaises(ValueError):
            p_classifier.make_updatable()
        self.assertEqual(p_classifier.spec.isUpdatable, False)

        builder.spec.isUpdatable = True
        p_classifier.add_model(builder.spec)

        self.assertEqual(p_classifier.spec.isUpdatable, False)
        p_classifier.make_updatable()
        self.assertEqual(p_classifier.spec.isUpdatable, True)
        self.assertEqual(p_classifier.spec.description.trainingInput[0].name,
                         'input')
        self.assertEqual(
            p_classifier.spec.description.trainingInput[0].type.WhichOneof(
                'Type'), 'multiArrayType')
        self.assertEqual(p_classifier.spec.description.trainingInput[1].name,
                         'target')
        self.assertEqual(
            p_classifier.spec.description.trainingInput[1].type.WhichOneof(
                'Type'), 'stringType')

        # fails since once updatable does not allow adding new models
        with self.assertRaises(ValueError):
            p_classifier.add_model(builder.spec)
        self.assertEqual(p_classifier.spec.isUpdatable, True)