コード例 #1
ファイル: RawAndBeta.py プロジェクト: NickZhangChen/EEGBaseDL
def evaluate(num):
    # num 表示要取那个人的数据
    # return 第num 个人数据经过测试得到数据。
    with tf.name_scope("input"):
        input_x = tf.placeholder(tf.float32, [bat[num], 500, 9, 1], name='EEG-input')  # 数据的输入,第一维表示一个batch中样例的个数
        input_y = tf.placeholder(tf.float32, [None, 7], name='EEG-lable')  # 一个batch里的lable
    regularlizer = tf.contrib.layers.l2_regularizer(Regularazition_Rate)#本来测试的时候不用加这个
    is_training = tf.cast(False, tf.bool)
    out = BaseCNN(input_x, is_training, regularlizer)
    y = out['logit']
    with tf.name_scope("test_acc"):
        correct_predection = tf.equal(tf.argmax(y,1),tf.argmax(input_y,1))
        accuracy = tf.reduce_mean(tf.cast(correct_predection,tf.float32))
        tf.summary.scalar('test_acc', accuracy)
    variable_averages = tf.train.ExponentialMovingAverage(Moving_Average_Decay)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    with  tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(Model_Save_Path)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            x, y = test_batch(num)#获取第x个人的数据
            reshape_xs = np.reshape(x,(-1,500,9,1))
            ys = one_hot(y)
            conv1, pool1, conv2, pool2, conv3, pool3, conv4, pool4, acc_score =sess.run([out['conv1'], out['pool1'], out['conv2'], out['pool2'],
                                                                                         out['conv3'], out['pool3'], out['conv4'], out['pool4'],
                                                                                         accuracy],feed_dict={input_x: reshape_xs, input_y: ys})
            print("Afer %s training step, test accuracy = %g" % (global_step,acc_score))
        else :
            print("No checkpoint file found")
    return  conv1, pool1, conv2, pool2, conv3, pool3, conv4, pool4
コード例 #2
def evaluate(num, channel, name, fliter):
    # num 表示要取那个人的数据,channel 表示对那个通道数据感兴趣,name表示是哪一层,fliter 表示神经网络中间层滤波器的数量
    # return 第num 个人数据经过测试得到的在name层处理后的输出数据对应channel的值。
    with tf.name_scope("input"):
        input_x = tf.placeholder(tf.float32, [5424, 500, 9, 1],
                                 name='EEG-input')  # 数据的输入,第一维表示一个batch中样例的个数
        input_y = tf.placeholder(tf.float32, [None, 7],
                                 name='EEG-lable')  # 一个batch里的lable
    regularlizer = tf.contrib.layers.l2_regularizer(
        Regularazition_Rate)  #本来测试的时候不用加这个
    out = BaseCNN(input_x, False, regularlizer)
    y = out['logit']
    with tf.name_scope("test_acc"):
        correct_predection = tf.equal(tf.argmax(y, 1), tf.argmax(input_y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_predection, tf.float32))
        tf.summary.scalar('test_acc', accuracy)
    variable_averages = tf.train.ExponentialMovingAverage(Moving_Average_Decay)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    raw_channel = []

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(Model_Save_Path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
            x, y = test_batch(num)  #获取第x个人的数据
            # x = x[100:200]
            # y = y[100:200]
            # xs = standardize(x)
            reshape_xs = np.reshape(x, (-1, 500, 9, 1))
            ys = one_hot(y)
            data, acc_score = sess.run([out['pool4'], accuracy],
                                           input_x: reshape_xs,
                                           input_y: ys

            #length = data[0]*data[1]
            for i in range(fliter):
                temp = data[:, :, channel, i]
                raw = np.reshape(temp, [173568])  # 数据的总长度
            print("Afer %s training step, test accuracy = %g" %
                  (global_step, acc_score))
            print("No checkpoint file found")
    return raw_channel