Exemplo n.º 1
0
def main(_):
  gpu_options = tf.GPUOptions(
      per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))

  with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    config = get_config(FLAGS) or FLAGS

    if config.env_type == 'simple':
      env = SimpleGymEnvironment(config)
    else:
      env = GymEnvironment(config)

    if FLAGS.cpu:
      config.cnn_format = 'NHWC'

    agent = Agent(config, env, sess)

    if FLAGS.save_weight:
      agent.save_weight_to_pkl()
    if FLAGS.load_weight:
      agent.load_weight_from_pkl(cpu_mode=FLAGS.cpu)

    if FLAGS.is_train:
      agent.train()
    else:
      agent.play()
Exemplo n.º 2
0
def main(_):
  gpu_options = tf.GPUOptions(
      per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))

  with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    config = get_config(FLAGS) or FLAGS

    if config.env_type == 'simple':
      env = SimpleGymEnvironment(config)
    else:
      env = GymEnvironment(config)

    if FLAGS.use_gpu:
      config.cnn_format = 'NCHW'

    agent = Agent(config, env, sess)

    if FLAGS.is_train:
      agent.train()
    else:
      agent.play()
Exemplo n.º 3
0
def main(_):
  gpu_options = tf.GPUOptions(
      per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))

  with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True, log_device_placement=True)) as sess:
  # with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    config = get_config(FLAGS) or FLAGS

    if config.env_type == 'simple':
      env = SimpleGymEnvironment(config)
    else:
      env = GymEnvironment(config)

    if not FLAGS.use_gpu:
      config.cnn_format = 'NHWC'
    
    with tf.device('/gpu:2'):
        agent = Agent(config, env, sess)

    if FLAGS.is_train:
      agent.train()
    else:
      agent.play()
Exemplo n.º 4
0
def main(_):
    gpu_options = tf.GPUOptions(allow_growth=True, visible_device_list='0')

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        config = get_config(FLAGS) or FLAGS

        if config.env_type == 'simple':
            env = SimpleGymEnvironment(config)
        else:
            env = GymEnvironment(config)

        if not tf.test.is_gpu_available() and FLAGS.use_gpu:
            raise Exception("use_gpu flag is true when no GPUs are available")

        if not FLAGS.use_gpu:
            config.cnn_format = 'NHWC'

        agent = Agent(config, env, sess)

        if FLAGS.is_train:
            agent.train()
        else:
            agent.play()
Exemplo n.º 5
0
def main(_):
    # tensorflow 在执行过程中会默认使用全部的 GPU 内存,给系统保留 200 M,因此我们可以使用如下语句指定 GPU 内存的分配比例:
    if FLAGS.gpu_fraction == '':
        raise ValueError("--gpu_fraction should be defined")
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))
    # 在终端监视:watch -n 10 nvidia-smi

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:

        config = DQNConfig(FLAGS) or FLAGS
        print("\n [*] Current Configuration")
        pp(config.list_all_member())

        # Notice before the process
        # Code in remoteApi.start(19999) in Vrep otherwise it may cause some unpredictable problem

        if not tf.test.is_gpu_available() and FLAGS.use_gpu:
            raise Exception("use_gpu flag is true when no GPUs are available")

        if config.is_train:
            env = DQNEnvironment(config)
            agent = Agent(config, env, sess)
            agent.train()
        else:
            if config.is_sim:
                env = DQNEnvironment(config)
                agent = Agent(config, env, sess)
                agent.play()
                agent.randomplay()
            else:
                from experiment.environment import REALEnvironment
                env = REALEnvironment(config)
                agent = Agent(config, env, sess)
                agent.exp_play()

        env.close()
Exemplo n.º 6
0
def main(_):
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        config = get_config(FLAGS) or FLAGS

        if config.env_type == 'simple':
            env = SimpleGymEnvironment(config)
        else:
            env = GymEnvironment(config)

        if not tf.test.is_gpu_available() and FLAGS.use_gpu:
            raise Exception("use_gpu flag is true when no GPUs are available")

        if not FLAGS.use_gpu:
            config.cnn_format = 'NHWC'

        agent = Agent(config, env, sess)

        if FLAGS.is_train:
            agent.train()
        else:
            agent.play()
