Example #1
0
def main():
    logging.getLogger().setLevel(logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config/global_config.json')
    parser.add_argument('--num_step', type=int, default=2000)
    parser.add_argument('--ckpt', type=str)
    parser.add_argument('--algo', type=str, default='DQN', choices=['DQN', 'DDQN', 'DuelDQN'], help='choose an algorithm')


    args = parser.parse_args()

    # preparing config
    # # for environment
    config = json.load(open(args.config))
    config["num_step"] = args.num_step
    cityflow_config = json.load(open(config['cityflow_config_file']))
    roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
    config["lane_phase_info"] = parse_roadnet(roadnetFile)

    # # for agent
    intersection_id = "intersection_1_1"
    config["intersection_id"] = intersection_id
    config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1  # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
    phase_list = config['lane_phase_info'][intersection_id]['phase']
    config["action_size"] = len(phase_list)
    config["batch_size"] = args.batch_size
    
    logging.info(phase_list)

    # build cityflow environment
    env = CityFlowEnv(config)

    # build agent
    agent = DQNAgent(config)
    
    # inference
    agent.load(args.ckpt)
    env.reset()
    state = env.get_state()
    
    for i in range(args.num_step): 
        action = agent.choose_action(state) # index of action
        action_phase = phase_list[action] # actual action
        next_state, reward = env.step(action_phase) # one step

        state = next_state

        # logging
        logging.info("step:{}/{}, action:{}, reward:{}"
                        .format(i, args.num_step, action, reward))
Example #2
0
replay_start_size = 2000
n_neurons = [32, 32]
render_delay = None
activations = ['relu', 'relu', 'linear']
filepath = "tetris-nn_4-8.h5"

agent = DQNAgent(env.get_action_space(),
                 n_neurons=n_neurons,
                 activations=activations,
                 epsilon=0,
                 epsilon_stop_episode=epsilon_stop_episode,
                 mem_size=mem_size,
                 discount=discount,
                 replay_start_size=replay_start_size)

agent.load(filepath)

scores = []

for episode in range(episodes):
    current_state = env.reset()
    done = False
    steps = 0

    # Game
    while not done and (not max_steps or steps < max_steps):
        next_states = env.get_next_states()
        # print(next_states)
        best_state = agent.best_state(next_states.values())
        # print(best_state)
Example #3
0
# Comment this line to enable training using your GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

# Initiating the Mountain Car environment
env = gym.make('MountainCar-v0')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

# Creating the DQN agent (with greedy policy, suited for evaluation)
agent = DQNAgent(state_size, action_size, epsilon=0.0, epsilon_min=0.0)

# Checking if weights from previous learning session exists
if os.path.exists('mountain_car.h5'):
    print('Loading weights from previous learning session.')
    agent.load("mountain_car.h5")
else:
    print('No weights found from previous learning session. Unable to proceed.')
    exit(-1)
return_history = []

for episodes in range(1, NUM_EPISODES + 1):
    # Reset the environment
    state = env.reset()
    # This reshape is needed to keep compatibility with Keras
    state = np.reshape(state, [1, state_size])
    # Cumulative reward is the return since the beginning of the episode
    cumulative_reward = 0.0
    for time in range(1, 500):
        # Render the environment for visualization
        env.render()
class DQNTrainer:
    def __init__(self,
                 level_filepath,
                 episodes=30000,
                 initial_epsilon=1.,
                 min_epsilon=0.1,
                 exploration_ratio=0.5,
                 max_steps=2000,
                 render_freq=500,
                 enable_render=True,
                 render_fps=20,
                 save_dir='checkpoints',
                 enable_save=True,
                 save_freq=500,
                 gamma=0.99,
                 batch_size=64,
                 min_replay_memory_size=1000,
                 replay_memory_size=100000,
                 target_update_freq=5,
                 seed=42):
        self.set_random_seed(seed)

        self.episodes = episodes
        self.max_steps = max_steps
        self.epsilon = initial_epsilon
        self.min_epsilon = min_epsilon
        self.exploration_ratio = exploration_ratio
        self.render_freq = render_freq
        self.enable_render = enable_render
        self.render_fps = render_fps
        self.save_dir = save_dir
        self.enable_save = enable_save
        self.save_freq = save_freq

        if enable_save and not os.path.exists(save_dir):
            os.makedirs(save_dir)

        level_loader = LevelLoader(level_filepath)

        self.agent = DQNAgent(level_loader.get_field_size(),
                              gamma=gamma,
                              batch_size=batch_size,
                              min_replay_memory_size=min_replay_memory_size,
                              replay_memory_size=replay_memory_size,
                              target_update_freq=target_update_freq)
        self.env = Snake(level_loader)
        self.summary = Summary()
        self.current_episode = 0
        self.max_average_length = 0

        self.epsilon_decay = (initial_epsilon -
                              min_epsilon) / (exploration_ratio * episodes)

    def set_random_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        tf.set_random_seed(seed)

    def train(self):
        pbar = tqdm(initial=self.current_episode,
                    total=self.episodes,
                    unit='episodes')
        while self.current_episode < self.episodes:
            current_state = self.env.reset()

            done = False
            steps = 0
            while not done and steps < self.max_steps:
                if random.random() > self.epsilon:
                    action = np.argmax(
                        self.agent.get_q_values(np.array([current_state])))
                else:
                    action = np.random.randint(NUM_ACTIONS)

                next_state, reward, done = self.env.step(action)

                self.agent.update_replay_memory(current_state, action, reward,
                                                next_state, done)
                self.summary.add('loss', self.agent.train())

                current_state = next_state
                steps += 1

            self.agent.increase_target_update_counter()

            self.summary.add('length', self.env.get_length())
            self.summary.add('reward', self.env.tot_reward)
            self.summary.add('steps', steps)

            # decay epsilon
            self.epsilon = max(self.epsilon - self.epsilon_decay,
                               self.min_epsilon)

            self.current_episode += 1

            # save model, training info
            if self.enable_save and self.current_episode % self.save_freq == 0:
                self.save(str(self.current_episode))

                average_length = self.summary.get_average('length')
                if average_length > self.max_average_length:
                    self.max_average_length = average_length
                    self.save('best')
                    print('best model saved - average_length: {}'.format(
                        average_length))

                self.summary.write(self.current_episode, self.epsilon)
                self.summary.clear()

            # update pbar
            pbar.update(1)

            # preview
            if self.enable_render and self.current_episode % self.render_freq == 0:
                self.preview(self.render_fps)

    def preview(self, render_fps, disable_exploration=False, save_dir=None):
        if save_dir is not None and not os.path.exists(save_dir):
            os.makedirs(save_dir)

        current_state = self.env.reset()

        self.env.render(fps=render_fps)
        if save_dir is not None:
            self.env.save_image(save_path=save_dir + '/0.png')

        done = False
        steps = 0
        while not done and steps < self.max_steps:
            if disable_exploration or random.random() > self.epsilon:
                action = np.argmax(
                    self.agent.get_q_values(np.array([current_state])))
            else:
                action = np.random.randint(NUM_ACTIONS)

            next_state, reward, done = self.env.step(action)
            current_state = next_state
            steps += 1

            self.env.render(fps=render_fps)
            if save_dir is not None:
                self.env.save_image(save_path=save_dir +
                                    '/{}.png'.format(steps))

        return self.env.get_length()

    def quit(self):
        self.env.quit()

    def save(self, suffix):
        self.agent.save(self.save_dir + '/model_{}.h5'.format(suffix),
                        self.save_dir + '/target_model_{}.h5'.format(suffix))

        dic = {
            'replay_memory': self.agent.replay_memory,
            'target_update_counter': self.agent.target_update_counter,
            'current_episode': self.current_episode,
            'epsilon': self.epsilon,
            'summary': self.summary,
            'max_average_length': self.max_average_length
        }

        with open(self.save_dir + '/training_info_{}.pkl'.format(suffix),
                  'wb') as fout:
            pickle.dump(dic, fout)

    def load(self, suffix):
        self.agent.load(self.save_dir + '/model_{}.h5'.format(suffix),
                        self.save_dir + '/target_model_{}.h5'.format(suffix))

        with open(self.save_dir + '/training_info_{}.pkl'.format(suffix),
                  'rb') as fin:
            dic = pickle.load(fin)

        self.agent.replay_memory = dic['replay_memory']
        self.agent.target_update_counter = dic['target_update_counter']
        self.current_episode = dic['current_episode']
        self.epsilon = dic['epsilon']
        self.summary = dic['summary']
        self.max_average_length = dic['max_average_length']
Example #5
0
    env = GameEnv(400)
    state_size = env.OBSERVATION_SPACE_VALUES
    action_size = env.ACTION_SPACE_SIZE
    agent = DQNAgent(state_size, action_size)
    done = False
    batch_size = 64

    best_score = -1
    render = False

    if PLAY:
        play_game(agent, False)
    else:
        if LOAD_MODEL:
            print("Loading Model...")
            agent.load(load_file_name)
            agent.epsilon = 1.0
            agent.learning_rate = 0.001
            agent.epsilon_decay = 0.990
            agent.gamma = 0.95

        for phase in range(7, 8):
            env.number_of_grids = phase + 3
            agent.epsilon = 1.0
            agent.memory = deque(maxlen=2000)
            phase_scores = deque(maxlen=5)
            for e in range(EPISODES):
                done = False
                state = env.reset()
                # env.seed(0)
                state = np.reshape(state, [1, state_size])
Example #6
0
def main():
    logging.getLogger().setLevel(logging.INFO)
    date = datetime.now().strftime('%Y%m%d_%H%M%S')
    parser = argparse.ArgumentParser()
    # parser.add_argument('--scenario', type=str, default='PongNoFrameskip-v4')
    parser.add_argument('--config',
                        type=str,
                        default='config/global_config.json',
                        help='config file')
    parser.add_argument('--algo',
                        type=str,
                        default='DQN',
                        choices=['DQN', 'DDQN', 'DuelDQN'],
                        help='choose an algorithm')
    parser.add_argument('--inference',
                        action="store_true",
                        help='inference or training')
    parser.add_argument('--ckpt', type=str, help='inference or training')
    parser.add_argument('--epoch',
                        type=int,
                        default=10,
                        help='number of training epochs')
    parser.add_argument(
        '--num_step',
        type=int,
        default=200,
        help='number of timesteps for one episode, and for inference')
    parser.add_argument('--save_freq',
                        type=int,
                        default=1,
                        help='model saving frequency')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='batchsize for training')
    parser.add_argument('--phase_step',
                        type=int,
                        default=15,
                        help='seconds of one phase')

    args = parser.parse_args()

    # preparing config
    # # for environment
    config = json.load(open(args.config))
    config["num_step"] = args.num_step

    assert "1x1" in config[
        'cityflow_config_file'], "please use 1x1 config file for cityflow"

    # config["replay_data_path"] = "replay"
    cityflow_config = json.load(open(config['cityflow_config_file']))
    roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
    config["lane_phase_info"] = parse_roadnet(roadnetFile)

    # # for agent
    intersection_id = list(config['lane_phase_info'].keys())[0]
    config["intersection_id"] = intersection_id

    phase_list = config['lane_phase_info'][config["intersection_id"]]['phase']
    config["action_size"] = len(phase_list)
    config["batch_size"] = args.batch_size

    logging.info(phase_list)

    model_dir = "model/{}_{}".format(args.algo, date)
    result_dir = "result/{}_{}".format(args.algo, date)
    config["result_dir"] = result_dir

    # parameters for training and inference
    # batch_size = 32
    EPISODES = args.epoch
    learning_start = 300
    # update_model_freq = args.batch_size
    update_model_freq = 1
    update_target_model_freq = 10

    if not args.inference:
        # build cityflow environment
        cityflow_config["saveReplay"] = True
        json.dump(cityflow_config, open(config["cityflow_config_file"], 'w'))
        env = CityFlowEnv(
            lane_phase_info=config["lane_phase_info"],
            intersection_id=config["intersection_id"],  # for single agent
            num_step=args.num_step,
            cityflow_config_file=config["cityflow_config_file"])

        # build agent
        config["state_size"] = env.state_size
        if args.algo == 'DQN':
            agent = DQNAgent(intersection_id,
                             state_size=config["state_size"],
                             action_size=config["action_size"],
                             batch_size=config["batch_size"],
                             phase_list=phase_list,
                             env=env)

        elif args.algo == 'DDQN':
            agent = DDQNAgent(config)
        elif args.algo == 'DuelDQN':
            agent = DuelingDQNAgent(config)

        # make dirs
        if not os.path.exists("model"):
            os.makedirs("model")
        if not os.path.exists("result"):
            os.makedirs("result")
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

        # training
        total_step = 0
        episode_rewards = []
        episode_scores = []
        with tqdm(total=EPISODES * args.num_step) as pbar:
            for i in range(EPISODES):
                # print("episode: {}".format(i))
                env.reset()
                state = env.get_state()

                episode_length = 0
                episode_reward = 0
                episode_score = 0
                while episode_length < args.num_step:

                    action = agent.choose_action_(state)  # index of action
                    action_phase = phase_list[action]  # actual action
                    # no yellow light
                    next_state, reward = env.step(action_phase)  # one step
                    # last_action_phase = action_phase
                    episode_length += 1
                    total_step += 1
                    episode_reward += reward
                    episode_score += env.get_score()

                    for _ in range(args.phase_step - 1):
                        next_state, reward_ = env.step(action_phase)
                        reward += reward_

                    reward /= args.phase_step

                    pbar.update(1)
                    # store to replay buffer
                    if episode_length > learning_start:
                        agent.remember(state, action_phase, reward, next_state)

                    state = next_state

                    # training
                    if episode_length > learning_start and total_step % update_model_freq == 0:
                        if len(agent.memory) > args.batch_size:
                            agent.replay()

                    # update target Q netwark
                    if episode_length > learning_start and total_step % update_target_model_freq == 0:
                        agent.update_target_network()

                    # logging
                    # logging.info("\repisode:{}/{}, total_step:{}, action:{}, reward:{}"
                    #             .format(i+1, EPISODES, total_step, action, reward))
                    pbar.set_description(
                        "total_step:{}, episode:{}, episode_step:{}, reward:{}"
                        .format(total_step, i + 1, episode_length, reward))

                # save episode rewards
                episode_rewards.append(
                    episode_reward /
                    args.num_step)  # record episode mean reward
                episode_scores.append(episode_score)
                print("score: {}, mean reward:{}".format(
                    episode_score, episode_reward / args.num_step))

                # save model
                if (i + 1) % args.save_freq == 0:
                    if args.algo != 'DuelDQN':
                        agent.model.save(model_dir +
                                         "/{}-{}.h5".format(args.algo, i + 1))
                    else:
                        agent.save(model_dir + "/{}-ckpt".format(args.algo),
                                   i + 1)

                    # save reward to file
                    df = pd.DataFrame({"rewards": episode_rewards})
                    df.to_csv(result_dir + '/rewards.csv', index=None)

                    df = pd.DataFrame({"rewards": episode_scores})
                    df.to_csv(result_dir + '/scores.csv', index=None)

                    # save figure
                    plot_data_lists([episode_rewards], ['episode reward'],
                                    figure_name=result_dir + '/rewards.pdf')
                    plot_data_lists([episode_scores], ['episode score'],
                                    figure_name=result_dir + '/scores.pdf')

    else:
        # inference
        cityflow_config["saveReplay"] = True
        json.dump(cityflow_config, open(config["cityflow_config_file"], 'w'))
        env = CityFlowEnv(
            lane_phase_info=config["lane_phase_info"],
            intersection_id=config["intersection_id"],  # for single agent
            num_step=args.num_step,
            cityflow_config_file=config["cityflow_config_file"])
        env.reset()

        # build agent
        config["state_size"] = env.state_size
        if args.algo == 'DQN':
            agent = DQNAgent(intersection_id,
                             state_size=config["state_size"],
                             action_size=config["action_size"],
                             batch_size=config["batch_size"],
                             phase_list=phase_list,
                             env=env)

        elif args.algo == 'DDQN':
            agent = DDQNAgent(config)
        elif args.algo == 'DuelDQN':
            agent = DuelingDQNAgent(config)
        agent.load(args.ckpt)

        state = env.get_state()
        scores = []
        for i in range(args.num_step):
            action = agent.choose_action(state)  # index of action
            action_phase = phase_list[action]  # actual action
            next_state, reward = env.step(action_phase)  # one step

            for _ in range(args.phase_step - 1):
                next_state, reward_ = env.step(action_phase)
                reward += reward_

            reward /= args.phase_step

            score = env.get_score()
            scores.append(score)
            state = next_state

            # logging
            logging.info("step:{}/{}, action:{}, reward:{}, score:{}".format(
                i + 1, args.num_step, action, reward, score))

        inf_result_dir = "result/" + args.ckpt.split("/")[1]
        df = pd.DataFrame({"inf_scores": scores})
        df.to_csv(inf_result_dir + '/inf_scores.csv', index=None)
        plot_data_lists([scores], ['inference scores'],
                        figure_name=inf_result_dir + '/inf_scores.pdf')
