Пример #1
0
def test_openDataset():
    from tslearn.datasets import UCR_UEA_datasets

    X_train, y_train, X_test, y_test = UCR_UEA_datasets().load_dataset(
        "GunPoint")
    print("Train shape", X_train.shape)
    print("Test shape", X_test.shape)
    print(set(y_train))

    dl = getDataLoader(data=X_train, label=y_train)
    print("nb data", len(dl.dataset))

    plt.plot(range(X_train.shape[1]), X_train[0, :] / X_train.max())
    plt.plot(range(X_train.shape[1]), X_train[25, :] / X_train.max())

    plt.show()
Пример #2
0
            best_scores_train = {k : 0. for k in _kernels}

            # lead-lag only if number of channels is <= 5
            x_train, _, _, _ = UCR_UEA_datasets(use_cache=True).load_dataset(name)    
            if x_train.shape[1] <= 200 and x_train.shape[2] <= 8: 
                transforms = tqdm([(True,True), (False,True), (True,False), (False,False)], position=1, leave=False)
            else: # do not try lead-lag as dimension is already high
                transforms = tqdm([(True,False), (False,False)], position=1, leave=False)
                
            # grid-search for path-transforms (add-time, lead-lag)
            for (at,ll) in transforms:
                transforms.set_description(f"add-time: {at}, lead-lag: {ll}")

                # load train data
                x_train, y_train, _, _ = UCR_UEA_datasets(use_cache=True).load_dataset(name)
                x_train /= x_train.max()

                # encode outputs as labels
                y_train = LabelEncoder().fit_transform(y_train)

                # path-transform
                x_train = sigkernel.transform(x_train, at=at, ll=ll, scale=.1)

                # subsample every time steps if certain length is exceeded
                subsample = max(int(np.floor(x_train.shape[1]/149)),1)
                x_train = x_train[:,::subsample,:]
                datasets.set_description(f"dataset: {name} --- shape: {x_train.shape}")

                #==================================================================================
                # Linear, RBF and GAK kernels
                #==================================================================================