Esempio n. 1
0
class PlayerTrainer(object):
    def __init__(self,actor,critic,buffersize,game,player,batch_size,gamma):
        self.actor = actor
        self.critic = critic
        self.replay = ReplayBuffer(buffersize)
        self.game =game
        self.player = player
        self.batch_size = batch_size
        self.gamma = gamma


    def noisyMaxQMove(self):
        state = self.game.space
        As = self.actor.predict(np.reshape(state, (1, *state.shape)))
        avail = self.game.avail()
        availQ = {}
        availP = []
        for k in avail:
            availQ[k] = As[0][k]
            availP.append(As[0][k])
        # if sum(availP)> 0:
        availP = np.array(availP)

        availP = [round(i, 5) if i >= 0 else (-.001 * round(i, 5)) for i in availP]
        availNorm = [i / sum(availP) for i in availP]

        a = np.random.choice(avail, p=availNorm)

        self.game.move(a,self.player)
        next_state, reward = self.game.step(self.player)

        self.bufferAdd(state,As,reward,self.game.game_over,next_state)
        if self.replay.size() > self.batch_size:
            s_batch, a_batch, r_batch, t_batch, s2_batch = self.replay.sample_batch(self.batch_size)
            target_q = self.critic.predict_target(s2_batch,self.actor.predict_target(s2_batch))
            y_i = []
            for k in range(self.batch_size):
                if t_batch[k]:
                    y_i.append(r_batch[k])
                else:
                    y_i.append(r_batch[k] + self.gamma * target_q[k])

            predicted_q_value, _ = self.critic.train(
                s_batch, a_batch, np.reshape(y_i, (self.batch_size, 1)))

            #ep_ave_max_q += np.amax(predicted_q_value)

            # Update the actor policy using the sampled gradient
            a_outs = self.actor.predict(s_batch)
            grads = self.critic.action_gradients(s_batch, a_outs)
            self.actor.train(s_batch, grads[0])

            # Update target networks
            self.actor.update_target_network()
            self.actor.update_target_network()
        return self.game.space , reward

    def bufferAdd(self,state,Qs,reward,terminal,next_state):
        self.replay.add(np.reshape(state,(self.actor.s_dim,)),np.reshape(Qs,(self.actor.a_dim,)),reward,terminal,np.reshape(next_state,(self.actor.s_dim,)))
