def __call__(self, inputs, state, scope=None, dtype=tf.float32): with tf.name_scope('ALIFcall'): i_future_buffer = state.i_future_buffer + einsum_bi_ijk_to_bjk( inputs, self.W_in) + einsum_bi_ijk_to_bjk(state.z, self.W_rec) new_b = self.decay_b * state.b + (1. - self.decay_b) * state.z thr = self.thr + new_b * self.beta * self.V0 new_v, new_z = self.LIF_dynamic(v=state.v, z=state.z, z_buffer=state.z_buffer, i_future_buffer=i_future_buffer, decay=self._decay, thr=thr) new_z_buffer = tf_roll(state.z_buffer, new_z, axis=2) new_i_future_buffer = tf_roll(i_future_buffer, axis=2) new_state = ALIFStateTuple(v=new_v, z=new_z, b=new_b, i_future_buffer=new_i_future_buffer, z_buffer=new_z_buffer) return [new_z, new_v, thr], new_state
def __call__(self, inputs, state, scope=None, dtype=tf.float32): i_future_buffer = state.i_future_buffer + einsum_bi_ijk_to_bjk( inputs, self.W_in) + einsum_bi_ijk_to_bjk(state.z, self.W_rec) new_v, new_z = self.LIF_dynamic(v=state.v, z=state.z, z_buffer=state.z_buffer, i_future_buffer=i_future_buffer) new_z_buffer = tf_roll(state.z_buffer, new_z, axis=2) new_i_future_buffer = tf_roll(i_future_buffer, axis=2) new_state = LIFStateTuple(v=new_v, z=new_z, i_future_buffer=new_i_future_buffer, z_buffer=new_z_buffer) return new_z, new_state