Example #1
0
def test_mrseql_on_gunpoint():
    # load training data
    X_train, y_train = load_gunpoint(split='train', return_X_y=True)
    X_test, y_test = load_gunpoint(split='test', return_X_y=True)

    sax_clf = MrSEQLClassifier(seql_mode='fs', symrep=['sax'])
    sfa_clf = MrSEQLClassifier(seql_mode='fs', symrep=['sfa'])
    ss_clf = MrSEQLClassifier(seql_mode='fs', symrep=['sax', 'sfa'])

    # fit training data
    sax_clf.fit(X_train, y_train)
    sfa_clf.fit(X_train, y_train)
    ss_clf.fit(X_train, y_train)

    # prediction
    sax_predicted = sax_clf.predict(X_test)
    sfa_predicted = sfa_clf.predict(X_test)
    ss_predicted = ss_clf.predict(X_test)

    # test feature space dimension
    # the multi-domain classifier (ss_clf) should produce as many features
    # as the others (sax_clf and sfa_clf) combine
    np.testing.assert_equal(
        ss_clf.ots_clf.coef_.shape[1],
        sfa_clf.ots_clf.coef_.shape[1] + sax_clf.ots_clf.coef_.shape[1])

    # test number of correct predictions
    np.testing.assert_equal((sax_predicted == y_test).sum(), 148)
    np.testing.assert_equal((sfa_predicted == y_test).sum(), 150)
    np.testing.assert_equal((ss_predicted == y_test).sum(), 150)
def set_classifier(cls, resampleId=None):
    """Construct a classifier.

    Basic way of creating the classifier to build using the default settings. This
    set up is to help with batch jobs for multiple problems to facilitate easy
    reproducability. You can set up bespoke classifier in many other ways.

    Parameters
    ----------
    cls: String indicating which classifier you want
    resampleId: classifier random seed

    Return
    ------
    A classifier.
    """
    name = cls.lower()
    # Distance based
    if name == "pf" or name == "proximityforest":
        return ProximityForest(random_state=resampleId)
    elif name == "pt" or name == "proximitytree":
        return ProximityTree(random_state=resampleId)
    elif name == "ps" or name == "proximityStump":
        return ProximityStump(random_state=resampleId)
    elif name == "dtwcv" or name == "kneighborstimeseriesclassifier":
        return KNeighborsTimeSeriesClassifier(distance="dtwcv")
    elif name == "dtw" or name == "1nn-dtw":
        return KNeighborsTimeSeriesClassifier(distance="dtw")
    elif name == "msm" or name == "1nn-msm":
        return KNeighborsTimeSeriesClassifier(distance="msm")
    elif name == "ee" or name == "elasticensemble":
        return ElasticEnsemble()
    elif name == "shapedtw":
        return ShapeDTW()
    # Dictionary based
    elif name == "boss" or name == "bossensemble":
        return BOSSEnsemble(random_state=resampleId)
    elif name == "cboss" or name == "contractableboss":
        return ContractableBOSS(random_state=resampleId)
    elif name == "tde" or name == "temporaldictionaryensemble":
        return TemporalDictionaryEnsemble(random_state=resampleId)
    elif name == "weasel":
        return WEASEL(random_state=resampleId)
    elif name == "muse":
        return MUSE(random_state=resampleId)
    # Interval based
    elif name == "rise" or name == "randomintervalspectralforest":
        return RandomIntervalSpectralForest(random_state=resampleId)
    elif name == "tsf" or name == "timeseriesforestclassifier":
        return TimeSeriesForestClassifier(random_state=resampleId)
    elif name == "cif" or name == "canonicalintervalforest":
        return CanonicalIntervalForest(random_state=resampleId)
    elif name == "drcif":
        return DrCIF(random_state=resampleId)
    # Shapelet based
    elif name == "stc" or name == "shapelettransformclassifier":
        return ShapeletTransformClassifier(random_state=resampleId,
                                           transform_contract_in_mins=60)
    elif name == "mrseql" or name == "mrseqlclassifier":
        return MrSEQLClassifier(seql_mode="fs", symrep=["sax", "sfa"])
    elif name == "rocket":
        return ROCKETClassifier(random_state=resampleId)
    elif name == "arsenal":
        return Arsenal(random_state=resampleId)
    # Hybrid
    elif name == "catch22":
        return Catch22ForestClassifier(random_state=resampleId)
    elif name == "hivecotev1":
        return HIVECOTEV1(random_state=resampleId)
    else:
        raise Exception("UNKNOWN CLASSIFIER")