Esempio n. 2
0
def train(train_env, agent_action_fn, eval_mode=False):
    action_space = train_env.action_space
    obs_space = train_env.observation_space

    ######### instantiate actor,critic, replay buffer, uo-process#########
    ## feed online with state. feed target with next_state.
    online_state_inputs = tf.placeholder(tf.float32,
                                         shape=(None, obs_space.shape[0]),
                                         name="online_state_inputs")

    target_state_inputs = tf.placeholder(tf.float32,
                                         shape=online_state_inputs.shape,
                                         name="target_state_inputs")

    ## inputs to q_net for training q.
    online_action_inputs_training_q = tf.placeholder(
        tf.float32,
        shape=(None, action_space.shape[0]),
        name='online_action_batch_inputs')
    # condition bool scalar to switch action inputs to online q.
    # feed True: training q.
    # feed False: training policy.
    cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q')

    terminated_inputs = tf.placeholder(tf.float32,
                                       shape=(None),
                                       name='terminated_inputs')
    reward_inputs = tf.placeholder(tf.float32,
                                   shape=(None),
                                   name='rewards_inputs')

    # for summary text
    summary_text_tensor = tf.convert_to_tensor(str('summary_text'),
                                               preferred_dtype=string)
    tf.summary.text(name='summary_text',
                    tensor=summary_text_tensor,
                    collections=[DDPG_CFG.log_summary_keys])

    ##instantiate actor, critic.
    actor = Actor(
        action_dim=action_space.shape[0],
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.actor_input_normalizer,
        input_norm_params=DDPG_CFG.actor_input_norm_params,
        n_fc_units=DDPG_CFG.actor_n_fc_units,
        fc_activations=DDPG_CFG.actor_fc_activations,
        fc_initializers=DDPG_CFG.actor_fc_initializers,
        fc_normalizers=DDPG_CFG.actor_fc_normalizers,
        fc_norm_params=DDPG_CFG.actor_fc_norm_params,
        fc_regularizers=DDPG_CFG.actor_fc_regularizers,
        output_layer_initializer=DDPG_CFG.actor_output_layer_initializer,
        output_layer_regularizer=None,
        output_normalizers=DDPG_CFG.actor_output_layer_normalizers,
        output_norm_params=DDPG_CFG.actor_output_layer_norm_params,
        output_bound_fns=DDPG_CFG.actor_output_bound_fns,
        learning_rate=DDPG_CFG.actor_learning_rate,
        is_training=is_training)

    critic = Critic(
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.critic_input_normalizer,
        input_norm_params=DDPG_CFG.critic_input_norm_params,
        online_action_inputs_training_q=online_action_inputs_training_q,
        online_action_inputs_training_policy=actor.
        online_action_outputs_tensor,
        cond_training_q=cond_training_q,
        target_action_inputs=actor.target_action_outputs_tensor,
        n_fc_units=DDPG_CFG.critic_n_fc_units,
        fc_activations=DDPG_CFG.critic_fc_activations,
        fc_initializers=DDPG_CFG.critic_fc_initializers,
        fc_normalizers=DDPG_CFG.critic_fc_normalizers,
        fc_norm_params=DDPG_CFG.critic_fc_norm_params,
        fc_regularizers=DDPG_CFG.critic_fc_regularizers,
        output_layer_initializer=DDPG_CFG.critic_output_layer_initializer,
        output_layer_regularizer=None,
        learning_rate=DDPG_CFG.critic_learning_rate)

    ## track updates.
    global_step_tensor = tf.train.create_global_step()

    ## build whole graph
    copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver \
      = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor)

    #we save the replay buffer data to files.
    replay_buffer = ReplayBuffer(
        buffer_size=DDPG_CFG.replay_buff_size,
        save_segment_size=DDPG_CFG.replay_buff_save_segment_size,
        save_path=DDPG_CFG.replay_buffer_file_path,
        seed=DDPG_CFG.random_seed)
    if DDPG_CFG.load_replay_buffer_set:
        replay_buffer.load(DDPG_CFG.replay_buffer_file_path)

    sess = tf.Session(graph=tf.get_default_graph())
    summary_writer = tf.summary.FileWriter(logdir=os.path.join(
        DDPG_CFG.log_dir, "train"),
                                           graph=sess.graph)
    log_summary_op = tf.summary.merge_all(key=DDPG_CFG.log_summary_keys)

    sess.run(fetches=[tf.global_variables_initializer()])

    #copy init params from online to target
    sess.run(fetches=[copy_online_to_target_op])

    # Load a previous checkpoint if it exists
    latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir)
    if latest_checkpoint:
        tf.logging.info(
            "==== Loading model checkpoint: {}".format(latest_checkpoint))
        saver.restore(sess, latest_checkpoint)
    elif eval_mode:
        raise FileNotFoundError(
            '== in evaluation mode, we need check point file which can not be found.==='
        )

    ####### start training #########
    obs = train_env.reset()
    transition = preprocess_low_dim(obs)

    n_episodes = 1

    if not eval_mode:
        for step in range(1, DDPG_CFG.num_training_steps):
            #replace with new transition
            policy_out = sess.run(fetches=[actor.online_action_outputs_tensor],
                                  feed_dict={
                                      online_state_inputs:
                                      transition.next_state[np.newaxis, :],
                                      is_training:
                                      False
                                  })[0]
            transition = agent_action_fn(policy_out, replay_buffer, train_env)
            if step % 200 == 0:
                tf.logging.info(' +++++++++++++++++++ global_step:{} action:{}'
                                '  reward:{} term:{}'.format(
                                    step, transition.action, transition.reward,
                                    transition.terminated))
            if step < 10:
                #feed some transitions in buffer.
                continue
            ## ++++ sample mini-batch and train.++++
            state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \
             replay_buffer.sample_batch(DDPG_CFG.batch_size)

            # ---- 1. train policy.-----------
            sess.run(
                fetches=[train_online_policy_op],
                feed_dict={
                    online_state_inputs: state_batch,
                    cond_training_q: False,
                    online_action_inputs_training_q:
                    action_batch,  # feed but not used.
                    is_training: True
                })

            # ---- 2. train q. --------------
            sess.run(fetches=[train_online_q_op],
                     feed_dict={
                         online_state_inputs: state_batch,
                         cond_training_q: True,
                         online_action_inputs_training_q: action_batch,
                         target_state_inputs: next_state_batch,
                         reward_inputs: reward_batch,
                         terminated_inputs: terminated_batch,
                         is_training: True
                     })

            # ----- 3. update target ---------
            sess.run(fetches=[update_target_op], feed_dict=None)

            # do evaluation after eval_freq steps:
            if step % DDPG_CFG.eval_freq == 0:  ##and step > DDPG_CFG.eval_freq:
                evaluate(env=train_env,
                         num_eval_steps=DDPG_CFG.num_eval_steps,
                         preprocess_fn=preprocess_low_dim,
                         estimate_fn=lambda state: sess.run(
                             fetches=[actor.online_action_outputs_tensor],
                             feed_dict={
                                 online_state_inputs: state,
                                 is_training: False
                             }),
                         summary_writer=summary_writer,
                         saver=saver,
                         sess=sess,
                         global_step=step,
                         log_summary_op=log_summary_op,
                         summary_text_tensor=summary_text_tensor)

            if transition.terminated:
                transition = preprocess_low_dim(train_env.reset())
                n_episodes += 1
                continue  # begin new episode

    else:  #eval mode
        evaluate(env=train_env,
                 num_eval_steps=DDPG_CFG.eval_steps_after_training,
                 preprocess_fn=preprocess_low_dim,
                 estimate_fn=lambda state: sess.run(
                     fetches=[actor.online_action_outputs_tensor],
                     feed_dict={
                         online_state_inputs: state,
                         is_training: False
                     }),
                 summary_writer=summary_writer,
                 saver=None,
                 sess=sess,
                 global_step=0,
                 log_summary_op=log_summary_op,
                 summary_text_tensor=summary_text_tensor)

    sess.close()
    train_env.close()
