Ejemplo n.º 1
0
 def test_1d_shift(self):
     x = [1, 2, 3, 4, 5, 6]
     xp1 = np.roll(x, 1)
     xn1 = np.roll(x, -1)
     xp5 = np.roll(x, 5)
     xn5 = np.roll(x, -5)
     with self.test_session() as sess:
         x_tf = tf.constant(x)
         self.assertAllEqual(Utils.shift_1d(x_tf, 0).eval(), x)
         self.assertAllEqual(Utils.shift_1d(x_tf, 1).eval(), xp1)
         self.assertAllEqual(Utils.shift_1d(x_tf, -1).eval(), xn1)
         self.assertAllEqual(Utils.shift_1d(x_tf, 5).eval(), xp5)
         self.assertAllEqual(Utils.shift_1d(x_tf, -5).eval(), xn5)
Ejemplo n.º 2
0
    def map_reward(self, state, electrode_weights=[], Fs=250, L=1000):
        with tf.name_scope("Generate_Reward"):

            reward = Utils.extract_frequency_bins(state, self.mLOWALPHA,
                                                  self.mHIGHALPHA, 1)
            reward_flat = tf.reshape(reward, [-1], name="flatten_fft")
            reward_summed = tf.reduce_sum(reward_flat, name="alphapow_summing")
        return reward_summed
Ejemplo n.º 3
0
    def bod(l_idx, specgram):
        w_tmp = Utils.shift_1d(w_var, tf.cast(w_shift * l_idx, tf.int32))
        s_tmp = tf.multiply(s_var, w_tmp)
        fft_tmp = tf.fft(tf.cast(s_tmp, tf.complex64))
        fft_padded = tf.pad(tf.expand_dims(
            fft_tmp, 0), [[l_idx, tf.shape(specgram)[0] - l_idx - 1], [0, 0]])

        #Update loop variables
        l_idx = l_idx + 1
        specgram = tf.add(specgram, fft_padded)

        return [l_idx, specgram]
Ejemplo n.º 4
0
        def bod(lcl_idx, lcl_r_var):
            w_tmp = Utils.shift_1d(w_var, tf.cast(w_overlap * lcl_idx,
                                                  tf.int32))
            s_tmp = tf.multiply(s_var, w_tmp)
            fft_tmp = tf.fft(tf.cast(s_tmp, tf.complex64))
            fft_padded = tf.pad(
                tf.expand_dims(fft_tmp, 0),
                [[lcl_idx, tf.shape(lcl_r_var)[0] - lcl_idx - 1], [0, 0]])

            lcl_idx = lcl_idx + 1
            lcl_r_var = tf.add(lcl_r_var, fft_padded)

            return [lcl_idx, lcl_r_var]
Ejemplo n.º 5
0
    def test_2d_array(self):
        with self.test_session() as sess:
            expected_val = np.int64([[500, 0, 0, 0, 0], [0, 0, 0, 0, 500]])
            Sin1 = [np.sin(2 * np.pi * (x / 250)) for x in range(1000)]
            Sin20 = [np.sin(2 * np.pi * (x / 250) * 20) for x in range(1000)]

            a = tf.constant([Sin1, Sin20], dtype=tf.complex64)
            b = Utils.extract_frequency_bins(a, 0, 25, 5)
            c = tf.cast(b, dtype=tf.int64)

            #Write graph to file
            writer = tf.summary.FileWriter(".\\Logs\\", sess.graph)

            #Actual check
            self.assertAllEqual(c.eval(), expected_val)
Ejemplo n.º 6
0
    def bod(l_idx, specgram):
        w_tmp = Utils.shift_2d(w_var, tf.cast(w_shift * l_idx, tf.int32), 1)
        s_tmp = tf.multiply(s_var, w_tmp)
        fft_tmp = tf.fft(tf.cast(s_tmp, tf.complex64))

        #Subtract one from the padding because fft takes up one row
        fft_padded = tf.pad(tf.expand_dims(fft_tmp, 0),
                            [[l_idx,
                              (specgram_len - 1) - l_idx], [0, 0], [0, 0]])

        #Update loop variables

        specgram = tf.add(specgram, fft_padded)
        with tf.control_dependencies([specgram]):
            l_idx = l_idx + 1
        return [l_idx, specgram]
Ejemplo n.º 7
0
    def run(self, _buf):
        with tf.name_scope("RUN_FIRFilt"):
            tf.assert_rank(_buf['data'],
                           2,
                           message="JCR: Input must be rank 2 tensor")
            asserts = [
                tf.assert_equal(
                    tf.shape(_buf['data'])[0],
                    self.mNCHAN,
                    message="JCR: Input Dim-0 must equal number of channels")
            ]

            with tf.control_dependencies(asserts):
                dout = Utils.multi_ch_conv(_buf['data'], self.mCOEFFS)
                return {
                    'data': dout,
                    'summaries': _buf['summaries'],
                    'updates': _buf['updates']
                }
