from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten from tensorflow.keras.layers import Conv2D from tensorflow.keras.layers import MaxPooling2D, BatchNormalization import time import dataLoading import gc from configReader import ConfigReader CONF = ConfigReader("hyperParameters.ini") ACTIONS = dataLoading.ACTIONS TIME_SLOT = CONF.time_slot FREQUENCY_SLOT = CONF.frequency_slot CHANNELS_NUM = len(CONF.selected_channels) RESHAPE = (-1, TIME_SLOT, FREQUENCY_SLOT, CHANNELS_NUM) HIDDEN_LAYERS = int(CONF.getAttr("default", "hidden_layers")) # 用于后续规格化,(NTFC) # 由于后续keras中卷积层默认通道数在最后一个维度上,即channels_last,故此处需要将8放在最后 # cpu版本的tf不支持channels_first,只支持NHWC模式,即channels_last OUT_SIZE = len(ACTIONS) #输出规格,与分类数有关 # ========================== data create ======================= print("Loading data...") train_data, test_data, validate_data = dataLoading.load("new_data") print("Done.") train_X, train_y = dataLoading.tag_divide(train_data) test_X, test_y = dataLoading.tag_divide(test_data) val_X, val_y = dataLoading.tag_divide(validate_data) # ========================== model design ==========================
from pylsl import StreamInlet, resolve_stream import tensorflow as tf import numpy as np import time from collections import deque import os from boxGraphicView import BoxGraphicView from configReader import ConfigReader import dataLoading CONF = ConfigReader("hyperParameters.ini") RESHAPE = (-1, 8, 60) FFT_MAX_HZ = CONF.frequency_slot HM_SECONDS = 10 TOTAL_ITERS = HM_SECONDS * 25 ACTIONS_H = CONF.getAttr("horizontal", "actions").split(',') ACTIONS_V = CONF.getAttr("vertical", "actions").split(',') CHANNELS_NUM = CONF.channels_num TIME_SLOT = CONF.time_slot model_h = tf.keras.models.load_model( os.path.join(CONF.getAttr("horizontal", "models_dir"), CONF.getAttr("horizontal", "test_model"))) model_v = tf.keras.models.load_model( os.path.join(CONF.getAttr("vertical", "models_dir"), CONF.getAttr("vertical", "test_model"))) # model_h.predict(np.zeros((32,60,60,8))) last_print = time.time() fps_counter = deque(maxlen=150)