def agent_action(step, sess, actor, online_state_inputs, is_training, state,
                 replay_buffer, noise_process, env):
    policy_output = sess.run(fetches=[actor.online_action_outputs_tensor],
                             feed_dict={
                                 online_state_inputs: state,
                                 is_training: False
                             })  # must reshape to (1,11)
    policy_output = policy_output[0]

    ##add noise and bound
    stochastic_action = policy_output_to_stochastic_action(
        policy_output, env.action_space, noise_process)

    ## excute a_t and store Transition.
    (state, reward, terminated) = env_step(env, stochastic_action)
    # episode_reward += reward

    # if step % 20 == 0:
    if step % 2000 == 0:
        tf.logging.info(' +++++++++++++++++++ global_step:{} action:{}'
                        '  reward:{} term:{}'.format(step, stochastic_action,
                                                     reward, terminated))
    # replace transition with new one.
    transition = preprocess_low_dim(action=stochastic_action,
                                    reward=reward,
                                    terminated=terminated,
                                    state=state)

    ##even if terminated ,we still save next_state cause FF Q network
    # will use it, but will discard Q value in the end.
    replay_buffer.store(transition)
    return transition
def agent_action(policy_out, replay_buffer,env):
    ##add noise and bound
    stochastic_action=policy_output_to_stochastic_action(policy_out, env.action_space)

    ## excute a_t and store Transition.
    (state, reward, terminated) = env_step(env, stochastic_action)

    # replace transition with new one.
    transition = preprocess_low_dim(action=stochastic_action,
        reward=reward,
        terminated=terminated,
        state=state)
    ##even if terminated ,we still save next_state.
    replay_buffer.store(transition)
    return transition
示例#3
0
def agent_action(step, sess, actor, online_state_inputs, is_training, state,
                 replay_buffer, noise_process, env):
    #make random play at beginning . to fill some states in replay buffer.
    if step < DDPG_CFG.learn_start:
        stochastic_action = [
            np.random.uniform(low, high)
            for (low, high) in zip(env.action_space.low, env.action_space.high)
        ]
    else:
        policy_output = sess.run(fetches=[actor.online_action_outputs_tensor],
                                 feed_dict={
                                     online_state_inputs: state,
                                     is_training: False
                                 })  # must reshape to (1,11)
        policy_output = policy_output[0]  #list of tensor

        ##add noise and bound
        stochastic_action = policy_output_to_stochastic_action(
            policy_output, noise_process, env.action_space)

    ## excute a_t and store Transition.
    (state, reward, terminated) = env_step(env, stochastic_action)

    if step % 2000 == 0:
        tf.logging.info('@@@@@@@@@@ global_step:{} action:{}'
                        '  reward:{} term:{} @@@@@@@@@@'.format(
                            step, stochastic_action, reward, terminated))

    # replace transition with new one.
    transition = preprocess_low_dim(action=stochastic_action,
                                    reward=reward,
                                    terminated=terminated,
                                    state=state)

    ##even if terminated ,we still save next_state cause FF Q network
    # will use it, but will discard Q value in the end.
    replay_buffer.store(transition)
    return transition
