Esempio n. 1
0
def train(agent, args, env, saver):
    """
    Train the agent.
    """

    logger = Logger(agent, args, saver.save_dir)

    # Pre-fill the Replay Buffer
    agent.initialize_memory(args.pretrain, env)

    #Begin training loop
    for episode in range(1, args.num_episodes + 1):
        # Begin each episode with a clean environment
        env.reset()
        # Get initial state
        states = env.states
        # Gather experience until done or max_steps is reached
        for t in range(args.max_steps):
            actions = agent.act(states)
            next_states, rewards, dones = env.step(actions)
            agent.step(states, actions, rewards, next_states)
            states = next_states

            logger.log(rewards, agent)
            if np.any(dones):
                break

        saver.save_checkpoint(agent, args.save_every)
        agent.new_episode()
        logger.step(episode, agent)

    env.close()
    saver.save_final(agent)
    logger.graph()
    return
Esempio n. 2
0
def eval(agent, args, env):
    """
    Evaluate the performance of an agent using a saved weights file.
    """

    logger = Logger(agent, args)

    #Begin evaluation loop
    for episode in range(1, args.num_episodes + 1):
        # Begin each episode with a clean environment
        env.reset()
        # Get initial state
        states = env.states
        # Gather experience until done or max_steps is reached
        for t in range(args.max_steps):
            actions = agent.act(states, eval=True)
            next_states, rewards, dones = env.step(actions)
            states = next_states

            logger.log(rewards, agent)
            if np.any(dones):
                break

        agent.new_episode()
        logger.step(episode)

    env.close()
    return
Esempio n. 3
0
def train(agent, args, env, saver):
    """
    Train the agent.
    """

    logger = Logger(agent, args, saver.save_dir, log_every=50)

    # Pre-fill the Replay Buffer
    agent.initialize_memory(args.pretrain, env)

    #Begin training loop
    for episode in range(1, args.num_episodes + 1):
        # Begin each episode with a clean environment
        done = False
        env.reset()
        # Get initial state
        state = env.state
        # Gather experience until done or max_steps is reached
        while not done:
            action = agent.act(state)
            next_state, reward, done = env.step(action)
            if done:
                next_state = None
            agent.step(state, action, reward, next_state)
            state = next_state

            logger.log(reward, agent)

        saver.save_checkpoint(agent, args.save_every)
        agent.new_episode()
        logger.step(episode, agent.epsilon)

    env.close()
    saver.save_final(agent)
    logger.graph()
    return
Esempio n. 4
0
def notebook_eval_agent(args, env, filename, num_eps=2):
    eval_agent = DQN_Agent(env.state_size, env.action_size, args)
    eval_saver = Saver(eval_agent.framework, eval_agent, args.save_dir,
                       filename)
    args.eval = True
    logger = Logger(eval_agent, args)
    for episode in range(3):
        env.reset()
        state = env.state
        for t in range(args.max_steps):
            action = eval_agent.act(state)
            next_state, reward, done = env.step(action)
            state = next_state
            logger.log(reward, eval_agent)
            if done:
                break
        eval_agent.new_episode()
        logger.step(episode)
Esempio n. 5
0
import pyodbc
import time
import os
import ssl

t = dt2.now()
root = os.path.dirname(__file__)
spreadsheet_path = join(
    root, "data\\ae_output\\{}.xlsx".format(dt2.now().strftime('%b-%d')))
ss_filename = dt2.now().strftime('%b-%d' + '.xlsx')

log_name = f"{str(dt2.now().strftime('%b-%d'))}.log"
log_path = os.path.join(root, f"logs\\{log_name}")
db_path = join(root, "data\\info.db")

fh_logger = Logger(log_path, "FH")


class MainFh:
    def fh_main(self):
        if not isfile(db_path):
            fh_message("Failed to find sqlite DB", critical=True)
            return

        from sftp import Sftp
        fh_ftp = Sftp(host="sftp.placeholder.com",
                      user="******",
                      passwd="pass",
                      remote_directory="/daily/",
                      local_directory=join(root, "data\\to_process"),
                      fh=True,
Esempio n. 6
0
import pyodbc
import time
import os
import ssl
import sys

root = os.path.dirname(__file__)
spreadsheet_path = join(
    root, "data\\ae_output\\{}.xlsx".format(dt2.now().strftime('%b-%d')))
ss_filename = dt2.now().strftime('%b-%d' + '.xlsx')

log_name = f"{str(dt2.now().strftime('%b-%d'))}.log"
log_path = os.path.join(root, f"logs\\{log_name}")
db_path = join(root, "data\\info.db")

np_logger = Logger(log_path, 'NP')


class MainNP:
    def np_main(self):
        db_error_count = 0
        if not isfile(db_path):
            db_error_count = db_error_count + 1
            if db_error_count > 0:
                sys.exit(10)
            np_message("Failed to find sqlite DB", critical=True)

        from sftp import Sftp
        np_ftp = Sftp(host="sftp.placeholder.com", user="******", passwd="pass", remote_directory="/daily/",
                      local_directory=join(root, "data\\to_process"), fh=False, message=np_message)