Exemplo n.º 7
0
def main(_):
    # Trying to request all the GPU memory will fail, since the system
    # always allocates a little memory on each GPU for itself. Only set
    # up a GPU configuration if fractional amount of memory is requested.
    tf_config = None
    gpu_fraction = calc_gpu_fraction(FLAGS.gpu_fraction)
    if gpu_fraction < 1:
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
        tf_config = tf.ConfigProto(gpu_options=gpu_options)

    with tf.Session(config=tf_config) as sess:
        config = get_config(FLAGS) or FLAGS
        env = GymEnvironment(config)

        # Change data format for running on a CPU.
        if not FLAGS.use_gpu:
            config.cnn_format = 'NHWC'

        agent = Agent(config, env, sess)

        if FLAGS.train:
            agent.train()
        else:
            agent.play()
Exemplo n.º 8
0
def main(_):
  gpu_options = tf.GPUOptions(
      per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))

  with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    config = get_config(FLAGS) or FLAGS

    if config.env_type == 'simple':
      env = SimpleGymEnvironment(config)
    else:
      env = GymEnvironment(config)

    if not tf.test.is_gpu_available() and FLAGS.use_gpu:
      raise Exception("use_gpu flag is true when no GPUs are available")

    if not FLAGS.use_gpu:
      config.cnn_format = 'NHWC'

    agent = Agent(config, env, sess)

    if FLAGS.is_train:
      agent.train()
    else:
      agent.play()
