def test_catch22_forest_classifier_on_basic_motions(): # load basic motions data X_train, y_train = load_basic_motions(split="train", return_X_y=True) X_test, y_test = load_basic_motions(split="test", return_X_y=True) indices = np.random.RandomState(0).permutation(20) # train c22f c22f = Catch22ForestClassifier(random_state=0) c22f.fit(X_train.iloc[indices], y_train[indices]) # assert probabilities are the same probas = c22f.predict_proba(X_test.iloc[indices]) testing.assert_array_equal(probas, catch22_forest_classifier_basic_motions_probas)
def set_classifier(cls, resampleId=None): """ Basic way of creating the classifier to build using the default settings. This set up is to help with batch jobs for multiple problems to facilitate easy reproducability. You can set up bespoke classifier in many other ways. :param cls: String indicating which classifier you want :param resampleId: classifier random seed :return: A classifier. """ name = cls.lower() # Distance based if name == "pf" or name == "proximityforest": return ProximityForest(random_state=resampleId) elif name == "pt" or name == "proximitytree": return ProximityTree(random_state=resampleId) elif name == "ps" or name == "proximityStump": return ProximityStump(random_state=resampleId) elif name == "dtwcv" or name == "kneighborstimeseriesclassifier": return KNeighborsTimeSeriesClassifier(distance="dtwcv") elif name == "dtw" or name == "1nn-dtw": return KNeighborsTimeSeriesClassifier(distance="dtw") elif name == "msm" or name == "1nn-msm": return KNeighborsTimeSeriesClassifier(distance="msm") elif name == "ee" or name == "elasticensemble": return ElasticEnsemble() elif name == "shapedtw": return ShapeDTW() # Dictionary based elif name == "boss" or name == "bossensemble": return BOSSEnsemble(random_state=resampleId) elif name == "cboss" or name == "contractableboss": return ContractableBOSS(random_state=resampleId) elif name == "tde" or name == "temporaldictionaryensemble": return TemporalDictionaryEnsemble(random_state=resampleId) elif name == "weasel": return WEASEL(random_state=resampleId) elif name == "muse": return MUSE(random_state=resampleId) # Interval based elif name == "rise" or name == "randomintervalspectralforest": return RandomIntervalSpectralForest(random_state=resampleId) elif name == "tsf" or name == "timeseriesforestclassifier": return TimeSeriesForestClassifier(random_state=resampleId) elif name == "cif" or name == "canonicalintervalforest": return CanonicalIntervalForest(random_state=resampleId) elif name == "drcif": return DrCIF(random_state=resampleId) # Shapelet based elif name == "stc" or name == "shapelettransformclassifier": return ShapeletTransformClassifier( random_state=resampleId, time_contract_in_mins=1 ) elif name == "mrseql" or name == "mrseqlclassifier": return MrSEQLClassifier(seql_mode="fs", symrep=["sax", "sfa"]) elif name == "rocket": return ROCKETClassifier(random_state=resampleId) elif name == "arsenal": return Arsenal(random_state=resampleId) # Hybrid elif name == "catch22": return Catch22ForestClassifier(random_state=resampleId) elif name == "hivecotev1": return HIVECOTEV1(random_state=resampleId) else: raise Exception("UNKNOWN CLASSIFIER")