コード例 #1
0
ファイル: agent.py プロジェクト: sharif1093/digideep
    def __init__(self, session, memory, **params):
        super(Agent, self).__init__(session, memory, **params)

        self.device = self.session.get_device()

        # Set the Policy
        # policyclass = get_class(self.params["policyname"])
        self.policy = Policy(device=self.device, **self.params["policyargs"])

        # Set the optimizer (+ schedulers if any)
        optimclass_actor = get_class(self.params["optimname_actor"])
        optimclass_critic = get_class(self.params["optimname_critic"])
        self.optimizer = {}
        self.optimizer["actor"] = optimclass_actor(
            self.policy.model["actor"].parameters(),
            **self.params["optimargs_actor"])
        self.optimizer["critic"] = optimclass_critic(
            self.policy.model["critic"].parameters(),
            **self.params["optimargs_critic"])

        # Build the sampler from sampler list:
        sampler_list = [get_class(k) for k in self.params["sampler_list"]]
        self.sampler = Compose(sampler_list)

        noiseclass = get_class(self.params["noisename"])
        self.noise = noiseclass(**self.params["noiseargs"])

        self.state["i_step"] = 0
コード例 #2
0
ファイル: agent.py プロジェクト: sharif1093/digideep
    def __init__(self, session, memory, **params):
        super(Agent, self).__init__(session, memory, **params)

        self.device = self.session.get_device()

        # Set the model
        policyclass = get_class(self.params["policyname"])
        self.policy = policyclass(device=self.device,
                                  **self.params["policyargs"])

        # Set the optimizer (+ schedulers if any)
        optimclass = get_class(self.params["optimname"])
        self.optimizer = optimclass(self.policy.model.parameters(),
                                    **self.params["optimargs"])

        self.state["i_step"] = 0
コード例 #3
0
    def run_wrapper_stack(self, env, stack):
        """
        Apply a series of wrappers.
        """
        for index in range(len(stack)):
            if stack[index]["enabled"]:
                wrapper_class = get_class(stack[index]["name"])
                # We pass mode to the wrapper as well, so the wrapper can adjust itself.
                if "request_for_args" in stack[index]:
                    for rfa in stack[index]["request_for_args"]:
                        logger("  Adding argument {} to the wrapper {}".format(
                            rfa, stack[index]["name"]))
                        if rfa == "session_state":
                            if self.session:
                                stack[index]["args"][
                                    "session_state"] = self.session.state
                        # TODO: Move the "mode" to optional parameter that can be requested!
                        # elif rfa == "mode":
                        #     stack[index]["args"]["mode"] = self.mode
                        else:
                            logger.fatal(
                                "  Argument {} not found!".format(rfa))
                            exit()

                env = wrapper_class(env,
                                    mode=self.mode,
                                    **stack[index]["args"])
        return env
コード例 #4
0
    def instantiate(self):
        """
        This function will instantiate the memory, the explorers, and the agents with their specific parameters.
        """
        ## Instantiate Memory
        self.memory = {}
        for memory_name in self.params["memory"]:
            memory_class = get_class(
                self.params["memory"][memory_name]["type"])
            self.memory[memory_name] = memory_class(
                self.session,
                mode=memory_name,
                **self.params["memory"][memory_name]["args"])

        ## Instantiate Agents
        self.agents = {}
        action_generator = {}
        for agent_name in self.params["agents"]:
            agent_class = get_class(self.params["agents"][agent_name]["type"])
            self.agents[agent_name] = agent_class(
                self.session, self.memory, **self.params["agents"][agent_name])

        ## Instantiate Explorers
        # All explorers: train/test/eval
        explorer_list = list(self.params["explorer"].keys())

        assert "train" in explorer_list, "'train' mode explorer is not defined in the explorer parameters."
        assert "test" in explorer_list, "'test' mode explorer is not defined in the explorer parameters."
        assert "eval" in explorer_list, "'eval' mode explorer is not defined in the explorer parameters."
        self.explorer = {}
        explorer_list.remove("eval")
        for e in explorer_list:
            # if e == "eval":
            #     continue
            self.explorer[e] = Explorer(self.session,
                                        agents=self.agents,
                                        **self.params["explorer"][e])
        # "eval" must be created as the last explorer to avoid GLFW connection to X11 issues.
        # if "eval" in self.explorer:
        # NOTE: We have made creation of "eval" explorer conditioned on the session being playing.
        #       This is to make sure that no connections to X11 exist in the main thread.
        if self.session.is_playing:
            self.explorer["eval"] = Explorer(self.session,
                                             agents=self.agents,
                                             **self.params["explorer"]["eval"])