Exemplo n.º 9
0
def train(sess, config):

    env = GymEnvironment(config)

    log_dir = './log/{}_lookahead_{}_gats_{}/'.format(config.env_name,
                                                      config.lookahead,
                                                      config.gats)
    checkpoint_dir = os.path.join(log_dir, 'checkpoints/')
    image_dir = os.path.join(log_dir, 'rollout/')
    if os.path.isdir(log_dir):
        shutil.rmtree(log_dir)
        print(' [*] Removed log dir: ' + log_dir)

    with tf.variable_scope('step'):
        step_op = tf.Variable(0, trainable=False, name='step')
        step_input = tf.placeholder('int32', None, name='step_input')
        step_assign_op = step_op.assign(step_input)

    with tf.variable_scope('summary'):
        scalar_summary_tags = [
            'average.reward', 'average.loss', 'average.q value',
            'episode.max reward', 'episode.min reward', 'episode.avg reward',
            'episode.num of game', 'training.learning_rate', 'rp.rp_accuracy',
            'rp.rp_plus_accuracy', 'rp.rp_minus_accuracy',
            'rp.nonzero_rp_accuracy'
        ]

        summary_placeholders = {}
        summary_ops = {}

        for tag in scalar_summary_tags:
            summary_placeholders[tag] = tf.placeholder('float32',
                                                       None,
                                                       name=tag.replace(
                                                           ' ', '_'))
            summary_ops[tag] = tf.summary.scalar(
                "%s-%s/%s" % (config.env_name, config.env_type, tag),
                summary_placeholders[tag])

        histogram_summary_tags = ['episode.rewards', 'episode.actions']

        for tag in histogram_summary_tags:
            summary_placeholders[tag] = tf.placeholder('float32',
                                                       None,
                                                       name=tag.replace(
                                                           ' ', '_'))
            summary_ops[tag] = tf.summary.histogram(tag,
                                                    summary_placeholders[tag])

    config.num_actions = env.action_size
    # config.num_actions = 3

    exploration = LinearSchedule(config.epsilon_end_t, config.epsilon_end)

    agent = Agent(sess, config, num_actions=config.num_actions)

    if config.gats:
        lookahead = config.lookahead
        rp_train_frequency = 4
        gdm_train_frequency = 4
        gdm = GDM(sess, config, num_actions=config.num_actions)
        rp = RP(sess, config, num_actions=config.num_actions)
        leaves_size = config.num_actions**config.lookahead
        if config.dyna:
            gan_memory = GANReplayMemory(config)
        else:
            gan_memory = None

        def base_generator():
            tree_base = np.zeros((leaves_size, lookahead)).astype('uint8')
            for i in range(leaves_size):
                n = i
                j = 0
                while n:
                    n, r = divmod(n, config.num_actions)
                    tree_base[i, lookahead - 1 - j] = r
                    j = j + 1
            return tree_base

        tree_base = base_generator()

    # memory = ReplayMemory(config)
    memory = ReplayMemory(config, log_dir)
    history = History(config)

    tf.global_variables_initializer().run()
    saver = tf.train.Saver(max_to_keep=30)

    # model load, if exist ckpt.
    load_model(sess, saver, checkpoint_dir)

    agent.updated_target_q_network()

    writer = tf.summary.FileWriter(log_dir, sess.graph)

    num_game, update_count, ep_reward = 0, 0, 0.
    total_reward, total_loss, total_q_value = 0., 0., 0.
    max_avg_ep_reward = -100
    ep_rewards, actions = [], []

    rp_accuracy = []
    rp_plus_accuracy = []
    rp_minus_accuracy = []
    nonzero_rp_accuracy = []

    screen, reward, action, terminal = env.new_random_game()

    # init state
    for _ in range(config.history_length):
        history.add(screen)

    start_step = step_op.eval()

    # main
    for step in tqdm(range(start_step, config.max_step),
                     ncols=70,
                     initial=start_step):

        if step == config.learn_start:
            num_game, update_count, ep_reward = 0, 0, 0.
            total_reward, total_loss, total_q_value = 0., 0., 0.
            ep_rewards, actions = [], []

        if step == config.gan_dqn_learn_start:
            rp_accuracy = []
            rp_plus_accuracy = []
            rp_minus_accuracy = []
            nonzero_rp_accuracy = []

        # ε-greedy
        MCTS_FLAG = False
        epsilon = exploration.value(step)
        if random.random() < epsilon:
            action = random.randrange(config.num_actions)
        else:
            current_state = norm_frame(np.expand_dims(history.get(), axis=0))
            if config.gats and (step >= config.gan_dqn_learn_start):
                action, predicted_reward = MCTS_planning(
                    gdm, rp, agent, current_state, leaves_size, tree_base,
                    config, exploration, step, gan_memory)
                MCTS_FLAG = True
            else:
                action = agent.get_action(
                    norm_frame_Q(unnorm_frame(current_state)))

        # GATS用?
        apply_action = action
        # if int(apply_action != 0):
        #     apply_action += 1

        # Observe
        screen, reward, terminal = env.act(apply_action, is_training=True)
        reward = max(config.min_reward, min(config.max_reward, reward))
        history.add(screen)
        memory.add(screen, reward, action, terminal)

        if MCTS_FLAG:
            rp_accuracy.append(int(predicted_reward == reward))
            if reward != 0:
                nonzero_rp_accuracy.append(int(predicted_reward == reward))
                if reward == 1:
                    rp_plus_accuracy.append(int(predicted_reward == reward))
                elif reward == -1:
                    rp_minus_accuracy.append(int(predicted_reward == reward))

        # Train
        if step > config.gan_learn_start and config.gats:
            if step % rp_train_frequency == 0 and memory.can_sample(
                    config.rp_batch_size):
                obs, act, rew = memory.reward_sample(config.rp_batch_size)
                # obs, act, rew = memory.reward_sample2(
                #     config.rp_batch_size, config.lookahead)
                reward_obs, reward_act, reward_rew = memory.reward_sample(
                    config.nonzero_batch_size, nonzero=True)
                # reward_obs, reward_act, reward_rew = memory.nonzero_reward_sample(
                #     config.rp_batch_size, config.lookahead)
                obs_batch = norm_frame(
                    np.concatenate((obs, reward_obs), axis=0))
                act_batch = np.concatenate((act, reward_act), axis=0)
                rew_batch = np.concatenate((rew, reward_rew), axis=0)
                reward_label = rew_batch + 1

                trajectories = gdm.get_state(obs_batch, act_batch[:, :-1])

                rp_summary = rp.train(trajectories, act_batch, reward_label)
                writer.add_summary(rp_summary, step)

            if step % gdm_train_frequency == 0 and memory.can_sample(
                    config.gan_batch_size):
                state_batch, action_batch, next_state_batch = memory.GAN_sample(
                )
                # state_batch, act_batch, next_state_batch = memory.GAN_sample2(
                #     config.gan_batch_size, config.lookahead)

                # gdm.summary, disc_summary, merged_summary = gdm.train(
                #     norm_frame(state_batch), act_batch, norm_frame(next_state_batch), warmup_bool)
                gdm.summary, disc_summary = gdm.train(
                    norm_frame(state_batch), action_batch,
                    norm_frame(next_state_batch))

        if step > config.learn_start:
            # if step % config.train_frequency == 0 and memory.can_sample(config.batch_size):
            if step % config.train_frequency == 0:
                # s_t, act_batch, rew_batch, s_t_plus_1, terminal_batch = memory.sample(
                #     config.batch_size, config.lookahead)
                s_t, act_batch, rew_batch, s_t_plus_1, terminal_batch = memory.sample(
                )
                s_t, s_t_plus_1 = norm_frame(s_t), norm_frame(s_t_plus_1)
                if config.gats and config.dyna:
                    if step > config.gan_dqn_learn_start and gan_memory.can_sample(
                            config.batch_size):
                        gan_obs_batch, gan_act_batch, gan_rew_batch, gan_terminal_batch = gan_memory.sample(
                        )
                        # gan_obs_batch, gan_act_batch, gan_rew_batch = gan_memory.sample(
                        #     config.batch_size)
                        gan_obs_batch = norm_frame(gan_obs_batch)
                        trajectories = gdm.get_state(
                            gan_obs_batch, np.expand_dims(gan_act_batch,
                                                          axis=1))
                        gan_next_obs_batch = trajectories[:, -config.
                                                          history_length:, ...]

                        # gan_obs_batch, gan_next_obs_batch = \
                        #     norm_frame(gan_obs_batch), norm_frame(gan_next_obs_batch)

                        s_t = np.concatenate([s_t, gan_obs_batch], axis=0)
                        act_batch = np.concatenate([act_batch, gan_act_batch],
                                                   axis=0)
                        rew_batch = np.concatenate([rew_batch, gan_rew_batch],
                                                   axis=0)
                        s_t_plus_1 = np.concatenate(
                            [s_t_plus_1, gan_next_obs_batch], axis=0)
                        terminal_batch = np.concatenate(
                            [terminal_batch, gan_terminal_batch], axis=0)

                s_t, s_t_plus_1 = norm_frame_Q(
                    unnorm_frame(s_t)), norm_frame_Q(unnorm_frame(s_t_plus_1))

                q_t, loss, dqn_summary = agent.train(s_t, act_batch, rew_batch,
                                                     s_t_plus_1,
                                                     terminal_batch, step)

                writer.add_summary(dqn_summary, step)
                total_loss += loss
                total_q_value += q_t.mean()
                update_count += 1

            if step % config.target_q_update_step == config.target_q_update_step - 1:
                agent.updated_target_q_network()

        # reinit
        if terminal:
            screen, reward, action, terminal = env.new_random_game()

            num_game += 1
            ep_rewards.append(ep_reward)
            ep_reward = 0.
        else:
            ep_reward += reward

        total_reward += reward

        # change train freqancy
        if config.gats:
            if step == 10000 - 1:
                rp_train_frequency = 8
                gdm_train_frequency = 8
            if step == 50000 - 1:
                rp_train_frequency = 16
                gdm_train_frequency = 16
            if step == 100000 - 1:
                rp_train_frequency = 24
                gdm_train_frequency = 24

        # rolloutを行い画像を保存
        if config.gats and step % config._test_step == config._test_step - 1:
            rollout_image(config, image_dir, gdm, memory, step + 1, 16)

        # calcurate infometion
        if step >= config.learn_start:
            if step % config._test_step == config._test_step - 1:

                # plot
                if config.gats:
                    writer.add_summary(gdm.summary, step)
                    writer.add_summary(disc_summary, step)

                avg_reward = total_reward / config._test_step
                avg_loss = total_loss / update_count
                avg_q = total_q_value / update_count

                try:
                    max_ep_reward = np.max(ep_rewards)
                    min_ep_reward = np.min(ep_rewards)
                    avg_ep_reward = np.mean(ep_rewards)
                except:
                    max_ep_reward, min_ep_reward, avg_ep_reward = 0, 0, 0

                print(
                    '\navg_r: %.4f, avg_l: %.6f, avg_q: %3.6f, avg_ep_r: %.4f, max_ep_r: %.4f, min_ep_r: %.4f, # game: %d'
                    % (avg_reward, avg_loss, avg_q, avg_ep_reward,
                       max_ep_reward, min_ep_reward, num_game))

                # require terget q network
                if max_avg_ep_reward * 0.9 <= avg_ep_reward:
                    step_assign_op.eval({step_input: step + 1})
                    save_model(sess, saver, checkpoint_dir, step + 1)

                    max_avg_ep_reward = max(max_avg_ep_reward, avg_ep_reward)

                if step >= config.gan_dqn_learn_start:
                    if len(rp_accuracy) > 0:
                        rp_accuracy = np.mean(rp_accuracy)
                        rp_plus_accuracy = np.mean(rp_plus_accuracy)
                        rp_minus_accuracy = np.mean(rp_minus_accuracy)
                        nonzero_rp_accuracy = np.mean(nonzero_rp_accuracy)
                    else:
                        rp_accuracy = 0
                        rp_plus_accuracy = 0
                        rp_minus_accuracy = 0
                        nonzero_rp_accuracy = 0
                else:
                    rp_accuracy = 0
                    rp_plus_accuracy = 0
                    rp_minus_accuracy = 0
                    nonzero_rp_accuracy = 0

                # summary
                if step > 180:
                    inject_summary(
                        sess, writer, summary_ops, summary_placeholders, {
                            'average.reward': avg_reward,
                            'average.loss': avg_loss,
                            'average.q value': avg_q,
                            'episode.max reward': max_ep_reward,
                            'episode.min reward': min_ep_reward,
                            'episode.avg reward': avg_ep_reward,
                            'episode.num of game': num_game,
                            'episode.rewards': ep_rewards,
                            'episode.actions': actions,
                            'rp.rp_accuracy': rp_accuracy,
                            'rp.rp_plus_accuracy': rp_plus_accuracy,
                            'rp.rp_minus_accuracy': rp_minus_accuracy,
                            'rp.nonzero_rp_accuracy': nonzero_rp_accuracy
                        }, step)

                num_game = 0
                total_reward = 0.
                total_loss = 0.
                total_q_value = 0.
                update_count = 0
                ep_reward = 0.
                ep_rewards = []
                actions = []

                rp_accuracy = []
                rp_plus_accuracy = []
                rp_minus_accuracy = []
                nonzero_rp_accuracy = []