示例#4
0
def train(train_env, agent_action_fn, eval_mode=False):
    action_space = train_env.action_space
    obs_space = train_env.observation_space

    ######### instantiate actor,critic, replay buffer, uo-process#########
    ## feed online with state. feed target with next_state.
    online_state_inputs = tf.placeholder(tf.float32,
                                         shape=(None, obs_space.shape[0]),
                                         name="online_state_inputs")

    target_state_inputs = tf.placeholder(tf.float32,
                                         shape=online_state_inputs.shape,
                                         name="target_state_inputs")

    ## inputs to q_net for training q.
    online_action_inputs_training_q = tf.placeholder(
        tf.float32,
        shape=(None, action_space.shape[0]),
        name='online_action_batch_inputs')
    # condition bool scalar to switch action inputs to online q.
    # feed True: training q.
    # feed False: training policy.
    cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q')

    terminated_inputs = tf.placeholder(tf.float32,
                                       shape=(None),
                                       name='terminated_inputs')
    reward_inputs = tf.placeholder(tf.float32,
                                   shape=(None),
                                   name='rewards_inputs')

    # for summary text
    summary_text_tensor = tf.convert_to_tensor(str('summary_text'),
                                               preferred_dtype=string)
    tf.summary.text(name='summary_text',
                    tensor=summary_text_tensor,
                    collections=[DDPG_CFG.log_summary_keys])

    ##instantiate actor, critic.
    actor = Actor(
        action_dim=action_space.shape[0],
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.actor_input_normalizer,
        input_norm_params=DDPG_CFG.actor_input_norm_params,
        n_fc_units=DDPG_CFG.actor_n_fc_units,
        fc_activations=DDPG_CFG.actor_fc_activations,
        fc_initializers=DDPG_CFG.actor_fc_initializers,
        fc_normalizers=DDPG_CFG.actor_fc_normalizers,
        fc_norm_params=DDPG_CFG.actor_fc_norm_params,
        fc_regularizers=DDPG_CFG.actor_fc_regularizers,
        output_layer_initializer=DDPG_CFG.actor_output_layer_initializer,
        output_layer_regularizer=None,
        output_normalizers=DDPG_CFG.actor_output_layer_normalizers,
        output_norm_params=DDPG_CFG.actor_output_layer_norm_params,
        output_bound_fns=DDPG_CFG.actor_output_bound_fns,
        learning_rate=DDPG_CFG.actor_learning_rate,
        is_training=is_training)

    critic = Critic(
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.critic_input_normalizer,
        input_norm_params=DDPG_CFG.critic_input_norm_params,
        online_action_inputs_training_q=online_action_inputs_training_q,
        online_action_inputs_training_policy=actor.
        online_action_outputs_tensor,
        cond_training_q=cond_training_q,
        target_action_inputs=actor.target_action_outputs_tensor,
        n_fc_units=DDPG_CFG.critic_n_fc_units,
        fc_activations=DDPG_CFG.critic_fc_activations,
        fc_initializers=DDPG_CFG.critic_fc_initializers,
        fc_normalizers=DDPG_CFG.critic_fc_normalizers,
        fc_norm_params=DDPG_CFG.critic_fc_norm_params,
        fc_regularizers=DDPG_CFG.critic_fc_regularizers,
        output_layer_initializer=DDPG_CFG.critic_output_layer_initializer,
        output_layer_regularizer=None,
        learning_rate=DDPG_CFG.critic_learning_rate)

    ## track updates.
    global_step_tensor = tf.train.create_global_step()

    ## build whole graph
    copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver \
      = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor)

    #we save the replay buffer data to files.
    replay_buffer = ReplayBuffer(
        buffer_size=DDPG_CFG.replay_buff_size,
        save_segment_size=DDPG_CFG.replay_buff_save_segment_size,
        save_path=DDPG_CFG.replay_buffer_file_path,
        seed=DDPG_CFG.random_seed)
    if DDPG_CFG.load_replay_buffer_set:
        replay_buffer.load(DDPG_CFG.replay_buffer_file_path)

    sess = tf.Session(graph=tf.get_default_graph())
    summary_writer = tf.summary.FileWriter(logdir=os.path.join(
        DDPG_CFG.log_dir, "train"),
                                           graph=sess.graph)
    log_summary_op = tf.summary.merge_all(key=DDPG_CFG.log_summary_keys)

    sess.run(fetches=[tf.global_variables_initializer()])

    #copy init params from online to target
    sess.run(fetches=[copy_online_to_target_op])

    # Load a previous checkpoint if it exists
    latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir)
    if latest_checkpoint:
        tf.logging.info(
            "==== Loading model checkpoint: {}".format(latest_checkpoint))
        saver.restore(sess, latest_checkpoint)
    elif eval_mode:
        raise FileNotFoundError(
            '== in evaluation mode, we need check point file which can not be found.==='
        )

    ####### start training #########
    obs = train_env.reset()
    transition = preprocess_low_dim(obs)

    n_episodes = 1

    if not eval_mode:
        for step in range(1, DDPG_CFG.num_training_steps):
            #replace with new transition
            policy_out = sess.run(fetches=[actor.online_action_outputs_tensor],
                                  feed_dict={
                                      online_state_inputs:
                                      transition.next_state[np.newaxis, :],
                                      is_training:
                                      False
                                  })[0]
            transition = agent_action_fn(policy_out, replay_buffer, train_env)
            if step % 200 == 0:
                tf.logging.info(' +++++++++++++++++++ global_step:{} action:{}'
                                '  reward:{} term:{}'.format(
                                    step, transition.action, transition.reward,
                                    transition.terminated))
            if step < 10:
                #feed some transitions in buffer.
                continue
            ## ++++ sample mini-batch and train.++++
            state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \
             replay_buffer.sample_batch(DDPG_CFG.batch_size)

            # ---- 1. train policy.-----------
            sess.run(
                fetches=[train_online_policy_op],
                feed_dict={
                    online_state_inputs: state_batch,
                    cond_training_q: False,
                    online_action_inputs_training_q:
                    action_batch,  # feed but not used.
                    is_training: True
                })

            # ---- 2. train q. --------------
            sess.run(fetches=[train_online_q_op],
                     feed_dict={
                         online_state_inputs: state_batch,
                         cond_training_q: True,
                         online_action_inputs_training_q: action_batch,
                         target_state_inputs: next_state_batch,
                         reward_inputs: reward_batch,
                         terminated_inputs: terminated_batch,
                         is_training: True
                     })

            # ----- 3. update target ---------
            sess.run(fetches=[update_target_op], feed_dict=None)

            # do evaluation after eval_freq steps:
            if step % DDPG_CFG.eval_freq == 0:  ##and step > DDPG_CFG.eval_freq:
                evaluate(env=train_env,
                         num_eval_steps=DDPG_CFG.num_eval_steps,
                         preprocess_fn=preprocess_low_dim,
                         estimate_fn=lambda state: sess.run(
                             fetches=[actor.online_action_outputs_tensor],
                             feed_dict={
                                 online_state_inputs: state,
                                 is_training: False
                             }),
                         summary_writer=summary_writer,
                         saver=saver,
                         sess=sess,
                         global_step=step,
                         log_summary_op=log_summary_op,
                         summary_text_tensor=summary_text_tensor)

            if transition.terminated:
                transition = preprocess_low_dim(train_env.reset())
                n_episodes += 1
                continue  # begin new episode

    else:  #eval mode
        evaluate(env=train_env,
                 num_eval_steps=DDPG_CFG.eval_steps_after_training,
                 preprocess_fn=preprocess_low_dim,
                 estimate_fn=lambda state: sess.run(
                     fetches=[actor.online_action_outputs_tensor],
                     feed_dict={
                         online_state_inputs: state,
                         is_training: False
                     }),
                 summary_writer=summary_writer,
                 saver=None,
                 sess=sess,
                 global_step=0,
                 log_summary_op=log_summary_op,
                 summary_text_tensor=summary_text_tensor)

    sess.close()
    train_env.close()
