def train(train_env, monitor_env, agent_action_fn, noise_process): ''' :return: ''' action_space = train_env.action_space obs_space = train_env.observation_space ######### instantiate actor,critic, replay buffer, uo-process######### ## feed online with state. feed target with next_state. online_state_inputs = tf.placeholder(tf.float32, shape=(None, obs_space.shape[0]), name="online_state_inputs") # tf.logging.info('@@@@ online_state_inputs shape:{}'.format(online_state_inputs.shape)) target_state_inputs = tf.placeholder(tf.float32, shape=online_state_inputs.shape, name="target_state_inputs") ## inputs to q_net for training q. online_action_inputs_training_q = tf.placeholder( tf.float32, shape=(None, action_space.shape[0]), name='online_action_batch_inputs') # condition bool scalar to switch action inputs to online q. # feed True: training q. # feed False: training policy. cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q') # batch_size vector. terminated_inputs = tf.placeholder(tf.float32, shape=(None), name='terminated_inputs') reward_inputs = tf.placeholder(tf.float32, shape=(None), name='rewards_inputs') #for l_r decay actor_l_r = tf.placeholder(tf.float32, shape=[], name='actor_l_r') critic_l_r = tf.placeholder(tf.float32, shape=[], name='critic_l_r') #for summary text summary_text_tensor = tf.convert_to_tensor(str('summary_text'), preferred_dtype=string) tf.summary.text(name='summary_text', tensor=summary_text_tensor, collections=[DDPG_CFG.log_summary_keys]) ##instantiate actor, critic. actor = Actor( action_dim=action_space.shape[0], online_state_inputs=online_state_inputs, target_state_inputs=target_state_inputs, input_normalizer=DDPG_CFG.actor_input_normalizer, input_norm_params=DDPG_CFG.actor_input_norm_params, n_fc_units=DDPG_CFG.actor_n_fc_units, fc_activations=DDPG_CFG.actor_fc_activations, fc_initializers=DDPG_CFG.actor_fc_initializers, fc_normalizers=DDPG_CFG.actor_fc_normalizers, fc_norm_params=DDPG_CFG.actor_fc_norm_params, fc_regularizers=DDPG_CFG.actor_fc_regularizers, output_layer_initializer=DDPG_CFG.actor_output_layer_initializer, # output_layer_regularizer=DDPG_CFG.actor_output_layer_regularizer, output_layer_regularizer=None, output_normalizers=DDPG_CFG.actor_output_layer_normalizers, output_norm_params=DDPG_CFG.actor_output_layer_norm_params, # output_normalizers=None, # output_norm_params=None, output_bound_fns=DDPG_CFG.actor_output_bound_fns, learning_rate=actor_l_r, is_training=is_training) critic = Critic( online_state_inputs=online_state_inputs, target_state_inputs=target_state_inputs, input_normalizer=DDPG_CFG.critic_input_normalizer, input_norm_params=DDPG_CFG.critic_input_norm_params, online_action_inputs_training_q=online_action_inputs_training_q, online_action_inputs_training_policy=actor. online_action_outputs_tensor, cond_training_q=cond_training_q, target_action_inputs=actor.target_action_outputs_tensor, n_fc_units=DDPG_CFG.critic_n_fc_units, fc_activations=DDPG_CFG.critic_fc_activations, fc_initializers=DDPG_CFG.critic_fc_initializers, fc_normalizers=DDPG_CFG.critic_fc_normalizers, fc_norm_params=DDPG_CFG.critic_fc_norm_params, fc_regularizers=DDPG_CFG.critic_fc_regularizers, output_layer_initializer=DDPG_CFG.critic_output_layer_initializer, output_layer_regularizer=DDPG_CFG.critic_output_layer_regularizer, # output_layer_regularizer = None, learning_rate=critic_l_r) ## track updates. global_step_tensor = tf.train.create_global_step() ## build whole graph copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver,q_loss_tensor \ = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor) #we save the replay buffer data to files. replay_buffer = ReplayBuffer( buffer_size=DDPG_CFG.replay_buff_size, save_segment_size=DDPG_CFG.replay_buff_save_segment_size, save_path=DDPG_CFG.replay_buffer_file_path, seed=DDPG_CFG.random_seed) ##TODO test load replay buffer from files. if DDPG_CFG.load_replay_buffer_set: replay_buffer.load(DDPG_CFG.replay_buffer_file_path) # === finish building ddpg graph before this =================# ##create tf default session sess = tf.Session(graph=tf.get_default_graph()) ''' # note: will transfer graph to graphdef now. so we must finish all the computation graph # before creating summary writer. ''' summary_writer = tf.summary.FileWriter(logdir=os.path.join( DDPG_CFG.log_dir, "train"), graph=sess.graph) actor_summary_op = tf.summary.merge_all(key=DDPG_CFG.actor_summary_keys) critic_summary_op = tf.summary.merge_all(key=DDPG_CFG.critic_summary_keys) log_summary_op = tf.summary.merge_all(key=DDPG_CFG.log_summary_keys) ######### initialize computation graph ############ ''' # -------------trace graphdef only whole_graph_def = meta_graph.create_meta_graph_def(graph_def=sess.graph.as_graph_def()) summary_writer.add_meta_graph(whole_graph_def,global_step=1) summary_writer.flush() run_options = tf.RunOptions(output_partition_graphs=True, trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() # including copy target -> online sess.run(fetches=[init_op], options=run_options, run_metadata=run_metadata ) graphdef_part1 = run_metadata.partition_graphs[0] meta_graph_part1 = meta_graph.create_meta_graph_def(graph_def=graphdef_part1) part1_metagraph_writer = tf.summary.FileWriter(DDPG_CFG.log_dir + '/part1_metagraph') part1_metagraph_writer.add_meta_graph(meta_graph_part1) part1_metagraph_writer.close() graphdef_part2 = run_metadata.partition_graphs[1] meta_graph_part2 = meta_graph.create_meta_graph_def(graph_def=graphdef_part2) part2_metagraph_writer = tf.summary.FileWriter(DDPG_CFG.log_dir + '/part2_metagraph') part2_metagraph_writer.add_meta_graph(meta_graph_part2) part2_metagraph_writer.close() # --------------- end trace ''' sess.run(fetches=[tf.global_variables_initializer()]) #copy init params from online to target sess.run(fetches=[copy_online_to_target_op]) # Load a previous checkpoint if it exists latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir) if latest_checkpoint: print("=== Loading model checkpoint: {}".format(latest_checkpoint)) saver.restore(sess, latest_checkpoint) ####### start training ######### if not DDPG_CFG.train_from_replay_buffer_set_only: obs = train_env.reset() transition = preprocess_low_dim(obs) n_episodes = 1 update_start = 0.0 for step in range(1, DDPG_CFG.num_training_steps): noise_process.reset() #replace with new transition if not DDPG_CFG.train_from_replay_buffer_set_only: #no need new samples transition = agent_action_fn(step, sess, actor, online_state_inputs, is_training, transition.next_state[np.newaxis, :], replay_buffer, noise_process, train_env) if step % DDPG_CFG.summary_transition_freq == 0: summary_transition(summary_writer, action_space.shape[0], transition, step) # after fill replay_buffer with some states, we start learn. if step > DDPG_CFG.learn_start: # test update duration at first 10 update if step < (DDPG_CFG.learn_start + 10): update_start = time.time() ## ++++ sample mini-batch and train.++++ state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \ replay_buffer.sample_batch(DDPG_CFG.batch_size) if step % 2000 == 0 and DDPG_CFG.train_from_replay_buffer_set_only: tf.logging.info( '@@@@@ train from buffer only -one sample - global_step:{} action:{}' ' reward:{} term:{} @@@@@@@@@@'.format( step, action_batch[0], reward_batch[0], terminated_batch[0])) # ---- 1. train policy.----------- # no need to feed reward, next_state, terminated which are un-used in policy update. # run_options = tf.RunOptions(output_partition_graphs=True, trace_level=tf.RunOptions.FULL_TRACE) if 0 == step % DDPG_CFG.summary_freq: # run_metadata = tf.RunMetadata() _, actor_summary = sess.run( fetches=[train_online_policy_op, actor_summary_op], feed_dict={ online_state_inputs: state_batch, cond_training_q: False, online_action_inputs_training_q: action_batch, # feed but not used. actor_l_r: l_r_decay(DDPG_CFG.actor_learning_rate, step), is_training: True }) # options=run_options, # run_metadata=run_metadata) # summary_writer._add_graph_def(run_metadata.partition_graphs[0]) # the policy online network is updated above and will not affect training q. # ---- 2. train q. -------------- _, critic_summary = sess.run( fetches=[train_online_q_op, critic_summary_op], feed_dict={ online_state_inputs: state_batch, cond_training_q: True, online_action_inputs_training_q: action_batch, target_state_inputs: next_state_batch, reward_inputs: reward_batch, reward_inputs: reward_batch, terminated_inputs: terminated_batch, critic_l_r: l_r_decay(DDPG_CFG.critic_learning_rate, step), is_training: True }) summary_writer.add_summary(actor_summary) summary_writer.add_summary(critic_summary) summary_writer.flush() else: _ = sess.run( fetches=[train_online_policy_op], feed_dict={ online_state_inputs: state_batch, cond_training_q: False, online_action_inputs_training_q: action_batch, # feed but not used. actor_l_r: l_r_decay(DDPG_CFG.actor_learning_rate, step), is_training: True }) # the policy online network is updated above and will not affect training q. # ---- 2. train q. -------------- _, q_loss_value = sess.run( fetches=[train_online_q_op, q_loss_tensor], feed_dict={ online_state_inputs: state_batch, cond_training_q: True, online_action_inputs_training_q: action_batch, target_state_inputs: next_state_batch, reward_inputs: reward_batch, terminated_inputs: terminated_batch, critic_l_r: l_r_decay(DDPG_CFG.critic_learning_rate, step), is_training: True }) if step % 2000 == 0: tf.logging.info('@@ step:{} q_loss:{}'.format( step, q_loss_value)) # --end of summary -- # ----- 3. update target --------- # including increment global step. _ = sess.run(fetches=[update_target_op], feed_dict=None) # test update duration at first 10 update if step < (DDPG_CFG.learn_start + 10): tf.logging.info( ' @@@@ one batch learn duration @@@@:{}'.format( time.time() - update_start)) # do evaluation after eval_freq steps: if step % DDPG_CFG.eval_freq == 0: ##and step > DDPG_CFG.eval_freq: evaluate(env=monitor_env, num_eval_steps=DDPG_CFG.num_eval_steps, preprocess_fn=preprocess_low_dim, estimate_fn=lambda state: sess.run( fetches=[actor.online_action_outputs_tensor], feed_dict={ online_state_inputs: state, is_training: False }), summary_writer=summary_writer, saver=saver, sess=sess, global_step=step, log_summary_op=log_summary_op, summary_text_tensor=summary_text_tensor) # if monitor_env is train_env: # #torcs share. we should reset # transition.terminated = True #fall through #-- end of learn #TODO temp solution to on vision .use thread instead if step % 2000 == 0: v_on = os.path.exists('/home/yuheng/Desktop/train_vision_on') if train_env.vision_status == False and v_on: train_env.vision_on() #will display next reset transition = preprocess_low_dim(train_env.reset(relaunch=True)) n_episodes += 1 tf.logging.info('@@ episodes: {} @@'.format(n_episodes)) continue elif train_env.vision_status == True and not v_on: train_env.vision_off() transition = preprocess_low_dim(train_env.reset(relaunch=True)) n_episodes += 1 tf.logging.info('@@ episodes: {} @@'.format(n_episodes)) continue # if os.path.exists('/home/yuheng/Desktop/eval_vision_on'): # monitor_env.vision_on() # will display next reset # else: # monitor_env.vision_off() if (not DDPG_CFG.train_from_replay_buffer_set_only) and ( transition.terminated): # relaunch TORCS every 3 episode because of the memory leak error # replace with transition observed after reset.only save state.. transition = preprocess_low_dim(train_env.reset()) n_episodes += 1 tf.logging.info('@@ episodes: {} @@'.format(n_episodes)) continue # begin new episode # ====end for t. sess.close() train_env.close() monitor_env.close()
def train(train_env, agent_action_fn, eval_mode=False): action_space = train_env.action_space obs_space = train_env.observation_space ######### instantiate actor,critic, replay buffer, uo-process######### ## feed online with state. feed target with next_state. online_state_inputs = tf.placeholder(tf.float32, shape=(None, obs_space.shape[0]), name="online_state_inputs") target_state_inputs = tf.placeholder(tf.float32, shape=online_state_inputs.shape, name="target_state_inputs") ## inputs to q_net for training q. online_action_inputs_training_q = tf.placeholder( tf.float32, shape=(None, action_space.shape[0]), name='online_action_batch_inputs') # condition bool scalar to switch action inputs to online q. # feed True: training q. # feed False: training policy. cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q') terminated_inputs = tf.placeholder(tf.float32, shape=(None), name='terminated_inputs') reward_inputs = tf.placeholder(tf.float32, shape=(None), name='rewards_inputs') # for summary text summary_text_tensor = tf.convert_to_tensor(str('summary_text'), preferred_dtype=string) tf.summary.text(name='summary_text', tensor=summary_text_tensor, collections=[DDPG_CFG.log_summary_keys]) ##instantiate actor, critic. actor = Actor( action_dim=action_space.shape[0], online_state_inputs=online_state_inputs, target_state_inputs=target_state_inputs, input_normalizer=DDPG_CFG.actor_input_normalizer, input_norm_params=DDPG_CFG.actor_input_norm_params, n_fc_units=DDPG_CFG.actor_n_fc_units, fc_activations=DDPG_CFG.actor_fc_activations, fc_initializers=DDPG_CFG.actor_fc_initializers, fc_normalizers=DDPG_CFG.actor_fc_normalizers, fc_norm_params=DDPG_CFG.actor_fc_norm_params, fc_regularizers=DDPG_CFG.actor_fc_regularizers, output_layer_initializer=DDPG_CFG.actor_output_layer_initializer, output_layer_regularizer=None, output_normalizers=DDPG_CFG.actor_output_layer_normalizers, output_norm_params=DDPG_CFG.actor_output_layer_norm_params, output_bound_fns=DDPG_CFG.actor_output_bound_fns, learning_rate=DDPG_CFG.actor_learning_rate, is_training=is_training) critic = Critic( online_state_inputs=online_state_inputs, target_state_inputs=target_state_inputs, input_normalizer=DDPG_CFG.critic_input_normalizer, input_norm_params=DDPG_CFG.critic_input_norm_params, online_action_inputs_training_q=online_action_inputs_training_q, online_action_inputs_training_policy=actor. online_action_outputs_tensor, cond_training_q=cond_training_q, target_action_inputs=actor.target_action_outputs_tensor, n_fc_units=DDPG_CFG.critic_n_fc_units, fc_activations=DDPG_CFG.critic_fc_activations, fc_initializers=DDPG_CFG.critic_fc_initializers, fc_normalizers=DDPG_CFG.critic_fc_normalizers, fc_norm_params=DDPG_CFG.critic_fc_norm_params, fc_regularizers=DDPG_CFG.critic_fc_regularizers, output_layer_initializer=DDPG_CFG.critic_output_layer_initializer, output_layer_regularizer=None, learning_rate=DDPG_CFG.critic_learning_rate) ## track updates. global_step_tensor = tf.train.create_global_step() ## build whole graph copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver \ = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor) #we save the replay buffer data to files. replay_buffer = ReplayBuffer( buffer_size=DDPG_CFG.replay_buff_size, save_segment_size=DDPG_CFG.replay_buff_save_segment_size, save_path=DDPG_CFG.replay_buffer_file_path, seed=DDPG_CFG.random_seed) if DDPG_CFG.load_replay_buffer_set: replay_buffer.load(DDPG_CFG.replay_buffer_file_path) sess = tf.Session(graph=tf.get_default_graph()) summary_writer = tf.summary.FileWriter(logdir=os.path.join( DDPG_CFG.log_dir, "train"), graph=sess.graph) log_summary_op = tf.summary.merge_all(key=DDPG_CFG.log_summary_keys) sess.run(fetches=[tf.global_variables_initializer()]) #copy init params from online to target sess.run(fetches=[copy_online_to_target_op]) # Load a previous checkpoint if it exists latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir) if latest_checkpoint: tf.logging.info( "==== Loading model checkpoint: {}".format(latest_checkpoint)) saver.restore(sess, latest_checkpoint) elif eval_mode: raise FileNotFoundError( '== in evaluation mode, we need check point file which can not be found.===' ) ####### start training ######### obs = train_env.reset() transition = preprocess_low_dim(obs) n_episodes = 1 if not eval_mode: for step in range(1, DDPG_CFG.num_training_steps): #replace with new transition policy_out = sess.run(fetches=[actor.online_action_outputs_tensor], feed_dict={ online_state_inputs: transition.next_state[np.newaxis, :], is_training: False })[0] transition = agent_action_fn(policy_out, replay_buffer, train_env) if step % 200 == 0: tf.logging.info(' +++++++++++++++++++ global_step:{} action:{}' ' reward:{} term:{}'.format( step, transition.action, transition.reward, transition.terminated)) if step < 10: #feed some transitions in buffer. continue ## ++++ sample mini-batch and train.++++ state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \ replay_buffer.sample_batch(DDPG_CFG.batch_size) # ---- 1. train policy.----------- sess.run( fetches=[train_online_policy_op], feed_dict={ online_state_inputs: state_batch, cond_training_q: False, online_action_inputs_training_q: action_batch, # feed but not used. is_training: True }) # ---- 2. train q. -------------- sess.run(fetches=[train_online_q_op], feed_dict={ online_state_inputs: state_batch, cond_training_q: True, online_action_inputs_training_q: action_batch, target_state_inputs: next_state_batch, reward_inputs: reward_batch, terminated_inputs: terminated_batch, is_training: True }) # ----- 3. update target --------- sess.run(fetches=[update_target_op], feed_dict=None) # do evaluation after eval_freq steps: if step % DDPG_CFG.eval_freq == 0: ##and step > DDPG_CFG.eval_freq: evaluate(env=train_env, num_eval_steps=DDPG_CFG.num_eval_steps, preprocess_fn=preprocess_low_dim, estimate_fn=lambda state: sess.run( fetches=[actor.online_action_outputs_tensor], feed_dict={ online_state_inputs: state, is_training: False }), summary_writer=summary_writer, saver=saver, sess=sess, global_step=step, log_summary_op=log_summary_op, summary_text_tensor=summary_text_tensor) if transition.terminated: transition = preprocess_low_dim(train_env.reset()) n_episodes += 1 continue # begin new episode else: #eval mode evaluate(env=train_env, num_eval_steps=DDPG_CFG.eval_steps_after_training, preprocess_fn=preprocess_low_dim, estimate_fn=lambda state: sess.run( fetches=[actor.online_action_outputs_tensor], feed_dict={ online_state_inputs: state, is_training: False }), summary_writer=summary_writer, saver=None, sess=sess, global_step=0, log_summary_op=log_summary_op, summary_text_tensor=summary_text_tensor) sess.close() train_env.close()
def train(train_env, agent_action_fn, eval_mode=False): # 仿真环境动作、环境维度 action_space = train_env.action_space obs_space = train_env.observation_space ######### instantiate actor,critic, replay buffer, uo-process######### ## feed online with state. feed target with next_state. online_state_inputs = tf.placeholder(tf.float32, shape=(None, obs_space.shape[0]), name="online_state_inputs") target_state_inputs = tf.placeholder(tf.float32, shape=online_state_inputs.shape, name="target_state_inputs") ## inputs to q_net for training q. online_action_inputs_training_q = tf.placeholder( tf.float32, shape=(None, action_space.shape[0]), name='online_action_batch_inputs') # cond_training_q:用于控制q 网络action输入的条件变量 # True: training q . # False: training policy. cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q') terminated_inputs = tf.placeholder(tf.float32, shape=(None), name='terminated_inputs') reward_inputs = tf.placeholder(tf.float32, shape=(None), name='rewards_inputs') # summary text summary_text_tensor = tf.convert_to_tensor(str('summary_text'), preferred_dtype=string) tf.summary.text(name='summary_text', tensor=summary_text_tensor, collections=[DDPG_CFG.log_summary_keys]) # 创建actor、critic实例 actor = Actor( action_dim=action_space.shape[0], online_state_inputs=online_state_inputs, target_state_inputs=target_state_inputs, input_normalizer=DDPG_CFG.actor_input_normalizer, input_norm_params=DDPG_CFG.actor_input_norm_params, n_fc_units=DDPG_CFG.actor_n_fc_units, fc_activations=DDPG_CFG.actor_fc_activations, fc_initializers=DDPG_CFG.actor_fc_initializers, fc_normalizers=DDPG_CFG.actor_fc_normalizers, fc_norm_params=DDPG_CFG.actor_fc_norm_params, fc_regularizers=DDPG_CFG.actor_fc_regularizers, output_layer_initializer=DDPG_CFG.actor_output_layer_initializer, output_layer_regularizer=None, output_normalizers=DDPG_CFG.actor_output_layer_normalizers, output_norm_params=DDPG_CFG.actor_output_layer_norm_params, output_bound_fns=DDPG_CFG.actor_output_bound_fns, learning_rate=DDPG_CFG.actor_learning_rate, is_training=is_training) critic = Critic( online_state_inputs=online_state_inputs, target_state_inputs=target_state_inputs, input_normalizer=DDPG_CFG.critic_input_normalizer, input_norm_params=DDPG_CFG.critic_input_norm_params, online_action_inputs_training_q=online_action_inputs_training_q, online_action_inputs_training_policy=actor. online_action_outputs_tensor, cond_training_q=cond_training_q, target_action_inputs=actor.target_action_outputs_tensor, n_fc_units=DDPG_CFG.critic_n_fc_units, fc_activations=DDPG_CFG.critic_fc_activations, fc_initializers=DDPG_CFG.critic_fc_initializers, fc_normalizers=DDPG_CFG.critic_fc_normalizers, fc_norm_params=DDPG_CFG.critic_fc_norm_params, fc_regularizers=DDPG_CFG.critic_fc_regularizers, output_layer_initializer=DDPG_CFG.critic_output_layer_initializer, output_layer_regularizer=None, learning_rate=DDPG_CFG.critic_learning_rate) # track updates. global_step_tensor = tf.train.create_global_step() # 构建整个ddpg computation graph copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver \ = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor) # 实例化 replay buffer,指定是否将buffer数据保存到文件 replay_buffer = ReplayBuffer( buffer_size=DDPG_CFG.replay_buff_size, save_segment_size=DDPG_CFG.replay_buff_save_segment_size, save_path=DDPG_CFG.replay_buffer_file_path, seed=DDPG_CFG.random_seed) # 从文件加载buffer数据 if DDPG_CFG.load_replay_buffer_set: replay_buffer.load(DDPG_CFG.replay_buffer_file_path) # 使用summary监控训练中各项数据、参数的变化,并生成图表,在tensorboard中进行观察 sess = tf.Session(graph=tf.get_default_graph()) summary_writer = tf.summary.FileWriter(logdir=os.path.join( DDPG_CFG.log_dir, "train"), graph=sess.graph) actor_summary_op = tf.summary.merge_all(key=DDPG_CFG.actor_summary_keys) critic_summary_op = tf.summary.merge_all(key=DDPG_CFG.critic_summary_keys) log_summary_op = tf.summary.merge_all(key=DDPG_CFG.log_summary_keys) sess.run(fetches=[tf.global_variables_initializer()]) # 初始化将online的参数拷贝给target sess.run(fetches=[copy_online_to_target_op]) # #加载之前保存的模型参数checkpoint: latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir) if latest_checkpoint: tf.logging.info( "==== Loading model checkpoint: {}".format(latest_checkpoint)) saver.restore(sess, latest_checkpoint) elif eval_mode: raise FileNotFoundError( '== in evaluation mode, we need check point file which can not be found.===' ) ####### 开始训练 ######### obs = train_env.reset() transition = preprocess_low_dim(obs) n_episodes = 1 # episode计数 # 训练模式 if not eval_mode: # 总共训练DDPG_CFG.num_training_steps for step in range(1, DDPG_CFG.num_training_steps): # 根据state参数,从online policy网络得到action policy_out = sess.run(fetches=[actor.online_action_outputs_tensor], feed_dict={ online_state_inputs: transition.next_state[np.newaxis, :], is_training: False })[0] # 通过仿真环境执行action,并保存Transition数据到replay buffer transition = agent_action_fn(policy_out, replay_buffer, train_env) if step % 200 == 0: tf.logging.info(' +++++++++++++++++++ global_step:{} action:{}' ' reward:{} term:{}'.format( step, transition.action, transition.reward, transition.terminated)) # feed some transitions in buffer. if step < 10: continue # 从replay buffer采样一个mini-batch state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \ replay_buffer.sample_batch(DDPG_CFG.batch_size) if step % DDPG_CFG.summary_frep == 0: _, actor_summary = sess.run( fetches=[train_online_policy_op, actor_summary_op], feed_dict={ online_state_inputs: state_batch, cond_training_q: False, online_action_inputs_training_q: action_batch, is_training: True }) _, critic_summary = sess.run( fetches=[train_online_q_op, critic_summary_op], feed_dict={ online_state_inputs: state_batch, cond_training_q: True, online_action_inputs_training_q: action_batch, target_state_inputs: next_state_batch, reward_inputs: reward_batch, terminated_inputs: terminated_batch, is_training: True }) summary_writer.add_summary(actor_summary, global_step=step) summary_writer.add_summary(critic_summary, global_step=step) summary_writer.flush() else: # ---- 1.训练策略网络----------- sess.run(fetches=[train_online_policy_op], feed_dict={ online_state_inputs: state_batch, cond_training_q: False, online_action_inputs_training_q: action_batch, is_training: True }) # ---- 2.训练q网络 -------------- sess.run(fetches=[train_online_q_op], feed_dict={ online_state_inputs: state_batch, cond_training_q: True, online_action_inputs_training_q: action_batch, target_state_inputs: next_state_batch, reward_inputs: reward_batch, terminated_inputs: terminated_batch, is_training: True }) # ----- 3.更新target网络 --------- sess.run(fetches=[update_target_op], feed_dict=None) # 每隔 eval_freq steps,我们进行一次evaluation,以便在训练结束后选择好的模型: if step % DDPG_CFG.eval_freq == 0: evaluate(env=train_env, num_eval_steps=DDPG_CFG.num_eval_steps, preprocess_fn=preprocess_low_dim, estimate_fn=lambda state: sess.run( fetches=[actor.online_action_outputs_tensor], feed_dict={ online_state_inputs: state, is_training: False }), summary_writer=summary_writer, saver=saver, sess=sess, global_step=step, log_summary_op=log_summary_op, summary_text_tensor=summary_text_tensor) if transition.terminated: transition = preprocess_low_dim(train_env.reset()) n_episodes += 1 continue # begin new episode # eval mode else: evaluate(env=train_env, num_eval_steps=DDPG_CFG.eval_steps_after_training, preprocess_fn=preprocess_low_dim, estimate_fn=lambda state: sess.run( fetches=[actor.online_action_outputs_tensor], feed_dict={ online_state_inputs: state, is_training: False }), summary_writer=summary_writer, saver=None, sess=sess, global_step=0, log_summary_op=log_summary_op, summary_text_tensor=summary_text_tensor) sess.close() train_env.close()