Esempio n. 3
0
def train(train_env, monitor_env, agent_action_fn, noise_process):
    '''
    :return:
  '''
    action_space = train_env.action_space
    obs_space = train_env.observation_space

    ######### instantiate actor,critic, replay buffer, uo-process#########
    ## feed online with state. feed target with next_state.
    online_state_inputs = tf.placeholder(tf.float32,
                                         shape=(None, obs_space.shape[0]),
                                         name="online_state_inputs")

    # tf.logging.info('@@@@ online_state_inputs shape:{}'.format(online_state_inputs.shape))
    target_state_inputs = tf.placeholder(tf.float32,
                                         shape=online_state_inputs.shape,
                                         name="target_state_inputs")

    ## inputs to q_net for training q.
    online_action_inputs_training_q = tf.placeholder(
        tf.float32,
        shape=(None, action_space.shape[0]),
        name='online_action_batch_inputs')
    # condition bool scalar to switch action inputs to online q.
    # feed True: training q.
    # feed False: training policy.
    cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q')

    # batch_size vector.
    terminated_inputs = tf.placeholder(tf.float32,
                                       shape=(None),
                                       name='terminated_inputs')
    reward_inputs = tf.placeholder(tf.float32,
                                   shape=(None),
                                   name='rewards_inputs')

    #for l_r decay
    actor_l_r = tf.placeholder(tf.float32, shape=[], name='actor_l_r')
    critic_l_r = tf.placeholder(tf.float32, shape=[], name='critic_l_r')

    #for summary text
    summary_text_tensor = tf.convert_to_tensor(str('summary_text'),
                                               preferred_dtype=string)
    tf.summary.text(name='summary_text',
                    tensor=summary_text_tensor,
                    collections=[DDPG_CFG.log_summary_keys])

    ##instantiate actor, critic.
    actor = Actor(
        action_dim=action_space.shape[0],
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.actor_input_normalizer,
        input_norm_params=DDPG_CFG.actor_input_norm_params,
        n_fc_units=DDPG_CFG.actor_n_fc_units,
        fc_activations=DDPG_CFG.actor_fc_activations,
        fc_initializers=DDPG_CFG.actor_fc_initializers,
        fc_normalizers=DDPG_CFG.actor_fc_normalizers,
        fc_norm_params=DDPG_CFG.actor_fc_norm_params,
        fc_regularizers=DDPG_CFG.actor_fc_regularizers,
        output_layer_initializer=DDPG_CFG.actor_output_layer_initializer,
        # output_layer_regularizer=DDPG_CFG.actor_output_layer_regularizer,
        output_layer_regularizer=None,
        output_normalizers=DDPG_CFG.actor_output_layer_normalizers,
        output_norm_params=DDPG_CFG.actor_output_layer_norm_params,
        # output_normalizers=None,
        # output_norm_params=None,
        output_bound_fns=DDPG_CFG.actor_output_bound_fns,
        learning_rate=actor_l_r,
        is_training=is_training)

    critic = Critic(
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.critic_input_normalizer,
        input_norm_params=DDPG_CFG.critic_input_norm_params,
        online_action_inputs_training_q=online_action_inputs_training_q,
        online_action_inputs_training_policy=actor.
        online_action_outputs_tensor,
        cond_training_q=cond_training_q,
        target_action_inputs=actor.target_action_outputs_tensor,
        n_fc_units=DDPG_CFG.critic_n_fc_units,
        fc_activations=DDPG_CFG.critic_fc_activations,
        fc_initializers=DDPG_CFG.critic_fc_initializers,
        fc_normalizers=DDPG_CFG.critic_fc_normalizers,
        fc_norm_params=DDPG_CFG.critic_fc_norm_params,
        fc_regularizers=DDPG_CFG.critic_fc_regularizers,
        output_layer_initializer=DDPG_CFG.critic_output_layer_initializer,
        output_layer_regularizer=DDPG_CFG.critic_output_layer_regularizer,
        # output_layer_regularizer = None,
        learning_rate=critic_l_r)

    ## track updates.
    global_step_tensor = tf.train.create_global_step()

    ## build whole graph
    copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver,q_loss_tensor \
      = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor)

    #we save the replay buffer data to files.
    replay_buffer = ReplayBuffer(
        buffer_size=DDPG_CFG.replay_buff_size,
        save_segment_size=DDPG_CFG.replay_buff_save_segment_size,
        save_path=DDPG_CFG.replay_buffer_file_path,
        seed=DDPG_CFG.random_seed)
    ##TODO test load replay buffer from files.
    if DDPG_CFG.load_replay_buffer_set:
        replay_buffer.load(DDPG_CFG.replay_buffer_file_path)

    # ===  finish building ddpg graph before this =================#

    ##create tf default session
    sess = tf.Session(graph=tf.get_default_graph())
    '''
  # note: will transfer graph to graphdef now. so we must finish all the computation graph
  # before creating summary writer.
  '''
    summary_writer = tf.summary.FileWriter(logdir=os.path.join(
        DDPG_CFG.log_dir, "train"),
                                           graph=sess.graph)
    actor_summary_op = tf.summary.merge_all(key=DDPG_CFG.actor_summary_keys)
    critic_summary_op = tf.summary.merge_all(key=DDPG_CFG.critic_summary_keys)
    log_summary_op = tf.summary.merge_all(key=DDPG_CFG.log_summary_keys)
    ######### initialize computation graph  ############
    '''
  # -------------trace graphdef only
  whole_graph_def = meta_graph.create_meta_graph_def(graph_def=sess.graph.as_graph_def())
  summary_writer.add_meta_graph(whole_graph_def,global_step=1)
  summary_writer.flush()

  run_options = tf.RunOptions(output_partition_graphs=True, trace_level=tf.RunOptions.FULL_TRACE)
  run_metadata = tf.RunMetadata()

  # including copy target -> online
  sess.run(fetches=[init_op],
           options=run_options,
           run_metadata=run_metadata
           )
  graphdef_part1 = run_metadata.partition_graphs[0]
  meta_graph_part1 = meta_graph.create_meta_graph_def(graph_def=graphdef_part1)
  part1_metagraph_writer = tf.summary.FileWriter(DDPG_CFG.log_dir + '/part1_metagraph')
  part1_metagraph_writer.add_meta_graph(meta_graph_part1)
  part1_metagraph_writer.close()

  graphdef_part2 = run_metadata.partition_graphs[1]
  meta_graph_part2 = meta_graph.create_meta_graph_def(graph_def=graphdef_part2)
  part2_metagraph_writer = tf.summary.FileWriter(DDPG_CFG.log_dir + '/part2_metagraph')
  part2_metagraph_writer.add_meta_graph(meta_graph_part2)
  part2_metagraph_writer.close()
  # --------------- end trace
  '''

    sess.run(fetches=[tf.global_variables_initializer()])

    #copy init params from online to target
    sess.run(fetches=[copy_online_to_target_op])

    # Load a previous checkpoint if it exists
    latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir)
    if latest_checkpoint:
        print("=== Loading model checkpoint: {}".format(latest_checkpoint))
        saver.restore(sess, latest_checkpoint)

    ####### start training #########

    if not DDPG_CFG.train_from_replay_buffer_set_only:
        obs = train_env.reset()
        transition = preprocess_low_dim(obs)

    n_episodes = 1
    update_start = 0.0

    for step in range(1, DDPG_CFG.num_training_steps):
        noise_process.reset()

        #replace with new transition
        if not DDPG_CFG.train_from_replay_buffer_set_only:  #no need new samples
            transition = agent_action_fn(step, sess, actor,
                                         online_state_inputs, is_training,
                                         transition.next_state[np.newaxis, :],
                                         replay_buffer, noise_process,
                                         train_env)
        if step % DDPG_CFG.summary_transition_freq == 0:
            summary_transition(summary_writer, action_space.shape[0],
                               transition, step)

        # after fill replay_buffer with some states, we start learn.
        if step > DDPG_CFG.learn_start:
            # test update duration at first 10 update
            if step < (DDPG_CFG.learn_start + 10):
                update_start = time.time()

            ## ++++ sample mini-batch and train.++++
            state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \
              replay_buffer.sample_batch(DDPG_CFG.batch_size)

            if step % 2000 == 0 and DDPG_CFG.train_from_replay_buffer_set_only:
                tf.logging.info(
                    '@@@@@ train from buffer only -one sample - global_step:{} action:{}'
                    '  reward:{} term:{} @@@@@@@@@@'.format(
                        step, action_batch[0], reward_batch[0],
                        terminated_batch[0]))

            # ---- 1. train policy.-----------
            # no need to feed reward, next_state, terminated which are un-used in policy update.
            # run_options = tf.RunOptions(output_partition_graphs=True, trace_level=tf.RunOptions.FULL_TRACE)
            if 0 == step % DDPG_CFG.summary_freq:
                # run_metadata = tf.RunMetadata()
                _, actor_summary = sess.run(
                    fetches=[train_online_policy_op, actor_summary_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: False,
                        online_action_inputs_training_q:
                        action_batch,  # feed but not used.
                        actor_l_r: l_r_decay(DDPG_CFG.actor_learning_rate,
                                             step),
                        is_training: True
                    })
                # options=run_options,
                # run_metadata=run_metadata)
                # summary_writer._add_graph_def(run_metadata.partition_graphs[0])

                # the policy online network is updated above and will not affect training q.
                # ---- 2. train q. --------------
                _, critic_summary = sess.run(
                    fetches=[train_online_q_op, critic_summary_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: True,
                        online_action_inputs_training_q: action_batch,
                        target_state_inputs: next_state_batch,
                        reward_inputs: reward_batch,
                        reward_inputs: reward_batch,
                        terminated_inputs: terminated_batch,
                        critic_l_r: l_r_decay(DDPG_CFG.critic_learning_rate,
                                              step),
                        is_training: True
                    })

                summary_writer.add_summary(actor_summary)
                summary_writer.add_summary(critic_summary)
                summary_writer.flush()
            else:
                _ = sess.run(
                    fetches=[train_online_policy_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: False,
                        online_action_inputs_training_q:
                        action_batch,  # feed but not used.
                        actor_l_r: l_r_decay(DDPG_CFG.actor_learning_rate,
                                             step),
                        is_training: True
                    })

                # the policy online network is updated above and will not affect training q.
                # ---- 2. train q. --------------
                _, q_loss_value = sess.run(
                    fetches=[train_online_q_op, q_loss_tensor],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: True,
                        online_action_inputs_training_q: action_batch,
                        target_state_inputs: next_state_batch,
                        reward_inputs: reward_batch,
                        terminated_inputs: terminated_batch,
                        critic_l_r: l_r_decay(DDPG_CFG.critic_learning_rate,
                                              step),
                        is_training: True
                    })
                if step % 2000 == 0:
                    tf.logging.info('@@ step:{} q_loss:{}'.format(
                        step, q_loss_value))

            # --end of summary --
            # ----- 3. update target ---------
            # including increment global step.
            _ = sess.run(fetches=[update_target_op], feed_dict=None)

            # test update duration at first 10 update
            if step < (DDPG_CFG.learn_start + 10):
                tf.logging.info(
                    ' @@@@ one batch learn duration @@@@:{}'.format(
                        time.time() - update_start))

            # do evaluation after eval_freq steps:
            if step % DDPG_CFG.eval_freq == 0:  ##and step > DDPG_CFG.eval_freq:
                evaluate(env=monitor_env,
                         num_eval_steps=DDPG_CFG.num_eval_steps,
                         preprocess_fn=preprocess_low_dim,
                         estimate_fn=lambda state: sess.run(
                             fetches=[actor.online_action_outputs_tensor],
                             feed_dict={
                                 online_state_inputs: state,
                                 is_training: False
                             }),
                         summary_writer=summary_writer,
                         saver=saver,
                         sess=sess,
                         global_step=step,
                         log_summary_op=log_summary_op,
                         summary_text_tensor=summary_text_tensor)
                # if monitor_env is train_env:
                #   #torcs share. we should reset
                #   transition.terminated = True #fall through
        #-- end of learn

        #TODO temp solution to on vision .use thread instead
        if step % 2000 == 0:
            v_on = os.path.exists('/home/yuheng/Desktop/train_vision_on')
            if train_env.vision_status == False and v_on:
                train_env.vision_on()  #will display next reset
                transition = preprocess_low_dim(train_env.reset(relaunch=True))
                n_episodes += 1
                tf.logging.info('@@ episodes: {} @@'.format(n_episodes))
                continue
            elif train_env.vision_status == True and not v_on:
                train_env.vision_off()
                transition = preprocess_low_dim(train_env.reset(relaunch=True))
                n_episodes += 1
                tf.logging.info('@@ episodes: {} @@'.format(n_episodes))
                continue

            # if os.path.exists('/home/yuheng/Desktop/eval_vision_on'):
            #   monitor_env.vision_on()  # will display next reset
            # else:
            #   monitor_env.vision_off()

        if (not DDPG_CFG.train_from_replay_buffer_set_only) and (
                transition.terminated):
            # relaunch TORCS every 3 episode because of the memory leak error
            # replace with transition observed after reset.only save state..
            transition = preprocess_low_dim(train_env.reset())
            n_episodes += 1
            tf.logging.info('@@ episodes: {} @@'.format(n_episodes))
            continue  # begin new episode
    # ====end for t.

    sess.close()
    train_env.close()
    monitor_env.close()