示例#5
0
def train(train_env, monitor_env, agent_action_fn, noise_process):
    '''
    :return:
  '''
    action_space = train_env.action_space
    obs_space = train_env.observation_space

    ######### instantiate actor,critic, replay buffer, uo-process#########
    ## feed online with state. feed target with next_state.
    online_state_inputs = tf.placeholder(tf.float32,
                                         shape=(None, obs_space.shape[0]),
                                         name="online_state_inputs")

    # tf.logging.info('@@@@ online_state_inputs shape:{}'.format(online_state_inputs.shape))
    target_state_inputs = tf.placeholder(tf.float32,
                                         shape=online_state_inputs.shape,
                                         name="target_state_inputs")

    ## inputs to q_net for training q.
    online_action_inputs_training_q = tf.placeholder(
        tf.float32,
        shape=(None, action_space.shape[0]),
        name='online_action_batch_inputs')
    # condition bool scalar to switch action inputs to online q.
    # feed True: training q.
    # feed False: training policy.
    cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q')

    # batch_size vector.
    terminated_inputs = tf.placeholder(tf.float32,
                                       shape=(None),
                                       name='terminated_inputs')
    reward_inputs = tf.placeholder(tf.float32,
                                   shape=(None),
                                   name='rewards_inputs')

    #for l_r decay
    actor_l_r = tf.placeholder(tf.float32, shape=[], name='actor_l_r')
    critic_l_r = tf.placeholder(tf.float32, shape=[], name='critic_l_r')

    #for summary text
    summary_text_tensor = tf.convert_to_tensor(str('summary_text'),
                                               preferred_dtype=string)
    tf.summary.text(name='summary_text',
                    tensor=summary_text_tensor,
                    collections=[DDPG_CFG.log_summary_keys])

    ##instantiate actor, critic.
    actor = Actor(
        action_dim=action_space.shape[0],
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.actor_input_normalizer,
        input_norm_params=DDPG_CFG.actor_input_norm_params,
        n_fc_units=DDPG_CFG.actor_n_fc_units,
        fc_activations=DDPG_CFG.actor_fc_activations,
        fc_initializers=DDPG_CFG.actor_fc_initializers,
        fc_normalizers=DDPG_CFG.actor_fc_normalizers,
        fc_norm_params=DDPG_CFG.actor_fc_norm_params,
        fc_regularizers=DDPG_CFG.actor_fc_regularizers,
        output_layer_initializer=DDPG_CFG.actor_output_layer_initializer,
        # output_layer_regularizer=DDPG_CFG.actor_output_layer_regularizer,
        output_layer_regularizer=None,
        output_normalizers=DDPG_CFG.actor_output_layer_normalizers,
        output_norm_params=DDPG_CFG.actor_output_layer_norm_params,
        # output_normalizers=None,
        # output_norm_params=None,
        output_bound_fns=DDPG_CFG.actor_output_bound_fns,
        learning_rate=actor_l_r,
        is_training=is_training)

    critic = Critic(
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.critic_input_normalizer,
        input_norm_params=DDPG_CFG.critic_input_norm_params,
        online_action_inputs_training_q=online_action_inputs_training_q,
        online_action_inputs_training_policy=actor.
        online_action_outputs_tensor,
        cond_training_q=cond_training_q,
        target_action_inputs=actor.target_action_outputs_tensor,
        n_fc_units=DDPG_CFG.critic_n_fc_units,
        fc_activations=DDPG_CFG.critic_fc_activations,
        fc_initializers=DDPG_CFG.critic_fc_initializers,
        fc_normalizers=DDPG_CFG.critic_fc_normalizers,
        fc_norm_params=DDPG_CFG.critic_fc_norm_params,
        fc_regularizers=DDPG_CFG.critic_fc_regularizers,
        output_layer_initializer=DDPG_CFG.critic_output_layer_initializer,
        output_layer_regularizer=DDPG_CFG.critic_output_layer_regularizer,
        # output_layer_regularizer = None,
        learning_rate=critic_l_r)

    ## track updates.
    global_step_tensor = tf.train.create_global_step()

    ## build whole graph
    copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver,q_loss_tensor \
      = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor)

    #we save the replay buffer data to files.
    replay_buffer = ReplayBuffer(
        buffer_size=DDPG_CFG.replay_buff_size,
        save_segment_size=DDPG_CFG.replay_buff_save_segment_size,
        save_path=DDPG_CFG.replay_buffer_file_path,
        seed=DDPG_CFG.random_seed)
    ##TODO test load replay buffer from files.
    if DDPG_CFG.load_replay_buffer_set:
        replay_buffer.load(DDPG_CFG.replay_buffer_file_path)

    # ===  finish building ddpg graph before this =================#

    ##create tf default session
    sess = tf.Session(graph=tf.get_default_graph())
    '''
  # note: will transfer graph to graphdef now. so we must finish all the computation graph
  # before creating summary writer.
  '''
    summary_writer = tf.summary.FileWriter(logdir=os.path.join(
        DDPG_CFG.log_dir, "train"),
                                           graph=sess.graph)
    actor_summary_op = tf.summary.merge_all(key=DDPG_CFG.actor_summary_keys)
    critic_summary_op = tf.summary.merge_all(key=DDPG_CFG.critic_summary_keys)
    log_summary_op = tf.summary.merge_all(key=DDPG_CFG.log_summary_keys)
    ######### initialize computation graph  ############
    '''
  # -------------trace graphdef only
  whole_graph_def = meta_graph.create_meta_graph_def(graph_def=sess.graph.as_graph_def())
  summary_writer.add_meta_graph(whole_graph_def,global_step=1)
  summary_writer.flush()

  run_options = tf.RunOptions(output_partition_graphs=True, trace_level=tf.RunOptions.FULL_TRACE)
  run_metadata = tf.RunMetadata()

  # including copy target -> online
  sess.run(fetches=[init_op],
           options=run_options,
           run_metadata=run_metadata
           )
  graphdef_part1 = run_metadata.partition_graphs[0]
  meta_graph_part1 = meta_graph.create_meta_graph_def(graph_def=graphdef_part1)
  part1_metagraph_writer = tf.summary.FileWriter(DDPG_CFG.log_dir + '/part1_metagraph')
  part1_metagraph_writer.add_meta_graph(meta_graph_part1)
  part1_metagraph_writer.close()

  graphdef_part2 = run_metadata.partition_graphs[1]
  meta_graph_part2 = meta_graph.create_meta_graph_def(graph_def=graphdef_part2)
  part2_metagraph_writer = tf.summary.FileWriter(DDPG_CFG.log_dir + '/part2_metagraph')
  part2_metagraph_writer.add_meta_graph(meta_graph_part2)
  part2_metagraph_writer.close()
  # --------------- end trace
  '''

    sess.run(fetches=[tf.global_variables_initializer()])

    #copy init params from online to target
    sess.run(fetches=[copy_online_to_target_op])

    # Load a previous checkpoint if it exists
    latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir)
    if latest_checkpoint:
        print("=== Loading model checkpoint: {}".format(latest_checkpoint))
        saver.restore(sess, latest_checkpoint)

    ####### start training #########

    if not DDPG_CFG.train_from_replay_buffer_set_only:
        obs = train_env.reset()
        transition = preprocess_low_dim(obs)

    n_episodes = 1
    update_start = 0.0

    for step in range(1, DDPG_CFG.num_training_steps):
        noise_process.reset()

        #replace with new transition
        if not DDPG_CFG.train_from_replay_buffer_set_only:  #no need new samples
            transition = agent_action_fn(step, sess, actor,
                                         online_state_inputs, is_training,
                                         transition.next_state[np.newaxis, :],
                                         replay_buffer, noise_process,
                                         train_env)
        if step % DDPG_CFG.summary_transition_freq == 0:
            summary_transition(summary_writer, action_space.shape[0],
                               transition, step)

        # after fill replay_buffer with some states, we start learn.
        if step > DDPG_CFG.learn_start:
            # test update duration at first 10 update
            if step < (DDPG_CFG.learn_start + 10):
                update_start = time.time()

            ## ++++ sample mini-batch and train.++++
            state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \
              replay_buffer.sample_batch(DDPG_CFG.batch_size)

            if step % 2000 == 0 and DDPG_CFG.train_from_replay_buffer_set_only:
                tf.logging.info(
                    '@@@@@ train from buffer only -one sample - global_step:{} action:{}'
                    '  reward:{} term:{} @@@@@@@@@@'.format(
                        step, action_batch[0], reward_batch[0],
                        terminated_batch[0]))

            # ---- 1. train policy.-----------
            # no need to feed reward, next_state, terminated which are un-used in policy update.
            # run_options = tf.RunOptions(output_partition_graphs=True, trace_level=tf.RunOptions.FULL_TRACE)
            if 0 == step % DDPG_CFG.summary_freq:
                # run_metadata = tf.RunMetadata()
                _, actor_summary = sess.run(
                    fetches=[train_online_policy_op, actor_summary_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: False,
                        online_action_inputs_training_q:
                        action_batch,  # feed but not used.
                        actor_l_r: l_r_decay(DDPG_CFG.actor_learning_rate,
                                             step),
                        is_training: True
                    })
                # options=run_options,
                # run_metadata=run_metadata)
                # summary_writer._add_graph_def(run_metadata.partition_graphs[0])

                # the policy online network is updated above and will not affect training q.
                # ---- 2. train q. --------------
                _, critic_summary = sess.run(
                    fetches=[train_online_q_op, critic_summary_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: True,
                        online_action_inputs_training_q: action_batch,
                        target_state_inputs: next_state_batch,
                        reward_inputs: reward_batch,
                        reward_inputs: reward_batch,
                        terminated_inputs: terminated_batch,
                        critic_l_r: l_r_decay(DDPG_CFG.critic_learning_rate,
                                              step),
                        is_training: True
                    })

                summary_writer.add_summary(actor_summary)
                summary_writer.add_summary(critic_summary)
                summary_writer.flush()
            else:
                _ = sess.run(
                    fetches=[train_online_policy_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: False,
                        online_action_inputs_training_q:
                        action_batch,  # feed but not used.
                        actor_l_r: l_r_decay(DDPG_CFG.actor_learning_rate,
                                             step),
                        is_training: True
                    })

                # the policy online network is updated above and will not affect training q.
                # ---- 2. train q. --------------
                _, q_loss_value = sess.run(
                    fetches=[train_online_q_op, q_loss_tensor],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: True,
                        online_action_inputs_training_q: action_batch,
                        target_state_inputs: next_state_batch,
                        reward_inputs: reward_batch,
                        terminated_inputs: terminated_batch,
                        critic_l_r: l_r_decay(DDPG_CFG.critic_learning_rate,
                                              step),
                        is_training: True
                    })
                if step % 2000 == 0:
                    tf.logging.info('@@ step:{} q_loss:{}'.format(
                        step, q_loss_value))

            # --end of summary --
            # ----- 3. update target ---------
            # including increment global step.
            _ = sess.run(fetches=[update_target_op], feed_dict=None)

            # test update duration at first 10 update
            if step < (DDPG_CFG.learn_start + 10):
                tf.logging.info(
                    ' @@@@ one batch learn duration @@@@:{}'.format(
                        time.time() - update_start))

            # do evaluation after eval_freq steps:
            if step % DDPG_CFG.eval_freq == 0:  ##and step > DDPG_CFG.eval_freq:
                evaluate(env=monitor_env,
                         num_eval_steps=DDPG_CFG.num_eval_steps,
                         preprocess_fn=preprocess_low_dim,
                         estimate_fn=lambda state: sess.run(
                             fetches=[actor.online_action_outputs_tensor],
                             feed_dict={
                                 online_state_inputs: state,
                                 is_training: False
                             }),
                         summary_writer=summary_writer,
                         saver=saver,
                         sess=sess,
                         global_step=step,
                         log_summary_op=log_summary_op,
                         summary_text_tensor=summary_text_tensor)
                # if monitor_env is train_env:
                #   #torcs share. we should reset
                #   transition.terminated = True #fall through
        #-- end of learn

        #TODO temp solution to on vision .use thread instead
        if step % 2000 == 0:
            v_on = os.path.exists('/home/yuheng/Desktop/train_vision_on')
            if train_env.vision_status == False and v_on:
                train_env.vision_on()  #will display next reset
                transition = preprocess_low_dim(train_env.reset(relaunch=True))
                n_episodes += 1
                tf.logging.info('@@ episodes: {} @@'.format(n_episodes))
                continue
            elif train_env.vision_status == True and not v_on:
                train_env.vision_off()
                transition = preprocess_low_dim(train_env.reset(relaunch=True))
                n_episodes += 1
                tf.logging.info('@@ episodes: {} @@'.format(n_episodes))
                continue

            # if os.path.exists('/home/yuheng/Desktop/eval_vision_on'):
            #   monitor_env.vision_on()  # will display next reset
            # else:
            #   monitor_env.vision_off()

        if (not DDPG_CFG.train_from_replay_buffer_set_only) and (
                transition.terminated):
            # relaunch TORCS every 3 episode because of the memory leak error
            # replace with transition observed after reset.only save state..
            transition = preprocess_low_dim(train_env.reset())
            n_episodes += 1
            tf.logging.info('@@ episodes: {} @@'.format(n_episodes))
            continue  # begin new episode
    # ====end for t.

    sess.close()
    train_env.close()
    monitor_env.close()