コード例 #5
0
ファイル: agent_sched.py プロジェクト: sharif1093/dextron
    def __init__(self, session, memory, **params):
        super(AgentSchedule, self).__init__(session, memory, **params)

        self.device = self.session.get_device()

        # Set the Policy
        # policyclass = get_class(self.params["policyname"])
        self.policy = Policy(device=self.device, **self.params["policyargs"])

        # Set the optimizer (+ schedulers if any)
        optimclass_value = get_class(self.params["optimname_value"])
        optimclass_softq = get_class(self.params["optimname_softq"])
        optimclass_actor = get_class(self.params["optimname_actor"])

        self.optimizer = {}
        # self.optimizer["image"] = optimclass_value(self.policy.model["value"].parameters(), **self.params["optimargs_value"])

        self.optimizer["value"] = optimclass_value(
            self.policy.model["value"].parameters(),
            **self.params["optimargs_value"])
        self.optimizer["softq"] = optimclass_softq(
            self.policy.model["softq"].parameters(),
            **self.params["optimargs_softq"])
        self.optimizer["actor"] = optimclass_actor(
            self.policy.model["actor"].parameters(),
            **self.params["optimargs_actor"])

        self.criterion = {}
        self.criterion["value"] = nn.MSELoss()
        self.criterion["softq"] = nn.MSELoss()

        # Build the sampler from sampler list:
        sampler_list = [get_class(k) for k in self.params["sampler_list"]]
        self.sampler = Compose(sampler_list)

        # noiseclass = get_class(self.params["noisename"])
        # self.noise = noiseclass(**self.params["noiseargs"])

        self.state["i_step"] = 0

        initial = self.params["sampler_args"]["scheduler_start"]
        interval = self.params["sampler_args"]["scheduler_steps"]
        decay = self.params["sampler_args"]["scheduler_decay"]

        self.scheduler = Scheduler(initial, interval, decay)
コード例 #6
0
ファイル: myrunner.py プロジェクト: sharif1093/dextron
 def instantiate(self):
     """
     This function will instantiate the memory, the explorers, and the agents with their specific parameters.
     """
     ## Instantiate Memory
     self.memory = {}
     for memory_name in self.params["memory"]:
         memory_class = get_class(self.params["memory"][memory_name]["type"])
         self.memory[memory_name] = memory_class(self.session, mode=memory_name, **self.params["memory"][memory_name]["args"])
コード例 #7
0
ファイル: agent.py プロジェクト: sharif1093/digideep
    def __init__(self, session, memory, **params):
        super(Agent, self).__init__(session, memory, **params)

        self.device = self.session.get_device()

        # policy_type: Gaussian | Deterministic. Only "Gaussian" for now.

        # Set the Policy
        # policyclass = get_class(self.params["policyname"])
        self.policy = Policy(device=self.device, **self.params["policyargs"])
        
        # Set the optimizer (+ schedulers if any)
        optimclass_critic = get_class(self.params["optimname_critic"])
        optimclass_actor  = get_class(self.params["optimname_actor"])
        
        self.optimizer = {}
        self.optimizer["critic1"] = optimclass_critic(self.policy.model["critic1"].parameters(), **self.params["optimargs_critic"])
        self.optimizer["critic2"] = optimclass_critic(self.policy.model["critic2"].parameters(), **self.params["optimargs_critic"])
        self.optimizer["actor"]   = optimclass_actor(self.policy.model["actor"].parameters(), **self.params["optimargs_actor"])

        self.criterion = {}
        # self.criterion["critic"] = nn.MSELoss()
        # self.criterion["actor"]  = nn.MSELoss()
        
        # Build the sampler from sampler list:
        sampler_list = [get_class(k) for k in self.params["sampler_list"]]
        self.sampler = Compose(sampler_list)

        # if self.params["methodargs"]["automatic_entropy_tuning"] == True:
        #     # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
        #     self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
        #     self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        #     self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

        # noiseclass = get_class(self.params["noisename"])
        # self.noise = noiseclass(**self.params["noiseargs"])

        self.state["i_step"] = 0