Exemplo n.º 10
0
Arquivo: train.py Projeto: pongib/dqn
import tensorflow as tf

from config import Config
from dqn.agent import Agent
from dqn.environment import Environment

with tf.Session() as sess:
    config = Config()
    environment = Environment(config)
    agent = Agent(config, environment, sess)
    agent.train()
Exemplo n.º 11
0
from datetime import datetime
from config import AgentConfig
from dqn.agent import Agent
import pytz
import warnings
import pandas as pd

if __name__ == '__main__':
    warnings.simplefilter("ignore", DeprecationWarning)

    config = AgentConfig()
    # env = Environment(sd, ed, config, datafile_loc='./fundretriever/snp500.h5')

    # parameters
    sd = datetime(2004, 1, 1, 0, 0, 0, 0, pytz.utc)
    ed = datetime(2015, 1, 1, 0, 0, 0, 0, pytz.utc)
    live_start_date = datetime(2015, 1, 1, 0, 0, 0, 0, pytz.utc)

    syms = pd.read_csv('sp500.csv')
    syms = syms.values[:, 0].tolist()
    captial = 1000000

    agent = Agent(config, syms, captial)
    agent.train(sd, ed)
    # agent.test(sd, ed, live_start_date=live_start_date)
Exemplo n.º 12
0
def main(argv=None):
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_id

    train_dir = FLAGS.train_dir
    if not tf.gfile.Exists(train_dir):
        tf.logging.info("Creating training directory: %s", train_dir)
        tf.gfile.MakeDirs(train_dir)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False,
                                    gpu_options=gpu_options)

    with tf.Session(config=session_config) as sess:
        env = gym.make(FLAGS.env_name)
        agent = Agent(FLAGS, env.action_space.n)

        # Initialize variables
        init = tf.global_variables_initializer()
        sess.run(init)

        # Setup logger/saver
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=200)
        ckpt = tf.train.get_checkpoint_state(train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        # Variables for logging
        log_vars = _init_log_vars()

        done = True
        for step in tqdm(range(FLAGS.num_steps), ncols=70):
            if done:
                env.reset()
                log_vars["num_games"] += 1
                log_vars["ep_rewards"].append(log_vars["ep_reward"])
                log_vars["ep_reward"] = 0.
                action = 0

            reward = 0.
            for _ in range(FLAGS.action_repeat):
                raw_observation, one_reward, done, info = env.step(action)
                reward += one_reward
                if done:
                    reward -= 1.
                    break
            observation = atari_preprocessing(raw_observation,
                                              FLAGS.screen_width,
                                              FLAGS.screen_height)
            log_vars["ep_reward"] += reward

            if FLAGS.display:
                env.render()

            if FLAGS.is_train:
                action, epsilon, summary_str = agent.train(
                    observation, reward, done, step, sess)
            else:
                action, epsilon, summary_str = agent.predict(observation, sess)
            log_vars["actions"].append(action)

            # Log performance periodically
            if (step + 1) % TEST_STEP == 0:
                ep_rewards = log_vars["ep_rewards"]
                num_games = log_vars["num_games"]
                actions = log_vars["actions"]
                try:
                    max_r = np.max(ep_rewards)
                    min_r = np.min(ep_rewards)
                    avg_r = np.mean(ep_rewards)
                except:
                    max_r = 0.0
                    min_r = 0.0
                    avg_r = 0.0

                format_str = "[Step {}] avg_r: {:4}, max_r: {:4}, min_r: {:4}, # games: {}, epsilon: {:4}".format(
                    step, avg_r, max_r, min_r, num_games, epsilon)
                tf.logging.info(format_str)

                summary = tf.Summary()
                summary.value.add(tag="avg_r", simple_value=avg_r)
                summary.value.add(tag="max_r", simple_value=max_r)
                summary.value.add(tag="min_r", simple_value=min_r)
                summary.value.add(tag="num_games", simple_value=num_games)
                summary.value.add(tag="epsilon", simple_value=epsilon)
                summary.value.add(tag="actions", histo=get_histo(actions))
                summary_writer.add_summary(summary, step)

                log_vars = _init_log_vars()

            # Update Tensorboard
            if (
                    step + 1
            ) % TENSORBOARD_STEP == 0 and step > FLAGS.learn_start and summary_str:
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically
            if (step + 1) % SAVE_STEP == 0:
                tf.logging.info("Save checkpoint at {} step".format(step))
                checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                saver.save(sess,
                           checkpoint_path,
                           global_step=agent.global_step.eval())

        env.close()