コード例 #1
0
    def map_reward(self, state, electrode_weights=[], Fs=250, L=1000):
        with tf.name_scope("Generate_Reward"):

            reward = Utils.extract_frequency_bins(state, self.mLOWALPHA,
                                                  self.mHIGHALPHA, 1)
            reward_flat = tf.reshape(reward, [-1], name="flatten_fft")
            reward_summed = tf.reduce_sum(reward_flat, name="alphapow_summing")
        return reward_summed
コード例 #2
0
    def test_2d_array(self):
        with self.test_session() as sess:
            expected_val = np.int64([[500, 0, 0, 0, 0], [0, 0, 0, 0, 500]])
            Sin1 = [np.sin(2 * np.pi * (x / 250)) for x in range(1000)]
            Sin20 = [np.sin(2 * np.pi * (x / 250) * 20) for x in range(1000)]

            a = tf.constant([Sin1, Sin20], dtype=tf.complex64)
            b = Utils.extract_frequency_bins(a, 0, 25, 5)
            c = tf.cast(b, dtype=tf.int64)

            #Write graph to file
            writer = tf.summary.FileWriter(".\\Logs\\", sess.graph)

            #Actual check
            self.assertAllEqual(c.eval(), expected_val)
コード例 #3
0
    def run(self, _buf):
        with tf.name_scope("RUN_BandPower"):
            tf.assert_rank(_buf['data'],
                           2,
                           message="JCR: Input must be rank 2 tensor")
            asserts = [
                tf.assert_equal(
                    tf.shape(_buf['data'])[0],
                    self.mNCHAN,
                    message="JCR: Input Dim-0 must equal number of channels")
            ]

            with tf.control_dependencies(asserts):
                hz_per_point = self.mFS / self.mSIGLEN
                points_per_bin = np.maximum(self.mBIN_WIDTH // hz_per_point,
                                            1.0)

                total_number_bins = int(self.mSIGLEN // points_per_bin)
                total_number_points = int(total_number_bins * points_per_bin)
                f_start = 0
                f_end = total_number_points * self.mFS / self.mSIGLEN

                dout = Utils.extract_frequency_bins(_buf['data'], f_start,
                                                    f_end, total_number_bins,
                                                    self.mSIGLEN, self.mFS)
                asserts = [
                    tf.assert_equal(tf.shape(dout)[0],
                                    self.mNCHAN,
                                    message="JCR: Input/output shape mismatch",
                                    name='FinalCheck')
                ]
                with tf.control_dependencies(asserts):
                    return {
                        'data': dout,
                        'summaries': _buf['summaries'],
                        #pass dout shape to updates so that the assert gets evaluated
                        'updates': _buf['updates'] + [tf.shape(dout)[0]]
                    }