コード例 #8
0
ファイル: plotter.py プロジェクト: sharif1093/digideep
    count = 0
    for ss in args.session_names:
        for s in ss:
            count += 1

    if args.output_dir == '' and count == 1:
        output_dir = None
    else:
        # If --output-dir is relative then --root-dir is prefixed to it.
        output_dir = os.path.join(args.root_dir, args.output_dir)

    # Change the PYTHONPATH to load the saved modules for more compatibility.
    # TODO: Why?
    sys.path.insert(0, args.root_dir)

    ## Get the loaders from SaaM
    loaders = []
    for sublist in args.session_names:
        subloaders = []
        for s in sublist:
            subloaders += [get_class(s + "." + "loader")]
        loaders += [subloaders]

    #######################################
    ##          Actual Plotting          ##
    #######################################
    # Do the plotting
    plotter_class = get_class(type_aliases(args.type))
    pc = plotter_class(loaders, output_dir, **args.opts)
    pc.plot(keyx=args.keyx, keyy=args.keyy)
コード例 #9
0
ファイル: main.py プロジェクト: sharif1093/digideep
def main(session):
    ##########################################
    ### LOOPING ###
    ###############
    # 1. Loading
    if session.is_loading:
        params = session.update_params({})
        # Summary
        logger.warn("="*50)
        logger.warn("Session:", params["session_name"])
        logger.warn("Message:", params["session_msg"])
        logger.warn("Command:\n\n$", params["session_cmd"], "\n")
        logger.warn("-"*50)

        runner = session.load_runner()
        # runner.override(session.args["override"])
        # params = runner.params
    else:
        ##########################################
        ### LOAD FRESH PARAMETERS ###
        #############################
        # Import method-specific modules
        ParamEngine = get_module(session.args["params"])
        cpanel = strict_update(ParamEngine.cpanel, session.args["cpanel"])
        params = ParamEngine.gen_params(cpanel) ## Generate params from cpanel everytime

        # Storing parameters in the session.
        params = session.update_params(params)
        session.dump_cpanel(cpanel)
        session.dump_params(params)

        # Summary
        logger.warn("="*50)
        logger.warn("Session:", params["session_name"])
        logger.warn("Message:", params["session_msg"])
        logger.warn("Command:\n\n$", params["session_cmd"], "\n")
        logger.warn("-"*50)
        # logger.info("Hyper-Parameters\n\n{}".format(yaml.dump(params, indent=2)) )
        logger.warn("Hyper-Parameters\n\n{}".format(json.dumps(cpanel, indent=4, sort_keys=False)) )
        logger.warn("="*50)
        ##########################################
        
        Runner = get_class(params["runner"]["name"])
        runner = Runner(params)

    # If we are creating the session only, we do not even need to start the runner.
    session.save_runner(runner, 0)
    if session.is_session_only:
        logger.fatal("Session was created; exiting ...")
        return
    
    # 2. Initializing: It will load_state_dicts if we are in loading mode
    runner.start(session)
    
    # 3. Train/Enjoy/Custom Loops
    if session.is_playing:
        runner.enjoy()
    elif session.is_customs:
        runner.custom()
    else:
        runner.train()