def main(model=None, mode='train', start_episode=0):
    my_xml = '''<?xml version="1.0" encoding="UTF-8" standalone="no" ?>
    <Mission xmlns="http://ProjectMalmo.microsoft.com" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
      <About>
        <Summary>Hill Descent.</Summary>
      </About>
      <ModSettings>
        <MsPerTick>20</MsPerTick>
      </ModSettings>
      <ServerSection>

        <ServerInitialConditions>

            <Time><StartTime>1</StartTime></Time>
        </ServerInitialConditions>
        <ServerHandlers>

          <DefaultWorldGenerator seed="-999595225643433963" forceReset="false" destroyAfterUse="false" />

          <ServerQuitFromTimeUp timeLimitMs="100000000"/>
          <ServerQuitWhenAnyAgentFinishes/>
        </ServerHandlers>
      </ServerSection>
      <AgentSection mode="Survival">
        <Name>Bob</Name>
        <AgentStart>
          <Placement x="28.5" y="87" z="330.5" pitch="-90" yaw="0"/>
        </AgentStart>
        <AgentHandlers>
          <DiscreteMovementCommands/>
          <MissionQuitCommands quitDescription="done"/>
          <ChatCommands/>
          <ObservationFromFullStats/>
          <ObservationFromGrid>
              <Grid name="sight">
                  <min x="{}" y="{}" z="{}"/>
                  <max x="{}" y="{}" z="{}"/>
              </Grid>
              <Grid name="feet">
                  <min x="0" y="-1" z="0"/>
                  <max x="0" y="-1" z="0"/>
              </Grid>
      </ObservationsationFromGrid>
          <AgentQuitFromTouchingBlockType>
              <Block type="cobblestone" />
          </AgentQuitFromTouchingBlockType>
        </AgentHandlers>
      </AgentSection>
    </Mission>

    '''.format(-(grid_width - 1) // 2, -grid_height, -(grid_width - 1) // 2,
               (grid_width - 1) // 2, grid_height, (grid_width - 1) // 2)

    batch_size = 100
    agent = DQNAgent(state_size, action_size, learning_rate, discount_rate,
                     epsilon, epsilon_min, epsilon_decay)
    if model != None:
        agent.load(model)
        if mode == 'test':
            agent.epsilon = 0.0
        print('loaded model: {}'.format(model))
    else:
        clear_csv('./data/results.csv')
        clear_csv('./data/moves.csv')

    my_client_pool = MalmoPython.ClientPool()
    my_client_pool.add(MalmoPython.ClientInfo("127.0.0.1", 10001))
    agent_host = MalmoPython.AgentHost()

    for e in range(start_episode + 1, episodes + 1):
        my_mission = MalmoPython.MissionSpec(my_xml, True)
        my_mission_record = MalmoPython.MissionRecordSpec()
        my_mission.requestVideo(800, 500)
        my_mission.setViewpoint(2)
        print("Waiting for the mission to start", end=' ')
        agent_host.startMission(
            my_mission,
            my_mission_record,
        )
        world_state = agent_host.getWorldState()
        while not world_state.has_mission_begun:
            print(".", end="")
            time.sleep(0.1)
            world_state = agent_host.getWorldState()
            for error in world_state.errors:
                print("Error:", error.text)
        print()
        agent_host.sendCommand('chat /kill @e[type=Chicken]')
        agent_host.sendCommand('chat /kill @e[type=Pig]')
        agent_host.sendCommand('chat /kill @e[type=Cow]')
        moves = 0
        episode_reward = 0

        while world_state.is_mission_running:
            world_state = agent_host.getWorldState()
            if world_state.number_of_observations_since_last_state > 0:
                try:
                    obvsText = world_state.observations[-1].text
                    data = json.loads(obvsText)
                except:
                    print("Error when getting state")
                    continue

                state = get_state(data)

                prev_x = data.get(u'XPos', 0)
                prev_y = data.get(u'YPos', 0)
                prev_z = data.get(u'ZPos', 0)

                useful_state = [state[2], state[6], state[7], state[8], \
                    state[10], state[11], state[13], \
                    state[14], state[16], state[17], \
                    state[18], state[22]]

                action = agent.act(useful_state)

                if ((action == 0 and state[grid_center - grid_width] == 0)
                        or (action == 1 and state[grid_center + 1] == 0) or
                    (action == 2 and state[grid_center + grid_width] == 0)
                        or (action == 3 and state[grid_center - 1] == 0)):
                    agent_host.sendCommand(jump_directions[action])
                else:
                    agent_host.sendCommand(directions[action])
                time.sleep(0.25)
                #print("North:", state[grid_center - grid_width], \
                #      "  East:", state[grid_center + 1], \
                #      "  South:", state[grid_center + grid_width], \
                #      "  West:", state[grid_center - 1])

                try:
                    world_state = wait_world_state(agent_host, world_state)
                    obvsText = world_state.observations[-1].text
                    data = json.loads(obvsText)
                except:
                    print("Error when getting state")
                    continue

                current_x = data.get(u'XPos', 0)
                current_y = data.get(u'YPos', 0)
                current_z = data.get(u'ZPos', 0)
                damage_taken = calculate_damage(prev_y, current_y)
                next_state = get_state(data)

                useful_next_state = [state[2], state[6], state[7], state[8], \
                    state[10], state[11], state[13], \
                    state[14], state[16], state[17], \
                    state[18], state[22]]

                # print("previous and current y", prev_y, current_y)
                # print("damage taken", damage_taken)
                #print("X:", prev_x, current_x, "\n", \
                #      "Y:", prev_y, current_y, "\n", \
                #      "Z:", prev_z, current_z, "\n")
                reward = 2 * (
                    prev_y - current_y
                ) - 50 * damage_taken - 1 if prev_x != current_x or prev_y != current_y or prev_z != current_z else -1000
                episode_reward += reward
                done = True if current_y <= goal_height or not world_state.is_mission_running or data[
                    'Life'] <= 0 else False

                agent.remember(useful_state, action, reward, useful_next_state,
                               done)
                if ((action == 0 and state[grid_center - grid_width] == 0)
                        or (action == 1 and state[grid_center + 1] == 0) or
                    (action == 2 and state[grid_center + grid_width] == 0)
                        or (action == 3 and state[grid_center - 1] == 0)):
                    print(
                        'episode {}/{}, action: {}, reward: {}, e: {:.2}, move: {}, done: {}'
                        .format(e, episodes, jump_directions[action], reward,
                                agent.epsilon, moves, done))
                else:
                    print(
                        'episode {}/{}, action: {}, reward: {}, e: {:.2}, move: {}, done: {}'
                        .format(e, episodes, directions[action], reward,
                                agent.epsilon, moves, done))
                moves += 1

                if mode == 'train' or model == None:
                    write_to_csv('./data/moves.csv',
                                 [e, current_x, current_y, current_z, reward])

                if e > batch_size:
                    agent.replay(batch_size)

                if done or moves > max_moves:
                    agent_host.sendCommand("quit")

        if (mode == 'train'
                or model == None) and (e in checkpoints
                                       or agent.epsilon <= epsilon_min):
            print('saving model at episode {}'.format(e))
            agent.save('./models/model_{}'.format(e))
            if agent.epsilon <= epsilon_min:
                break

        time.sleep(1)
        # my_mission.forceWorldReset()
        if mode == 'train' or model == None:
            write_to_csv('./data/results.csv',
                         [e, episode_reward, moves,
                          int(episode_reward > 0)])
    print "agent score average: %s" % (sum(tot_rewards) /
                                       float(len(tot_rewards)))


if __name__ == '__main__':
    # Parse the arguments
    args = parse_args()

    # set up weights dir
    set_up_weights_dir(args.weights_dir)

    # Train and play or just play
    if args.learning:
        print "#############\nlearning...with following args: %s\n#############" % vars(
            args)
        trained_agent = training(**vars(args))
    else:
        trained_agent = DQNAgent(environment=env,
                                 action_space=[0, 1, 2, 3, 4, 5, 6, 7],
                                 epsilon=0,
                                 NN_arch=args.NN_arch)
        if args.load:
            trained_agent.load(args.load)
        else:
            LOGGER.error(
                'No weights file specified! Please specify a weight file when '
                'running in playing mode. '
                'E.g: --load weights_dir/weights_file.txt')
            exit(1)
    play(environment=env, agent=trained_agent, quiet=args.quiet)