Example #3
0
def set_classifier(cls, resample_id=None):
    """Construct a classifier.

    Basic way of creating the classifier to build using the default settings. This
    set up is to help with batch jobs for multiple problems to facilitate easy
    reproducibility for use with load_and_run_classification_experiment. You can pass a
    classifier object instead to run_classification_experiment.

    Parameters
    ----------
    cls : String
        String indicating which classifier you want.
    resample_id : Random seed as an int, default=None
        Classifier random seed.

    Return
    ------
    classifier : A BaseClassifier.
        The classifier matching the input classifier name.
    """
    name = cls.lower()
    # Dictionary based
    if name == "boss" or name == "bossensemble":
        return BOSSEnsemble(random_state=resample_id)
    elif name == "cboss" or name == "contractableboss":
        return ContractableBOSS(random_state=resample_id)
    elif name == "tde" or name == "temporaldictionaryensemble":
        return TemporalDictionaryEnsemble(random_state=resample_id)
    elif name == "weasel":
        return WEASEL(random_state=resample_id)
    elif name == "muse":
        return MUSE(random_state=resample_id)
    # Distance based
    elif name == "pf" or name == "proximityforest":
        return ProximityForest(random_state=resample_id)
    elif name == "pt" or name == "proximitytree":
        return ProximityTree(random_state=resample_id)
    elif name == "ps" or name == "proximityStump":
        return ProximityStump(random_state=resample_id)
    elif name == "dtwcv" or name == "kneighborstimeseriesclassifier":
        return KNeighborsTimeSeriesClassifier(distance="dtwcv")
    elif name == "dtw" or name == "1nn-dtw":
        return KNeighborsTimeSeriesClassifier(distance="dtw")
    elif name == "msm" or name == "1nn-msm":
        return KNeighborsTimeSeriesClassifier(distance="msm")
    elif name == "ee" or name == "elasticensemble":
        return ElasticEnsemble()
    elif name == "shapedtw":
        return ShapeDTW()
    # Feature based: Removed because of soft dependency
    elif name == "catch22":
        return Catch22Classifier(
            random_state=resample_id,
            estimator=RandomForestClassifier(n_estimators=500))
    elif name == "matrixprofile":
        return MatrixProfileClassifier(random_state=resample_id)
    elif name == "signature":
        return SignatureClassifier(
            random_state=resample_id,
            classifier=RandomForestClassifier(n_estimators=500),
        )
    elif name == "tsfresh":
        return TSFreshClassifier(
            random_state=resample_id,
            estimator=RandomForestClassifier(n_estimators=500))
    elif name == "tsfresh-r":
        return TSFreshClassifier(
            random_state=resample_id,
            estimator=RandomForestClassifier(n_estimators=500),
            relevant_feature_extractor=True,
        )
    # Hybrid
    elif name == "hivecotev1":
        return HIVECOTEV1(random_state=resample_id)
    # Interval based
    elif name == "rise" or name == "randomintervalspectralforest":
        return RandomIntervalSpectralForest(random_state=resample_id,
                                            n_estimators=500)
    elif name == "tsf" or name == "timeseriesforestclassifier":
        return TimeSeriesForestClassifier(random_state=resample_id,
                                          n_estimators=500)
    elif name == "cif" or name == "canonicalintervalforest":
        return CanonicalIntervalForest(random_state=resample_id,
                                       n_estimators=500)
    elif name == "stsf":
        return SupervisedTimeSeriesForest(random_state=resample_id,
                                          n_estimators=500)
    elif name == "drcif":
        return DrCIF(random_state=resample_id, n_estimators=500)
    # Kernel based
    elif name == "rocket":
        return ROCKETClassifier(random_state=resample_id)
    elif name == "arsenal":
        return Arsenal(random_state=resample_id)
    # Shapelet based
    elif name == "stc" or name == "shapelettransformclassifier":
        return ShapeletTransformClassifier(random_state=resample_id,
                                           n_estimators=500)
    elif name == "mrseql" or name == "mrseqlclassifier":
        return MrSEQLClassifier(seql_mode="fs", symrep=["sax", "sfa"])
    else:
        raise Exception("UNKNOWN CLASSIFIER")