def train(self, batch_data: TransitionData, **kwargs) -> dict: self.set_status('TRAIN') self.state_input_scaler.update_scaler(batch_data.state_set) self.action_input_scaler.update_scaler(batch_data.action_set) self.output_delta_state_scaler.update_scaler(batch_data.new_state_set - batch_data.state_set) tf_sess = kwargs['sess'] if ( 'sess' in kwargs and kwargs['sess']) else tf.get_default_session() train_iter = self.parameters( 'train_iter' ) if 'train_iter' not in kwargs else kwargs['train_iter'] feed_dict = { self.state_input: self.state_input_scaler.process(batch_data.state_set), self.action_input: self.action_input_scaler.process( flatten_n(self.env_spec.action_space, batch_data.action_set)), self.delta_state_label_ph: self.output_delta_state_scaler.process(batch_data.new_state_set - batch_data.state_set), **self.parameters.return_tf_parameter_feed_dict() } average_loss = 0.0 for i in range(train_iter): loss, _ = tf_sess.run([self.loss, self.optimize_op], feed_dict=feed_dict) average_loss += loss return dict(average_loss=average_loss / train_iter)
def train(self, batch_data=None, train_iter=None, sess=None, update_target=True) -> dict: super(DQN, self).train() self.recorder.record() if batch_data and not isinstance(batch_data, TransitionData): raise TypeError() tf_sess = sess if sess else tf.get_default_session() train_iter = self.parameters("TRAIN_ITERATION") if not train_iter else train_iter average_loss = 0.0 for i in range(train_iter): train_data = self.replay_buffer.sample( batch_size=self.parameters('BATCH_SIZE')) if batch_data is None else batch_data _, target_q_val_on_new_s = self.predict_target_with_q_val(obs=train_data.new_state_set, batch_flag=True) target_q_val_on_new_s = np.expand_dims(target_q_val_on_new_s, axis=1) assert target_q_val_on_new_s.shape[0] == train_data.state_set.shape[0] feed_dict = { self.reward_input: train_data.reward_set, self.action_input: flatten_n(self.env_spec.action_space, train_data.action_set), self.state_input: train_data.state_set, self.done_input: train_data.done_set, self.target_q_input: target_q_val_on_new_s, **self.parameters.return_tf_parameter_feed_dict() } res, _ = tf_sess.run([self.q_value_func_loss, self.update_q_value_func_op], feed_dict=feed_dict) average_loss += res average_loss /= train_iter if update_target is True: tf_sess.run(self.update_target_q_value_func_op, feed_dict=self.parameters.return_tf_parameter_feed_dict()) return dict(average_loss=average_loss)