Ejemplo n.º 1
0
from sklearn.metrics import confusion_matrix #performance diagnostic tool
import pdb #debugging package - use by including 'pdb.set_trace()' in code

#### LOAD DATA ####

setnum = 1 #1 is opp, 2 is pam, 3 is skoda

if setnum == 1:
    DB = 79 #reference to the data set to load below
if setnum == 2:
	DB = 52
if setnum == 3:
    DB = 60

#load the database through the function 'loadingDB()' from dataset.py    
train_x, valid_x, test_x, train_y, valid_y, test_y = loadingDB('../', DB)

#csv = open(str(DB)+'.csv','a') #to store performance results later on


#### FUNCTION THAT DETERMINES THAT DATA SET USED IN EACH ITERATION ####

def create_batches(data_x,data_y,range_split, random_start = False):
    dim_data = train_x.shape[1]
    n_classes = train_y.shape[1] 
    
    #determine number of batches in range (min_n_batches:max_n_batches)
    n_batches = np.random.randint(range_split[0],range_split[1],1)[0] #use [0] because the function returns an array and we want a number only
    #the length of each batch
    l_batches = data_x.shape[0]//n_batches
    
Ejemplo n.º 2
0
import tensorflow as tf
from tensorflow.contrib import rnn
from dataset import loadingDB
import numpy as np
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
import pickle
import pandas as pd
import pdb #debugging package - use by including 'pdb.set_trace()' in code

## this is a test function 
#R select the data set to use an d load data
setnum = 1#1 is opp79, 2 is pamap2, 3 is skoda
if setnum == 1:
	train_x, valid_x, test_x, train_y, valid_y, test_y = loadingDB('../', 79)
	n_classes = 18
	DB = 79 #number of features
if setnum == 2:
	train_x, valid_x, test_x, train_y, valid_y, test_y = loadingDB('../', 52)
	n_classes = 12
	DB = 52
if setnum == 3:
	train_x, valid_x, test_x, train_y, valid_y, test_y = loadingDB('../', 60)
	n_classes = 11
	DB = 60

# set hyperparameters of structure, dim and DB are same
nm_epochs = 2
rnn_size = 256 #number of nodes in the hidden layer
number_of_layers = 2
keep_rate = 0.5 #1/dropout rate