Example #1
0
    def test_tsvm_get_param_names(self):
        """
        It checks the names of the hyper-parameters of TSVM estimator that returned
        by its get_param_names
        """

        self.assertEqual([
            TSVM('linear').get_params_names(),
            TSVM('RBF').get_params_names()
        ], [['C1', 'C2'], ['C1', 'C2', 'gamma']])
Example #2
0
    def test_tsvm_set_get_params_rbf(self):
        """
        It checks that set_params and get_params works correctly for TSVM-RBF
        """

        expected_output = {
            'C2': 0.25,
            'C1': 2,
            'rect_kernel': 1,
            'gamma': 0.125,
            'kernel': 'RBF'
        }

        tsvm_cls = TSVM('RBF')
        tsvm_cls.set_params(**{'C1': 2, 'C2': 0.25, 'gamma': 0.125})

        self.assertEqual(tsvm_cls.get_params(), expected_output,
                         'set_params and get_params output don\'t match')
Example #3
0
    def test_RBF_Validator_ttsplit(self):
        """
        It applies train/test split in non-linear TSVM
        """
        tsvm_classifier = TSVM(kernel='RBF')
        validate = Validator(self.input.X_train, self.input.y_train, ('t_t_split', \
                             self.train_set_size), tsvm_classifier)

        func = validate.choose_validator()
        func({'C1': 1, 'C2': 1, 'gamma': 1})
Example #4
0
    def test_linear_Validator_CV(self):
        """
        It applies cross validation on Linear TSVM
        """

        tsvm_classifier = TSVM()
        validate = Validator(self.input.X_train, self.input.y_train, ('CV', \
                             self.k_folds), tsvm_classifier)

        func = validate.choose_validator()
        func({'C1': 1, 'C2': 1})
Example #5
0
def initializer(user_input_obj):

    """
    It passes a user's input to the functions and classes for solving a
    classification task. The steps that this function performs can be summarized
    as follows:
        
    #. Specifies a TwinSVM classifier based on the user's input.
    #. Chooses an evaluation method for assessment of the classifier.
    #. Computes all the combination of search elements.
    #. Computes the evaluation metrics for all the search element using grid search.
    #. Saves the detailed classification results in a spreadsheet file (Excel).
    
    Parameters
    ----------
    user_input_obj : object 
        An instance of :class:`UserInput` class which holds the user input.
    """

    if user_input_obj.class_type == 'binary':

        tsvm_obj = TSVM(user_input_obj.kernel_type, user_input_obj.rect_kernel)

    elif user_input_obj.class_type == 'ovo':

        tsvm_obj = OVO_TSVM(user_input_obj.kernel_type)
        
    elif user_input_obj.class_type == 'ova':
        
        tsvm_obj = MCTSVM(user_input_obj.kernel_type)

    validate = Validator(user_input_obj.X_train, user_input_obj.y_train, \
                         user_input_obj.test_method_tuple, tsvm_obj)

    search_elements = search_space(user_input_obj.kernel_type, user_input_obj.class_type, \
                      user_input_obj.lower_b_c, user_input_obj.upper_b_c, user_input_obj.lower_b_u, \
                      user_input_obj.upper_b_u)

    # Display headers
    print("%s-%s    Dataset: %s    Total Search Elements: %d" % (tsvm_obj.cls_name,
          user_input_obj.kernel_type, user_input_obj.filename, len(search_elements)))

    result = grid_search(search_elements, validate.choose_validator())

    try:

        return save_result(user_input_obj.filename, validate, result, user_input_obj.result_path)

    except FileNotFoundError:

        os.makedirs('result')

        return save_result(user_input_obj.filename, validate, result, user_input_obj.result_path)
Example #6
0
    def test_train_TSVM_RBF(self):
        """
        It trains TSVM classifier with RBF kernel
        """

        # Default arguments
        tsvm_classifier = TSVM(kernel='RBF')
        tsvm_classifier.fit(self.input.X_train, self.input.y_train)
        tsvm_classifier.predict(self.input.X_train)
Example #7
0
    def test_train_TSVM_linear(self):
        """
        It trains TSVM classifier with Linear kernel
        """

        # Default arguments
        tsvm_classifier = TSVM()
        tsvm_classifier.fit(self.input.X_train, self.input.y_train)
        tsvm_classifier.predict(self.input.X_train)