예제 #1
0
파일: agent.py 프로젝트: zhly0/relaax
    def get_action_and_value_from_network(self):
        if da3c_config.config.use_lstm:
            action, value, lstm_state = \
                    self.session.op_get_action_value_and_lstm_state(state=[self.observation.queue],
                                                                    lstm_state=self.lstm_state,
                                                                    lstm_step=[1])
            condition = self.experience is not None and (len(
                self.experience) == da3c_config.config.batch_size
                                                         or self.terminal)
            if not condition:
                self.lstm_state = lstm_state
        else:
            action, value = self.session.op_get_action_and_value(
                state=[self.observation.queue])

        value, = value
        if len(action) == 1:
            if M:
                self.metrics.histogram('action', action)
            self.last_probs, = action
            return utils.choose_action_descrete(self.last_probs), value
        mu, sigma2 = action
        self.last_probs = mu
        if M:
            self.metrics.histogram('mu', mu)
            self.metrics.histogram('sigma2', sigma2)
        return utils.choose_action_continuous(
            mu, sigma2, da3c_config.config.output.action_low,
            da3c_config.config.output.action_high), value
예제 #2
0
 def action_from_policy(self, state):
     assert state is not None
     state = np.asarray(state)
     state = np.reshape(state, (1, ) + state.shape)
     probabilities, = self.session.op_get_action(state=state)
     return utils.choose_action_descrete(probabilities, self.exploit)
예제 #3
0
 def action_from_policy(self, state):
     probabilities, = self.session.op_get_action(state=[state])
     return utils.choose_action_descrete(probabilities, self.exploit)