class ActorCritic: def __init__( self, _NAUDIO_COMMANDS, #scalar, number of possible audio commands _EEG_INPUT_SHAPE, #shape, (ntimepoints, nchan, nfreqs) _LOGDIR, #pass in directory to write summaries and whatnot _POLICY_LR=1e-4, #scalar, policy learning rate _VALUE_LR=1e-3, #scalar, value learning rate _REWARD_MA_LEN=100, #scalar _LSTM_CELLS=[ 30, 30, 30 ] #lstm dimensions, (cell0_size, cell1_size, ...) when total length is number of cells ): # These should not be changed by user but may change later in architechture self._InputShape = list(_EEG_INPUT_SHAPE) self._LSTMCells = list(_LSTM_CELLS) self._LSTMUnrollLength = 1 self._ValueDiscount = 1.0 self._Policy = Policy(_LEARNING_RATE=_POLICY_LR, _ACTIONSPACE_SIZE=_NAUDIO_COMMANDS) self._Value = Value(_LEARNING_RATE=_VALUE_LR, _DISCOUNT_RATE=self._ValueDiscount) self._Reward = Reward(_INPUT_SHAPE=_EEG_INPUT_SHAPE, _MA_LENGTH=_REWARD_MA_LEN) self._Shared = Shared(_CELLS=_LSTM_CELLS, _UNROLL_LENGTH=self._LSTMUnrollLength) # We store a version of the hidden state which we pass in every iteration self._HiddenStateShape = (len(_LSTM_CELLS), 2, self._LSTMUnrollLength, _LSTM_CELLS[-1]) self._LocalHiddenState = np.zeros(self._HiddenStateShape) # Save the logdir self.mLogdir = _LOGDIR self._buildModel() self._buildSummaries() self._buildFeedDicts() self._initSession() def _buildModel(self): # Inputs (from the outpside world) self._phInputEEG = tf.placeholder(tf.float32, shape=list([None] + self._InputShape), name="St0") self._phHiddenState = [ tf.placeholder(tf.float32, (2, self._LSTMUnrollLength, self._LSTMCells[idx]), name="Cell" + str(idx) + "_Ht0_Ct0") for idx in range(len(self._LSTMCells)) ] ''' ========= REFERENCE ========= Graph inputs: - self._phInputEEG - self._phHiddenState Graph Outputs: - self._Action - self._LSTMState Graph Train: - self._ValueTrain - self._PolicyTrain Graph Assigns: - self._RewardAssgn (after training) - self._ValueAssgn (after training) - self._PolicyGradientsAssgn (!! with prediction) Build order is as follows: 0. Reshape EEG input 1. Shared, must be build before policy and before value 2. Reward, must be built before value 3. Value, must be built before policy 4. Policy ''' # St self._netInputEEG = self._reshapeEEGInput(self._phInputEEG) # shared LSTM self._LSTMOutput, self._LSTMState = self._Shared.buildGraph( _phINPUT_EEG=self._netInputEEG, _INPUT_HSTATES=self._phHiddenState) # R(St) self._RewardOutput, self._RewardAssgn = self._Reward.buildGraph( _phINPUT_EEG=self._netInputEEG) # V(St) self._ValueTrain, self._TDError, self._ValueAssgn = self._Value.buildGraph( _INPUT_LAYER=self._LSTMOutput, _INPUT_REWARD=self._RewardOutput) # Pi(St) self._PolicyTrain, self._Action, self._PolicyGradientsAssgn = self._Policy.buildGraph( _INPUT_LAYER=self._LSTMOutput, _INPUT_TDERR=self._TDError) def _buildSummaries(self): weights_policy = tf.get_collection(tf.GraphKeys.WEIGHTS, self._Policy.mScope) weights_value = tf.get_collection(tf.GraphKeys.WEIGHTS, self._Value.mScope) biases_policy = tf.get_collection(tf.GraphKeys.BIASES, self._Policy.mScope) biases_value = tf.get_collection(tf.GraphKeys.BIASES, self._Value.mScope) # Value self.mValueTrainSummaries = tf.summary.merge([ tf.summary.scalar("val_loss", self._Value.mLoss), tf.summary.scalar("td_error", tf.reduce_mean(self._TDError)), tf.summary.scalar("R_t1_instant", tf.reduce_mean(self._Reward.mRewardInstant)), tf.summary.scalar("R_t1_average", self._Reward.mRewardMA), tf.summary.scalar("R_t1_actual", tf.reduce_mean(self._Reward.mRewardActual)), tf.summary.scalar("V_t0", self._Value.mVt0[0,0]), tf.summary.scalar("V_t1", self._Value.mVt1[0,0])] +\ [tf.summary.histogram(w.name.split(':')[0], tf.reshape(w.value(), [-1])) for w in weights_value] + [tf.summary.histogram(b.name.split(':')[0], tf.reshape(b.value(), [-1])) for b in biases_value] + [tf.summary.histogram(w.name.split(':')[0], tf.reshape(w.value(), [1,1,self._LSTMCells[-1], 1]), max_outputs=10) for w in weights_value] ) # Policy self.mPolicyTrainSummaries = tf.summary.merge([ tf.summary.scalar("action_greedy", self._Policy.mAt0Greedy[0]), tf.summary.scalar("action_taken", self._Action[0]), tf.summary.histogram("softmax", self._Policy.mSoftmax), tf.summary.histogram("action_taken", self._Action) ] + [ tf.summary.histogram( w.name.split(':')[0], tf.reshape(w.value(), [-1])) for w in weights_policy ] + [ tf.summary.histogram( b.name.split(':')[0], tf.reshape(b.value(), [-1])) for b in biases_policy ] + [ tf.summary.histogram(w.name.split(':')[0], tf.reshape(w.value(), [1, -1, self._LSTMCells[-1], 1]), max_outputs=10) for w in weights_policy ]) #[tf.summary.histogram(g.name.split("pv_rnn/")[-1].split(':')[0], g) for g in pol_gradients] ) #============================================================================== # self.mInputSummary = tf.summary.merge([ # tf.summary.image('spect', tf.reshape(in_raw_eeg_features, shape_eeg_input[1:] + [1]), max_outputs=nchan) # ]) # # tgt_prof_summaries = tf.summary.merge([ # tf.summary.image('tgt_prof', tf.reshape(in_raw_tgt_profile , [ntimepoints,nchan,nfreqs,1]), max_outputs=nchan), # tf.summary.image('tgt_weighting', tf.reshape(in_raw_tgt_weighting , [ntimepoints,nchan,nfreqs,1]), max_outputs=nchan) # ]) #============================================================================== def _buildFeedDicts(self): #Generate the fetch dictionaries for various actions self.mPolicyActFetch = { 'rnn_output_state': self._LSTMState, 'action': self._Action, 'step': self._Policy.mStepPredict, 'gradient_save': self._PolicyGradientsAssgn, } self.mPolicyActFetch_Summaries = self.mPolicyActFetch #no summaries self.mPolicyTrainFetch = { 'rnn_output_state': self._LSTMState, 'train_op': self._PolicyTrain, 'step': self._Policy.mStep } self.mPolicyTrainFetch_Summaries = dict( summaries=self.mPolicyTrainSummaries, **self.mPolicyTrainFetch) self.mValueTrainFetch = { 'rnn_output_state': self._LSTMState, 'train_op': self._ValueTrain, 'loss': self._TDError, 'step': self._Value.mStep, 'assgn_ops': [self._RewardAssgn, self._ValueAssgn ] #these will be evaluated last (using control_dependencies) } self.mValueTrainFetch_Summaries = dict( summaries=self.mValueTrainSummaries, **self.mValueTrainFetch) def _initSession(self): # Tensorflow Init self.mSaver = tf.train.Saver() self.mSess = tf.Session() self.mWriter = tf.summary.FileWriter(self.mLogdir, self.mSess.graph) self.mSess.run(tf.global_variables_initializer()) self.idx = 0 # Provide a sessrun_name to do a full trace def _runSession(self, _FETCH, _FEED, _SESSRUN_NAME=''): _SESSRUN_NAME = str(self.idx) self.idx += 1 if _SESSRUN_NAME != '': metadata = tf.RunMetadata() out = self.mSess.run(fetches=_FETCH, feed_dict=_FEED, options=tf.RunOptions( trace_level=tf.RunOptions.SOFTWARE_TRACE), run_metadata=tf.RunMetadata()) self.mWriter.add_run_metadata(metadata, _SESSRUN_NAME) else: out = self.mSess.run(_FETCH, _FEED) return out def _reshapeEEGInput(self, _IN): shape_spectrogram_flat = np.prod(self._InputShape) reshape_shape = [-1, self._LSTMUnrollLength, shape_spectrogram_flat] reshaped_input = tf.reshape(_IN, reshape_shape) return reshaped_input def _addHiddenState(self, feed): [ feed.update( {self._phHiddenState[idx]: self._LocalHiddenState[idx]}) for idx in range(len(self._LocalHiddenState)) ] def trainPolicy(self, _INPUT, _DO_SUMMARIES=False, _UPDATE_HIDDEN_STATE=False): feed = {self._phInputEEG: _INPUT} self._addHiddenState(feed) if _DO_SUMMARIES: fetch = self.mPolicyTrainFetch_Summaries else: fetch = self.mPolicyTrainFetch out = self._runSession(fetch, feed) if _UPDATE_HIDDEN_STATE: self._LocalHiddenState = out['rnn_output_state'] return out def trainValue(self, _INPUT, _DO_SUMMARIES=False, _UPDATE_HIDDEN_STATE=False): feed = {self._phInputEEG: _INPUT} self._addHiddenState(feed) if _DO_SUMMARIES: fetch = self.mValueTrainFetch_Summaries else: fetch = self.mValueTrainFetch out = self._runSession(fetch, feed) if _UPDATE_HIDDEN_STATE: self._LocalHiddenState = out['rnn_output_state'] return out def chooseAction(self, _INPUT, _DO_SUMMARIES=False, _UPDATE_HIDDEN_STATE=False): feed = {self._phInputEEG: _INPUT} self._addHiddenState(feed) if _DO_SUMMARIES: fetch = self.mPolicyActFetch_Summaries else: fetch = self.mPolicyActFetch out = self._runSession(fetch, feed) if _UPDATE_HIDDEN_STATE: self._LocalHiddenState = out['rnn_output_state'] return out def run(self, _INPUT, _DO_SUMMARIES=False): t0 = self.trainPolicy(_INPUT, _DO_SUMMARIES) t1 = self.trainValue(_INPUT, _DO_SUMMARIES) act0 = self.chooseAction(_INPUT, _DO_SUMMARIES, _UPDATE_HIDDEN_STATE=True) if _DO_SUMMARIES: self.mWriter.add_summary(t0['summaries'], global_step=t0['step']) self.mWriter.add_summary(t1['summaries'], global_step=t1['step']) self.mWriter.flush() return act0['action'] def updateTargetState(self): # reshape_tf_tgt = [1, self._LSTMUnrollLength, shape_spectrogram_flat] pass