Ejemplo n.º 8
0
 def run(self, _buf):
     with tf.name_scope("RUN_IAFPower"):
         tf.assert_rank(_buf['data'], 2, message="JCR: Input must be rank 2 tensor")
         asserts= [
                 tf.assert_equal(tf.shape(_buf['data'])[0], self.mNCHAN, message="JCR: Input Dim-0 must equal number of channels")
                 ]
         
         with tf.control_dependencies(asserts):
             s_len = tf.shape(_buf['data'])[1]
             
             pphz = tf.realdiv(tf.cast(s_len, tf.float32) , tf.cast(self.mFS, tf.float32))
             
             #This value (in Hz) is used to determine the peak frequency - larger window uses more surrounding values to calculate IAF
             IAF_Peak_Window_Size = tf.realdiv(self.mPEAK_WINDOW, 2.0)
             
             asserts = [
                     tf.assert_greater_equal(IAF_Peak_Window_Size, tf.realdiv(1.0, pphz), message="JCR: Invalid number of Hz/Window")
                     ]
             with tf.control_dependencies(asserts):
                 IAF_Window = tf.ones([tf.cast(tf.multiply(IAF_Peak_Window_Size,pphz),tf.int32)])
         
                 data_fft = tf.fft(tf.cast(_buf['data'],tf.complex64))
                 
                 #Convolve the window over the FFT to strengthen peak
                 data_fft_neighbor_avg = Utils.multi_ch_conv(tf.cast(tf.abs(data_fft),tf.float32),IAF_Window)
                 half_size = tf.cast((tf.shape(data_fft_neighbor_avg)[1] / 2), tf.int32)
                 dout1 = tf.argmax(data_fft_neighbor_avg[:, 0:half_size], axis=1)
                 dout = tf.realdiv( tf.cast(dout1, tf.float32), tf.realdiv(tf.cast(half_size,tf.float32), tf.realdiv(tf.cast(self.mFS,tf.float32), 2.0)))
     
                 asserts = [
                         tf.assert_equal(tf.shape(dout)[0], self.mNCHAN, message="JCR: Input/output shape mismatch",name='FinalCheck')
                         ]
                 with tf.control_dependencies(asserts):
                     return {
                             'data':dout,
                             'summaries':_buf['summaries'],
                             #pass dout shape to updates so that the assert gets evaluated
                             'updates': _buf['updates'] + [tf.shape(dout)[0]]
                             }
Ejemplo n.º 9
0
    def run(self, _buf):
        with tf.name_scope("RUN_BandPower"):
            tf.assert_rank(_buf['data'],
                           2,
                           message="JCR: Input must be rank 2 tensor")
            asserts = [
                tf.assert_equal(
                    tf.shape(_buf['data'])[0],
                    self.mNCHAN,
                    message="JCR: Input Dim-0 must equal number of channels")
            ]

            with tf.control_dependencies(asserts):
                hz_per_point = self.mFS / self.mSIGLEN
                points_per_bin = np.maximum(self.mBIN_WIDTH // hz_per_point,
                                            1.0)

                total_number_bins = int(self.mSIGLEN // points_per_bin)
                total_number_points = int(total_number_bins * points_per_bin)
                f_start = 0
                f_end = total_number_points * self.mFS / self.mSIGLEN

                dout = Utils.extract_frequency_bins(_buf['data'], f_start,
                                                    f_end, total_number_bins,
                                                    self.mSIGLEN, self.mFS)
                asserts = [
                    tf.assert_equal(tf.shape(dout)[0],
                                    self.mNCHAN,
                                    message="JCR: Input/output shape mismatch",
                                    name='FinalCheck')
                ]
                with tf.control_dependencies(asserts):
                    return {
                        'data': dout,
                        'summaries': _buf['summaries'],
                        #pass dout shape to updates so that the assert gets evaluated
                        'updates': _buf['updates'] + [tf.shape(dout)[0]]
                    }
Ejemplo n.º 10
0
        return tf.greater(specgram_len, tf.cast(l_idx, tf.int32))

    return tf.while_loop(cond, bod, loopvars, parallel_iterations=1000)


tf.reset_default_graph()

##SPECGRAM TST
if True:
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        raw_data_tensor = tf.constant(np.asarray(np.transpose(rd)),
                                      dtype=tf.float32)

        b_coeffs = tf.constant(b, dtype=tf.float32)

        data_bp_filt_tmp = Utils.multi_ch_conv(raw_data_tensor, b_coeffs)
        data_bp_filt = tf.slice(data_bp_filt_tmp, [0, 300],
                                [-1, tf.shape(data_bp_filt_tmp)[1] - 600])

        raw_data_tensor = data_bp_filt

        z = specgram_tf_2d(raw_data_tensor, 500, 480)
        init = tf.global_variables_initializer()
        print("Global variables init, ret: ", sess.run(init))
        writer = tf.summary.FileWriter(".\\Logs\\", sess.graph)

        global spectro
        run_options = tf.RunOptions(trace_level=tf.RunOptions.SOFTWARE_TRACE)
        run_metadata = tf.RunMetadata()
        spectro = sess.run([z], options=run_options, run_metadata=run_metadata)
        writer.add_run_metadata(run_metadata, 'S: %d' % 1)