def train(self, batch_data: TransitionData, **kwargs) -> dict:
        self.set_status('TRAIN')

        self.state_input_scaler.update_scaler(batch_data.state_set)
        self.action_input_scaler.update_scaler(batch_data.action_set)
        self.output_delta_state_scaler.update_scaler(batch_data.new_state_set -
                                                     batch_data.state_set)

        tf_sess = kwargs['sess'] if (
            'sess' in kwargs and kwargs['sess']) else tf.get_default_session()
        train_iter = self.parameters(
            'train_iter'
        ) if 'train_iter' not in kwargs else kwargs['train_iter']
        feed_dict = {
            self.state_input:
            self.state_input_scaler.process(batch_data.state_set),
            self.action_input:
            self.action_input_scaler.process(
                flatten_n(self.env_spec.action_space, batch_data.action_set)),
            self.delta_state_label_ph:
            self.output_delta_state_scaler.process(batch_data.new_state_set -
                                                   batch_data.state_set),
            **self.parameters.return_tf_parameter_feed_dict()
        }
        average_loss = 0.0

        for i in range(train_iter):
            loss, _ = tf_sess.run([self.loss, self.optimize_op],
                                  feed_dict=feed_dict)
            average_loss += loss
        return dict(average_loss=average_loss / train_iter)
Esempio n. 2
0
    def train(self, batch_data=None, train_iter=None, sess=None, update_target=True) -> dict:
        super(DQN, self).train()
        self.recorder.record()
        if batch_data and not isinstance(batch_data, TransitionData):
            raise TypeError()

        tf_sess = sess if sess else tf.get_default_session()
        train_iter = self.parameters("TRAIN_ITERATION") if not train_iter else train_iter
        average_loss = 0.0

        for i in range(train_iter):
            train_data = self.replay_buffer.sample(
                batch_size=self.parameters('BATCH_SIZE')) if batch_data is None else batch_data

            _, target_q_val_on_new_s = self.predict_target_with_q_val(obs=train_data.new_state_set,
                                                                      batch_flag=True)
            target_q_val_on_new_s = np.expand_dims(target_q_val_on_new_s, axis=1)
            assert target_q_val_on_new_s.shape[0] == train_data.state_set.shape[0]
            feed_dict = {
                self.reward_input: train_data.reward_set,
                self.action_input: flatten_n(self.env_spec.action_space, train_data.action_set),
                self.state_input: train_data.state_set,
                self.done_input: train_data.done_set,
                self.target_q_input: target_q_val_on_new_s,
                **self.parameters.return_tf_parameter_feed_dict()
            }
            res, _ = tf_sess.run([self.q_value_func_loss, self.update_q_value_func_op],
                                 feed_dict=feed_dict)
            average_loss += res

        average_loss /= train_iter
        if update_target is True:
            tf_sess.run(self.update_target_q_value_func_op,
                        feed_dict=self.parameters.return_tf_parameter_feed_dict())
        return dict(average_loss=average_loss)