def train(train_env, agent_action_fn, eval_mode=False):
    # 仿真环境动作、环境维度
    action_space = train_env.action_space
    obs_space = train_env.observation_space

    ######### instantiate actor,critic, replay buffer, uo-process#########
    ## feed online with state. feed target with next_state.
    online_state_inputs = tf.placeholder(tf.float32,
                                         shape=(None, obs_space.shape[0]),
                                         name="online_state_inputs")

    target_state_inputs = tf.placeholder(tf.float32,
                                         shape=online_state_inputs.shape,
                                         name="target_state_inputs")

    ## inputs to q_net for training q.
    online_action_inputs_training_q = tf.placeholder(
        tf.float32,
        shape=(None, action_space.shape[0]),
        name='online_action_batch_inputs')
    # cond_training_q:用于控制q 网络action输入的条件变量
    # True: training q .
    # False: training policy.
    cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q')

    terminated_inputs = tf.placeholder(tf.float32,
                                       shape=(None),
                                       name='terminated_inputs')
    reward_inputs = tf.placeholder(tf.float32,
                                   shape=(None),
                                   name='rewards_inputs')

    # summary text
    summary_text_tensor = tf.convert_to_tensor(str('summary_text'),
                                               preferred_dtype=string)
    tf.summary.text(name='summary_text',
                    tensor=summary_text_tensor,
                    collections=[DDPG_CFG.log_summary_keys])

    # 创建actor、critic实例
    actor = Actor(
        action_dim=action_space.shape[0],
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.actor_input_normalizer,
        input_norm_params=DDPG_CFG.actor_input_norm_params,
        n_fc_units=DDPG_CFG.actor_n_fc_units,
        fc_activations=DDPG_CFG.actor_fc_activations,
        fc_initializers=DDPG_CFG.actor_fc_initializers,
        fc_normalizers=DDPG_CFG.actor_fc_normalizers,
        fc_norm_params=DDPG_CFG.actor_fc_norm_params,
        fc_regularizers=DDPG_CFG.actor_fc_regularizers,
        output_layer_initializer=DDPG_CFG.actor_output_layer_initializer,
        output_layer_regularizer=None,
        output_normalizers=DDPG_CFG.actor_output_layer_normalizers,
        output_norm_params=DDPG_CFG.actor_output_layer_norm_params,
        output_bound_fns=DDPG_CFG.actor_output_bound_fns,
        learning_rate=DDPG_CFG.actor_learning_rate,
        is_training=is_training)

    critic = Critic(
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.critic_input_normalizer,
        input_norm_params=DDPG_CFG.critic_input_norm_params,
        online_action_inputs_training_q=online_action_inputs_training_q,
        online_action_inputs_training_policy=actor.
        online_action_outputs_tensor,
        cond_training_q=cond_training_q,
        target_action_inputs=actor.target_action_outputs_tensor,
        n_fc_units=DDPG_CFG.critic_n_fc_units,
        fc_activations=DDPG_CFG.critic_fc_activations,
        fc_initializers=DDPG_CFG.critic_fc_initializers,
        fc_normalizers=DDPG_CFG.critic_fc_normalizers,
        fc_norm_params=DDPG_CFG.critic_fc_norm_params,
        fc_regularizers=DDPG_CFG.critic_fc_regularizers,
        output_layer_initializer=DDPG_CFG.critic_output_layer_initializer,
        output_layer_regularizer=None,
        learning_rate=DDPG_CFG.critic_learning_rate)

    # track updates.
    global_step_tensor = tf.train.create_global_step()

    # 构建整个ddpg computation graph
    copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver \
        = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor)

    # 实例化 replay buffer,指定是否将buffer数据保存到文件
    replay_buffer = ReplayBuffer(
        buffer_size=DDPG_CFG.replay_buff_size,
        save_segment_size=DDPG_CFG.replay_buff_save_segment_size,
        save_path=DDPG_CFG.replay_buffer_file_path,
        seed=DDPG_CFG.random_seed)

    # 从文件加载buffer数据
    if DDPG_CFG.load_replay_buffer_set:
        replay_buffer.load(DDPG_CFG.replay_buffer_file_path)

    # 使用summary监控训练中各项数据、参数的变化,并生成图表,在tensorboard中进行观察
    sess = tf.Session(graph=tf.get_default_graph())
    summary_writer = tf.summary.FileWriter(logdir=os.path.join(
        DDPG_CFG.log_dir, "train"),
                                           graph=sess.graph)
    actor_summary_op = tf.summary.merge_all(key=DDPG_CFG.actor_summary_keys)
    critic_summary_op = tf.summary.merge_all(key=DDPG_CFG.critic_summary_keys)
    log_summary_op = tf.summary.merge_all(key=DDPG_CFG.log_summary_keys)

    sess.run(fetches=[tf.global_variables_initializer()])

    # 初始化将online的参数拷贝给target
    sess.run(fetches=[copy_online_to_target_op])

    # #加载之前保存的模型参数checkpoint:
    latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir)
    if latest_checkpoint:
        tf.logging.info(
            "==== Loading model checkpoint: {}".format(latest_checkpoint))
        saver.restore(sess, latest_checkpoint)
    elif eval_mode:
        raise FileNotFoundError(
            '== in evaluation mode, we need check point file which can not be found.==='
        )

    ####### 开始训练 #########
    obs = train_env.reset()
    transition = preprocess_low_dim(obs)

    n_episodes = 1  # episode计数

    # 训练模式
    if not eval_mode:
        # 总共训练DDPG_CFG.num_training_steps
        for step in range(1, DDPG_CFG.num_training_steps):
            # 根据state参数,从online policy网络得到action
            policy_out = sess.run(fetches=[actor.online_action_outputs_tensor],
                                  feed_dict={
                                      online_state_inputs:
                                      transition.next_state[np.newaxis, :],
                                      is_training:
                                      False
                                  })[0]
            # 通过仿真环境执行action,并保存Transition数据到replay buffer
            transition = agent_action_fn(policy_out, replay_buffer, train_env)

            if step % 200 == 0:
                tf.logging.info(' +++++++++++++++++++ global_step:{} action:{}'
                                '  reward:{} term:{}'.format(
                                    step, transition.action, transition.reward,
                                    transition.terminated))
            # feed some transitions in buffer.
            if step < 10:
                continue
            # 从replay buffer采样一个mini-batch
            state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \
                replay_buffer.sample_batch(DDPG_CFG.batch_size)

            if step % DDPG_CFG.summary_frep == 0:
                _, actor_summary = sess.run(
                    fetches=[train_online_policy_op, actor_summary_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: False,
                        online_action_inputs_training_q: action_batch,
                        is_training: True
                    })

                _, critic_summary = sess.run(
                    fetches=[train_online_q_op, critic_summary_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: True,
                        online_action_inputs_training_q: action_batch,
                        target_state_inputs: next_state_batch,
                        reward_inputs: reward_batch,
                        terminated_inputs: terminated_batch,
                        is_training: True
                    })
                summary_writer.add_summary(actor_summary, global_step=step)
                summary_writer.add_summary(critic_summary, global_step=step)
                summary_writer.flush()
            else:
                # ---- 1.训练策略网络-----------
                sess.run(fetches=[train_online_policy_op],
                         feed_dict={
                             online_state_inputs: state_batch,
                             cond_training_q: False,
                             online_action_inputs_training_q: action_batch,
                             is_training: True
                         })

                # ---- 2.训练q网络 --------------
                sess.run(fetches=[train_online_q_op],
                         feed_dict={
                             online_state_inputs: state_batch,
                             cond_training_q: True,
                             online_action_inputs_training_q: action_batch,
                             target_state_inputs: next_state_batch,
                             reward_inputs: reward_batch,
                             terminated_inputs: terminated_batch,
                             is_training: True
                         })

            # ----- 3.更新target网络 ---------
            sess.run(fetches=[update_target_op], feed_dict=None)

            # 每隔 eval_freq steps,我们进行一次evaluation,以便在训练结束后选择好的模型:
            if step % DDPG_CFG.eval_freq == 0:
                evaluate(env=train_env,
                         num_eval_steps=DDPG_CFG.num_eval_steps,
                         preprocess_fn=preprocess_low_dim,
                         estimate_fn=lambda state: sess.run(
                             fetches=[actor.online_action_outputs_tensor],
                             feed_dict={
                                 online_state_inputs: state,
                                 is_training: False
                             }),
                         summary_writer=summary_writer,
                         saver=saver,
                         sess=sess,
                         global_step=step,
                         log_summary_op=log_summary_op,
                         summary_text_tensor=summary_text_tensor)

            if transition.terminated:
                transition = preprocess_low_dim(train_env.reset())
                n_episodes += 1
                continue  # begin new episode
    # eval mode
    else:
        evaluate(env=train_env,
                 num_eval_steps=DDPG_CFG.eval_steps_after_training,
                 preprocess_fn=preprocess_low_dim,
                 estimate_fn=lambda state: sess.run(
                     fetches=[actor.online_action_outputs_tensor],
                     feed_dict={
                         online_state_inputs: state,
                         is_training: False
                     }),
                 summary_writer=summary_writer,
                 saver=None,
                 sess=sess,
                 global_step=0,
                 log_summary_op=log_summary_op,
                 summary_text_tensor=summary_text_tensor)

    sess.close()
    train_env.close()