Esempio n. 4
0
def train(train_env, monitor_env):
  '''
    :return:
  '''
  action_space = train_env.action_space


  ######### instantiate actor,critic, replay buffer, uo-process#########
  ## feed online with state. feed target with next_state.
  online_state_inputs = tf.placeholder(tf.float32,
                                       shape=(None, DDPG_CFG.screen_height, DDPG_CFG.screen_width,
                                              DDPG_CFG.screen_channels * DDPG_CFG.action_rpt),
                                       name="online_state_inputs")

  # tf.logging.info('@@@@ online_state_inputs shape:{}'.format(online_state_inputs.shape))
  target_state_inputs = tf.placeholder(tf.float32,
                                       shape=(None, DDPG_CFG.screen_height, DDPG_CFG.screen_width,
                                              DDPG_CFG.screen_channels * DDPG_CFG.action_rpt),
                                       name="target_state_inputs")

  ## inputs to q_net for training q.
  online_action_inputs_training_q = tf.placeholder(tf.float32,
                                                   shape=(None, action_space.shape[0]),
                                                   name='online_action_batch_inputs'
                                                   )
  # condition bool scalar to switch action inputs to online q.
  # feed True: training q.
  # feed False: training policy.
  cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q')

  # target_action_inputs = tf.placeholder(tf.float32,
  #                                       shape=(None, len(DDPG_CFG.torcs_action_fields)),
  #                                       name='target_action_inputs'
  #                                       )

  # batch_size vector.
  terminated_inputs = tf.placeholder(tf.float32, shape=(None), name='terminated_inputs')
  reward_inputs = tf.placeholder(tf.float32, shape=(None), name='rewards_inputs')

  ##instantiate actor, critic.
  actor = Actor(action_dim=action_space.shape[0],
                online_state_inputs=online_state_inputs,
                target_state_inputs=target_state_inputs,
                conv_n_feature_maps=DDPG_CFG.actor_conv_n_maps,
                conv_kernel_sizes=DDPG_CFG.actor_kernel_sizes,
                conv_strides=DDPG_CFG.actor_conv_strides,
                conv_padding=DDPG_CFG.actor_conv_paddings,
                conv_activations=DDPG_CFG.actor_conv_activations,
                conv_initializers=DDPG_CFG.actor_conv_initializers,
                conv_normalizers=DDPG_CFG.actor_conv_normalizers,
                conv_norm_params=DDPG_CFG.actor_conv_normal_params,
                conv_regularizers=DDPG_CFG.actor_conv_regularizers,
                n_fc_in=DDPG_CFG.actor_n_fc_in,
                n_fc_units=DDPG_CFG.actor_n_fc_units,
                fc_activations=DDPG_CFG.actor_fc_activations,
                fc_initializers=DDPG_CFG.actor_fc_initializers,
                fc_normalizers=DDPG_CFG.actor_fc_normalizers,
                fc_norm_params=DDPG_CFG.actor_fc_norm_params,
                fc_regularizers=DDPG_CFG.actor_fc_regularizers,
                output_layer_initializer=DDPG_CFG.actor_output_layer_initializer,
                output_layer_regularizer=DDPG_CFG.actor_output_layer_regularizer,
                # output_layer_regularizer=None,
                output_bound_fns=DDPG_CFG.actor_output_bound_fns,
                learning_rate=DDPG_CFG.actor_learning_rate,
                is_training=is_training)


  critic = Critic(online_state_inputs=online_state_inputs,
                  target_state_inputs=target_state_inputs,
                  online_action_inputs_training_q=online_action_inputs_training_q,
                  online_action_inputs_training_policy=actor.online_action_outputs_tensor,
                  cond_training_q=cond_training_q,
                  target_action_inputs=actor.target_action_outputs_tensor,
                  conv_n_feature_maps=DDPG_CFG.critic_conv_n_maps,
                  conv_kernel_sizes=DDPG_CFG.critic_kernel_sizes,
                  conv_strides=DDPG_CFG.critic_conv_strides,
                  conv_padding=DDPG_CFG.critic_conv_paddings,
                  conv_activations=DDPG_CFG.critic_conv_activations,
                  conv_initializers=DDPG_CFG.critic_conv_initializers,
                  conv_normalizers=DDPG_CFG.critic_conv_normalizers,
                  conv_norm_params=DDPG_CFG.critic_conv_normal_params,
                  conv_regularizers=DDPG_CFG.critic_conv_regularizers,
                  n_fc_in=DDPG_CFG.critic_n_fc_in,
                  n_fc_units=DDPG_CFG.critic_n_fc_units,
                  fc_activations=DDPG_CFG.critic_fc_activations,
                  fc_initializers=DDPG_CFG.critic_fc_initializers,
                  fc_normalizers=DDPG_CFG.critic_fc_normalizers,
                  fc_norm_params=DDPG_CFG.critic_fc_norm_params,
                  fc_regularizers=DDPG_CFG.critic_fc_regularizers,
                  output_layer_initializer=DDPG_CFG.critic_output_layer_initializer,
                  output_layer_regularizer=DDPG_CFG.critic_output_layer_regularizer,
                  learning_rate=DDPG_CFG.critic_learning_rate)

  ## track updates.
  global_step_tensor = tf.train.create_global_step()

  ## build whole graph
  copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver \
    = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor)

  replay_buffer = ReplayBuffer(buffer_size=DDPG_CFG.replay_buff_size)

  # noise shape (3,)
  noise_process = UO_Process(mu=np.zeros(shape=action_space.shape))

  # ===  finish building ddpg graph before this =================#

  ##create tf default session
  sess = tf.Session(graph=tf.get_default_graph())
  '''
  # note: will transfer graph to graphdef now. so we must finish all the computation graph
  # before creating summary writer.
  '''
  summary_writer = tf.summary.FileWriter(logdir=os.path.join(DDPG_CFG.log_dir, "train"),
                                         graph=sess.graph)
  actor_summary_op = tf.summary.merge_all(key=DDPG_CFG.actor_summary_keys)
  critic_summary_op = tf.summary.merge_all(key=DDPG_CFG.critic_summary_keys)
  ######### initialize computation graph  ############

  '''
  # -------------trace graphdef only
  whole_graph_def = meta_graph.create_meta_graph_def(graph_def=sess.graph.as_graph_def())
  summary_writer.add_meta_graph(whole_graph_def,global_step=1)
  summary_writer.flush()

  run_options = tf.RunOptions(output_partition_graphs=True, trace_level=tf.RunOptions.FULL_TRACE)
  run_metadata = tf.RunMetadata()

  # including copy target -> online
  sess.run(fetches=[init_op],
           options=run_options,
           run_metadata=run_metadata
           )
  graphdef_part1 = run_metadata.partition_graphs[0]
  meta_graph_part1 = meta_graph.create_meta_graph_def(graph_def=graphdef_part1)
  part1_metagraph_writer = tf.summary.FileWriter(DDPG_CFG.log_dir + '/part1_metagraph')
  part1_metagraph_writer.add_meta_graph(meta_graph_part1)
  part1_metagraph_writer.close()

  graphdef_part2 = run_metadata.partition_graphs[1]
  meta_graph_part2 = meta_graph.create_meta_graph_def(graph_def=graphdef_part2)
  part2_metagraph_writer = tf.summary.FileWriter(DDPG_CFG.log_dir + '/part2_metagraph')
  part2_metagraph_writer.add_meta_graph(meta_graph_part2)
  part2_metagraph_writer.close()
  # --------------- end trace
  '''

  sess.run(fetches=[tf.global_variables_initializer()])

  #copy init params from online to target
  sess.run(fetches=[copy_online_to_target_op])

  # Load a previous checkpoint if it exists
  latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir)
  if latest_checkpoint:
    print("=== Loading model checkpoint: {}".format(latest_checkpoint))
    saver.restore(sess, latest_checkpoint)

  ####### start training #########
  ##ddpg algo-1 processing

  # episode_reward_moving_average = 0.0
  # episode_steps_moving_average = 0

  episode = 1
  # track one epoch consuming time
  # epoch_start_time = time.time()  #in seconds
  # epoches = 0

  obs = train_env.reset()
  n_episodes = 1
  # we dont store the 1st frame.just stacked and used as input to policy network to generate action.
  transition = preprocess_img(frames=[obs])

  # while epoches < DDPG_CFG.num_training_epoches:
  update_start = 0.0
  for step in range(1, DDPG_CFG.num_training_steps):
    # episode_reward = 0.0
    # episode_steps = 0
    noise_process.reset()

    # for t in range(1,DDPG_CFG.num_timesteps_per_episode):

    #make random play at beginning . to fill some frames in replay buffer.
    if step < DDPG_CFG.learn_start:
      # stochastic_action = [np.random.uniform(low,high) for (low,high) in zip(action_space.low, action_space.high)]
      #give some speed at beginning.
      stochastic_action = [None]*3
      stochastic_action[DDPG_CFG.policy_output_idx_steer] = np.random.uniform(-0.1,0.1)
      stochastic_action[DDPG_CFG.policy_output_idx_accel] = np.random.uniform(0.5, 1.0)
      stochastic_action[DDPG_CFG.policy_output_idx_brake] = np.random.uniform(0.01, 0.1)
    else:
      ## calc a_t = mu(s_t) + Noise
      ## FF once to fetch the mu(s_t)
      ## --out[0]:steer, out[1]:accel, out[2]:brake--
      # episode_steps = t
      policy_output = sess.run(fetches=[actor.online_action_outputs_tensor],
            feed_dict={online_state_inputs: transition.next_state[np.newaxis,:,:,:],
                       is_training:False})  # must reshape to (1,64,64,9)
      policy_output=policy_output[0]

      #TODO anneal random prob of actions: from high prob of accel to low prob do nothing.
      if step % 7 ==0 or step < (DDPG_CFG.learn_start + 30):
        #tf.logging.info('@@@@@@ policy output:{}  @@@@@@'.format(policy_output))
        # we add some random speed:
        policy_output[0][DDPG_CFG.policy_output_idx_steer] = np.random.uniform(-0.1, 0.1)
        policy_output[0][DDPG_CFG.policy_output_idx_accel] += np.random.uniform(0.8, 1.0)
        policy_output[0][DDPG_CFG.policy_output_idx_brake] += np.random.uniform(-0.9, 0.1)
      ##add noise and bound
      stochastic_action=policy_output_to_stochastic_action(policy_output, noise_process, action_space)



    ## excute a_t and store Transition.
    (frames, reward, terminated) = action_repeat_steps(train_env, stochastic_action)
    # episode_reward += reward

    if step % 50 == 0:
      tf.logging.info('@@@@@@@@@@ global_step:{} action:{}  reward:{} term:{} @@@@@@@@@@'.format(step,stochastic_action,reward,terminated))

    # replace transition with new one.
    transition = preprocess_img(action=stochastic_action,
        reward=reward,
        terminated=terminated,
        frames=frames)

    ##even if terminated ,we still save next_state cause FF Q network
    # will use it, but will discard Q value in the end.
    replay_buffer.store(transition)

    # after fill replay_buffer with some frames, we start learn.
    if step > DDPG_CFG.learn_start:
      # test update duration at first 10 update
      if step < (DDPG_CFG.learn_start +10):
        update_start = time.time()

      ## ++++ sample mini-batch and train.++++
      state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \
        replay_buffer.sample_batch(DDPG_CFG.batch_size)

      # FP/BP to SGD, SGA update online mu and Q.
      ## training op will update online then soft-update target.
      # tf.logging.info('@@@@ state batch shape:{}'.format(state_batch.shape))


      # ---- 1. train policy.-----------
      # no need to feed reward, next_state, terminated which are un-used in policy update.
      # run_options = tf.RunOptions(output_partition_graphs=True, trace_level=tf.RunOptions.FULL_TRACE)
      if 0 == step % 20 :
        # run_metadata = tf.RunMetadata()
        _, actor_summary = sess.run(fetches=[train_online_policy_op,actor_summary_op],
                                   feed_dict={online_state_inputs: state_batch,
                                              cond_training_q: False,
                                              online_action_inputs_training_q: action_batch,  # feed but not used.
                                              is_training:True})
                                    # options=run_options,
                                    # run_metadata=run_metadata)
        # summary_writer._add_graph_def(run_metadata.partition_graphs[0])

        # the policy online network is updated above and will not affect training q.
        # ---- 2. train q. --------------
        _, critic_summary = sess.run(fetches=[train_online_q_op, critic_summary_op],
                                   feed_dict={
                                     online_state_inputs: state_batch,
                                     cond_training_q: True,
                                     online_action_inputs_training_q: action_batch,
                                     target_state_inputs: next_state_batch,
                                     reward_inputs: reward_batch,
                                     terminated_inputs: terminated_batch,
                                     is_training:True})

        summary_writer.add_summary(actor_summary)
        summary_writer.add_summary(critic_summary)
        summary_writer.flush()
      else:
        _ = sess.run(fetches=[train_online_policy_op],
                                 feed_dict={online_state_inputs: state_batch,
                                            cond_training_q: False,
                                            online_action_inputs_training_q: action_batch,  # feed but not used.
                                            is_training: True
                                            })

        # the policy online network is updated above and will not affect training q.
        # ---- 2. train q. --------------
        _ = sess.run(fetches=[train_online_q_op],
                                  feed_dict={
                                    online_state_inputs: state_batch,
                                    cond_training_q: True,
                                    online_action_inputs_training_q: action_batch,
                                    target_state_inputs: next_state_batch,
                                    reward_inputs: reward_batch,
                                    terminated_inputs: terminated_batch,
                                    is_training: True})


      # ----- 3. update target ---------
      # including increment global step.
      _ = sess.run(fetches=[update_target_op],
                   feed_dict=None)

      # test update duration at first 10 update
      if step < (DDPG_CFG.learn_start +10):
        tf.logging.info(' @@@@ one update duration @@@@:{}'.format(time.time() - update_start))

      # do evaluation after eval_freq steps:
      if step % DDPG_CFG.eval_freq == 0: ##and step > DDPG_CFG.eval_freq:
        evaluate(env=monitor_env,
                 num_eval_steps=DDPG_CFG.num_eval_steps,
                 preprocess_fn=preprocess_img,
                 estimate_fn=lambda state: sess.run(fetches=[actor.online_action_outputs_tensor],
                                                    feed_dict={online_state_inputs:state,
                                                    is_training:False} ),
                 summary_writer=summary_writer,
                 saver=saver, sess=sess, global_step=step)
    #-- end of learn

    if (transition.terminated):
      new_obs = train_env.reset()  # relaunch TORCS every 3 episode because of the memory leak error
      # replace with transition observed after reset.only save frames.
      transition = preprocess_img(frames=[new_obs])
      n_episodes +=1
      tf.logging.info('@@ episodes: {} @@'.format(n_episodes))
      continue  # begin new episode
      # ====end for t. end of one episode ====

      # ---end of save---
  # ---end for episode---

  sess.close()
  eval_monitor.close()
  train_env.close()
  monitor_env.close()
