Ejemplo n.º 1
0
    def __init__(self, args, id=33):

        self.train = not args.eval
        '''
        print("LOADING ON SYSTEM: {}".format(platform.system()))
        print_bracketing(do_lower=False)
        if platform.system() == 'Linux':
            unity_filename = "Reacher_Linux_NoVis/Reacher.x86_64"
        elif platform.system() == 'Darwin':
            print("MacOS not supported in this code!")
        else:
            unity_filename = 'Reacher_Windows_x86_64/Reacher.exe'
        '''

        self.env = UnityEnvironment(
            file_name=
            '/Users/parksoy/Desktop/deep-reinforcement-learning/p2_continuous-control/Reacher_multi.app',
            worker_id=id,
            no_graphics=True)  #args.nographics
        print_bracketing(do_upper=False)
        self.brain_name = self.env.brain_names[0]
        self.brain = self.env.brains[self.brain_name]

        # Environment resets itself when the class is instantiated
        self.reset()

        self.action_size = self.brain.vector_action_space_size
        self.state_size = self.states.shape[1]
        self.agent_count = len(self.env_info.agents)
    def _collect_params(self, args, agent):
        """
        Creates a list of all the Params used to run this training instance,
        prints this list to the command line if QUIET is not flagged, and stores
        it for later saving to the params log in the /logs/ directory.
        """

        param_dict = {key: getattr(args, key) for key in vars(args)}
        for key in vars(agent):
            param_dict[key.lstrip('_')] = getattr(agent, key)

        param_dict.pop('nographics', None)
        param_dict.pop('save_every', None)
        param_dict.pop('print_every', None)
        param_dict.pop('verbose', None)
        param_dict.pop('quiet', None)
        param_dict.pop('latest', None)
        param_dict.pop('save_every', None)
        param_dict.pop('avg_score', None)
        param_dict.pop('episode', None)
        param_dict.pop('t_step', None)
        if param_dict['update_type'] == 'soft':
            param_dict.pop('C', None)
        else:
            param_dict.pop('tau', None)
        param_list = [
            "{}: {}".format(key, value) for (key, value) in param_dict.items()
        ]
        print_bracketing(param_list)

        return param_list
    def __init__(self, agent=None, args=None, save_dir='.'):
        """
        Initialize a Logger object.
        """

        if agent == None or args == None:
            print("Blank init for Logger object.")
            return
        self.eval = args.eval
        self.framework = agent.framework
        self.max_eps = args.num_episodes
        self.quietmode = args.quiet
        self.log_every = args.log_every
        self.print_every = args.print_every
        self.agent_count = agent.agent_count
        self.save_dir = save_dir
        self.log_dir = os.path.join(self.save_dir, 'logs').replace('\\', '/')
        self.filename = os.path.basename(self.save_dir)
        self.start_time = self.prev_timestamp = time.time()
        self.scores = []
        self._reset_rewards()

        if not self.eval:

            timestamp = time.strftime("%H:%M:%S", time.localtime())
            statement = "Starting training at: {}".format(timestamp)
            print_bracketing(statement)

            check_dir(self.log_dir)
            self._init_logs(self._collect_params(args, agent))
    def _init_logs(self, params):
        """
        Outputs an initial log of all parameters provided as a list.
        """

        basename = os.path.join(self.log_dir, self.filename)
        self.paramfile = basename + "_LOG.txt"
        self.alossfile = basename + "_actorloss.txt"
        self.clossfile = basename + "_criticloss.txt"
        self.scoresfile = basename + "_scores.txt"
        # Create the log files. Params is filled on creation, the others are
        # initialized blank and filled as training proceeds.
        files = [
            self.paramfile, self.alossfile, self.clossfile, self.scoresfile
        ]
        log_statement = ["Logfiles saved to: {}".format(self.log_dir)]
        for filename in files:
            with open(filename, 'w') as f:
                if filename.endswith("_LOG.txt"):
                    for line in params:
                        f.write(line + '\n')
                else:
                    pass
            log_statement.append("...{}".format(os.path.basename(filename)))
        print_bracketing(log_statement)