def train(train_env, agent_action_fn, eval_mode=False):
    # 仿真环境动作、环境维度
    action_space = train_env.action_space
    obs_space = train_env.observation_space

    ######### instantiate actor,critic, replay buffer, uo-process#########
    ## feed online with state. feed target with next_state.
    online_state_inputs = tf.placeholder(tf.float32,
                                         shape=(None, obs_space.shape[0]),
                                         name="online_state_inputs")

    target_state_inputs = tf.placeholder(tf.float32,
                                         shape=online_state_inputs.shape,
                                         name="target_state_inputs")

    ## inputs to q_net for training q.
    online_action_inputs_training_q = tf.placeholder(
        tf.float32,
        shape=(None, action_space.shape[0]),
        name='online_action_batch_inputs')
    # cond_training_q:用于控制q 网络action输入的条件变量
    # True: training q .
    # False: training policy.
    cond_training_q = tf.placeholder(tf.bool, shape=[], name='cond_training_q')

    terminated_inputs = tf.placeholder(tf.float32,
                                       shape=(None),
                                       name='terminated_inputs')
    reward_inputs = tf.placeholder(tf.float32,
                                   shape=(None),
                                   name='rewards_inputs')

    # summary text
    summary_text_tensor = tf.convert_to_tensor(str('summary_text'),
                                               preferred_dtype=string)
    tf.summary.text(name='summary_text',
                    tensor=summary_text_tensor,
                    collections=[DDPG_CFG.log_summary_keys])

    # 创建actor、critic实例
    actor = Actor(
        action_dim=action_space.shape[0],
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.actor_input_normalizer,
        input_norm_params=DDPG_CFG.actor_input_norm_params,
        n_fc_units=DDPG_CFG.actor_n_fc_units,
        fc_activations=DDPG_CFG.actor_fc_activations,
        fc_initializers=DDPG_CFG.actor_fc_initializers,
        fc_normalizers=DDPG_CFG.actor_fc_normalizers,
        fc_norm_params=DDPG_CFG.actor_fc_norm_params,
        fc_regularizers=DDPG_CFG.actor_fc_regularizers,
        output_layer_initializer=DDPG_CFG.actor_output_layer_initializer,
        output_layer_regularizer=None,
        output_normalizers=DDPG_CFG.actor_output_layer_normalizers,
        output_norm_params=DDPG_CFG.actor_output_layer_norm_params,
        output_bound_fns=DDPG_CFG.actor_output_bound_fns,
        learning_rate=DDPG_CFG.actor_learning_rate,
        is_training=is_training)

    critic = Critic(
        online_state_inputs=online_state_inputs,
        target_state_inputs=target_state_inputs,
        input_normalizer=DDPG_CFG.critic_input_normalizer,
        input_norm_params=DDPG_CFG.critic_input_norm_params,
        online_action_inputs_training_q=online_action_inputs_training_q,
        online_action_inputs_training_policy=actor.
        online_action_outputs_tensor,
        cond_training_q=cond_training_q,
        target_action_inputs=actor.target_action_outputs_tensor,
        n_fc_units=DDPG_CFG.critic_n_fc_units,
        fc_activations=DDPG_CFG.critic_fc_activations,
        fc_initializers=DDPG_CFG.critic_fc_initializers,
        fc_normalizers=DDPG_CFG.critic_fc_normalizers,
        fc_norm_params=DDPG_CFG.critic_fc_norm_params,
        fc_regularizers=DDPG_CFG.critic_fc_regularizers,
        output_layer_initializer=DDPG_CFG.critic_output_layer_initializer,
        output_layer_regularizer=None,
        learning_rate=DDPG_CFG.critic_learning_rate)

    # track updates.
    global_step_tensor = tf.train.create_global_step()

    # 构建整个ddpg computation graph
    copy_online_to_target_op, train_online_policy_op, train_online_q_op, update_target_op, saver \
        = build_ddpg_graph(actor, critic, reward_inputs, terminated_inputs, global_step_tensor)

    # 实例化 replay buffer,指定是否将buffer数据保存到文件
    replay_buffer = ReplayBuffer(
        buffer_size=DDPG_CFG.replay_buff_size,
        save_segment_size=DDPG_CFG.replay_buff_save_segment_size,
        save_path=DDPG_CFG.replay_buffer_file_path,
        seed=DDPG_CFG.random_seed)

    # 从文件加载buffer数据
    if DDPG_CFG.load_replay_buffer_set:
        replay_buffer.load(DDPG_CFG.replay_buffer_file_path)

    # 使用summary监控训练中各项数据、参数的变化,并生成图表,在tensorboard中进行观察
    sess = tf.Session(graph=tf.get_default_graph())
    summary_writer = tf.summary.FileWriter(logdir=os.path.join(
        DDPG_CFG.log_dir, "train"),
                                           graph=sess.graph)
    actor_summary_op = tf.summary.merge_all(key=DDPG_CFG.actor_summary_keys)
    critic_summary_op = tf.summary.merge_all(key=DDPG_CFG.critic_summary_keys)
    log_summary_op = tf.summary.merge_all(key=DDPG_CFG.log_summary_keys)

    sess.run(fetches=[tf.global_variables_initializer()])

    # 初始化将online的参数拷贝给target
    sess.run(fetches=[copy_online_to_target_op])

    # #加载之前保存的模型参数checkpoint:
    latest_checkpoint = tf.train.latest_checkpoint(DDPG_CFG.checkpoint_dir)
    if latest_checkpoint:
        tf.logging.info(
            "==== Loading model checkpoint: {}".format(latest_checkpoint))
        saver.restore(sess, latest_checkpoint)
    elif eval_mode:
        raise FileNotFoundError(
            '== in evaluation mode, we need check point file which can not be found.==='
        )

    ####### 开始训练 #########
    obs = train_env.reset()
    transition = preprocess_low_dim(obs)

    n_episodes = 1  # episode计数

    # 训练模式
    if not eval_mode:
        # 总共训练DDPG_CFG.num_training_steps
        for step in range(1, DDPG_CFG.num_training_steps):
            # 根据state参数,从online policy网络得到action
            policy_out = sess.run(fetches=[actor.online_action_outputs_tensor],
                                  feed_dict={
                                      online_state_inputs:
                                      transition.next_state[np.newaxis, :],
                                      is_training:
                                      False
                                  })[0]
            # 通过仿真环境执行action,并保存Transition数据到replay buffer
            transition = agent_action_fn(policy_out, replay_buffer, train_env)

            if step % 200 == 0:
                tf.logging.info(' +++++++++++++++++++ global_step:{} action:{}'
                                '  reward:{} term:{}'.format(
                                    step, transition.action, transition.reward,
                                    transition.terminated))
            # feed some transitions in buffer.
            if step < 10:
                continue
            # 从replay buffer采样一个mini-batch
            state_batch, action_batch, reward_batch, next_state_batch, terminated_batch = \
                replay_buffer.sample_batch(DDPG_CFG.batch_size)

            if step % DDPG_CFG.summary_frep == 0:
                _, actor_summary = sess.run(
                    fetches=[train_online_policy_op, actor_summary_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: False,
                        online_action_inputs_training_q: action_batch,
                        is_training: True
                    })

                _, critic_summary = sess.run(
                    fetches=[train_online_q_op, critic_summary_op],
                    feed_dict={
                        online_state_inputs: state_batch,
                        cond_training_q: True,
                        online_action_inputs_training_q: action_batch,
                        target_state_inputs: next_state_batch,
                        reward_inputs: reward_batch,
                        terminated_inputs: terminated_batch,
                        is_training: True
                    })
                summary_writer.add_summary(actor_summary, global_step=step)
                summary_writer.add_summary(critic_summary, global_step=step)
                summary_writer.flush()
            else:
                # ---- 1.训练策略网络-----------
                sess.run(fetches=[train_online_policy_op],
                         feed_dict={
                             online_state_inputs: state_batch,
                             cond_training_q: False,
                             online_action_inputs_training_q: action_batch,
                             is_training: True
                         })

                # ---- 2.训练q网络 --------------
                sess.run(fetches=[train_online_q_op],
                         feed_dict={
                             online_state_inputs: state_batch,
                             cond_training_q: True,
                             online_action_inputs_training_q: action_batch,
                             target_state_inputs: next_state_batch,
                             reward_inputs: reward_batch,
                             terminated_inputs: terminated_batch,
                             is_training: True
                         })

            # ----- 3.更新target网络 ---------
            sess.run(fetches=[update_target_op], feed_dict=None)

            # 每隔 eval_freq steps,我们进行一次evaluation,以便在训练结束后选择好的模型:
            if step % DDPG_CFG.eval_freq == 0:
                evaluate(env=train_env,
                         num_eval_steps=DDPG_CFG.num_eval_steps,
                         preprocess_fn=preprocess_low_dim,
                         estimate_fn=lambda state: sess.run(
                             fetches=[actor.online_action_outputs_tensor],
                             feed_dict={
                                 online_state_inputs: state,
                                 is_training: False
                             }),
                         summary_writer=summary_writer,
                         saver=saver,
                         sess=sess,
                         global_step=step,
                         log_summary_op=log_summary_op,
                         summary_text_tensor=summary_text_tensor)

            if transition.terminated:
                transition = preprocess_low_dim(train_env.reset())
                n_episodes += 1
                continue  # begin new episode
    # eval mode
    else:
        evaluate(env=train_env,
                 num_eval_steps=DDPG_CFG.eval_steps_after_training,
                 preprocess_fn=preprocess_low_dim,
                 estimate_fn=lambda state: sess.run(
                     fetches=[actor.online_action_outputs_tensor],
                     feed_dict={
                         online_state_inputs: state,
                         is_training: False
                     }),
                 summary_writer=summary_writer,
                 saver=None,
                 sess=sess,
                 global_step=0,
                 log_summary_op=log_summary_op,
                 summary_text_tensor=summary_text_tensor)

    sess.close()
    train_env.close()