Ejemplo n.º 5
0
    def __init__(self, args, id=10):
        """
        Initialize an environment wrapper.
        """

        self.train = not args.eval
        self.pixels = args.pixels

        print("LOADING ON SYSTEM: {}".format(platform.system()))

        print_bracketing(do_lower=False)
        if platform.system() == 'Linux':
            unity_filename = "Banana_Linux_NoVis/Banana.x86"
        elif platform.system() == 'Darwin':
            print("MacOS not supported in this code!")
        elif self.pixels:
            unity_filename = "Banana_Windows_x86_64_Visual/Banana.exe"
        else:
            unity_filename = "Banana_Windows_x86_64/Banana.exe"
        self.env = UnityEnvironment(file_name=unity_filename,
                                    worker_id=id,
                                    no_graphics=args.nographics)
        print_bracketing(do_upper=False)

        self.brain_name = self.env.brain_names[0]
        self.brain = self.env.brains[self.brain_name]

        # Environment resets itself when the class is instantiated
        self.reset()

        self.action_size = self.brain.vector_action_space_size
        self.state_size = self.state.shape[1]
        self.agent_count = len(self.env_info.agents)
Ejemplo n.º 6
0
 def _load_agent(self, load_file, agent):
     checkpoint = torch.load(load_file,
                             map_location=lambda storage, loc: storage)
     agent.actor.load_state_dict(checkpoint['actor_dict'])
     agent.critic.load_state_dict(checkpoint['critic_dict'])
     agent._hard_update(agent.actor, agent.actor_target)
     agent._hard_update(agent.critic, agent.critic_target)
     statement = "Successfully loaded file: {}".format(load_file)
     print_bracketing(statement)
Ejemplo n.º 7
0
    def _load_agent(self, load_file, agent):
        """
        Loads a checkpoint from an earlier trained agent.
        """

        checkpoint = torch.load(load_file,
                                map_location=lambda storage, loc: storage)
        agent.q.load_state_dict(checkpoint['q_dict'])
        agent._hard_update(agent.q, agent.q_target)
        statement = "Successfully loaded file: {}".format(load_file)
        print_bracketing(statement)
Ejemplo n.º 8
0
 def __init__(self,
              prefix,
              agent,
              save_dir='saves',
              load_file=None,
              file_ext=".agent"):
     self.file_ext = file_ext
     self.save_dir, self.filename = self.generate_savename(prefix, save_dir)
     if load_file:
         self._load_agent(load_file, agent)
     else:
         statement = "Saving to base filename: {}".format(self.filename)
         print_bracketing(statement)
Ejemplo n.º 9
0
    def _collect_params(self, args, agent):
        """
        Creates a list of all the Params used to run this training instance,
        prints this list to the command line if QUIET is not flagged, and stores
        it for later saving to the params log in the /logs/ directory.
        """

        param_list = [
            self._format_param(arg, args) for arg in vars(args)
            if arg not in vars(agent)
        ]
        param_list += [self._format_param(arg, agent) for arg in vars(agent)]
        if not self.quietmode: print_bracketing(param_list)
        return param_list
def _get_filepath(files):
    """
    Prompts the user about what save to load, or uses the last modified save.
    """

    load_file_prompt = " (LATEST)\n\nPlease choose a saved Agent training file (or: q/quit): "
    user_quit_message = "User quit process before loading a file."
    message = [
        "{}. {}".format(len(files) - i, file) for i, file in enumerate(files)
    ]
    message = '\n'.join(message).replace('\\', '/')
    message = message + load_file_prompt
    save_file = input(message)
    if save_file.lower() in ("q", "quit"):
        raise KeyboardInterrupt(user_quit_message)
    try:
        file_index = len(files) - int(save_file)
        assert file_index >= 0
        return files[file_index]
    except:
        print_bracketing('Input "{}" is INVALID...'.format(save_file))
        return _get_filepath(files)