Пример #1
0
    def add_new_iteration_strategy_model(self, owner, adv_net_state_dict,
                                         cfr_iter):
        iter_strat = IterationStrategy(t_prof=self._t_prof,
                                       env_bldr=self._env_bldr,
                                       owner=owner,
                                       device=self._t_prof.device_inference,
                                       cfr_iter=cfr_iter)

        iter_strat.load_net_state_dict(
            self._ray.state_dict_to_torch(
                adv_net_state_dict, device=self._t_prof.device_inference))
        self._strategy_buffers[iter_strat.owner].add(
            iteration_strat=iter_strat)

        #  Store to disk
        if self._t_prof.export_each_net:
            path = ospj(self._t_prof.path_strategy_nets, self._t_prof.name)
            file_util.create_dir_if_not_exist(path)
            file_util.do_pickle(obj=iter_strat.state_dict(),
                                path=path,
                                file_name=str(iter_strat.cfr_iteration) +
                                "_P" + str(iter_strat.owner) + ".pkl")

        if self._t_prof.log_verbose:
            if owner == 1:
                # Logs
                process = psutil.Process(os.getpid())
                self.add_scalar(self._exp_mem_usage,
                                "Debug/Memory Usage/Chief", cfr_iter,
                                process.memory_info().rss)
Пример #2
0
    def load_state_dict(self, state):
        assert self.owner == state["owner"]

        self._strategies = []
        for net_state_dict, cfr_iter in state["nets"]:
            s = IterationStrategy(t_prof=self._t_prof,
                                  owner=self.owner,
                                  env_bldr=self._env_bldr,
                                  device=self._device,
                                  cfr_iter=cfr_iter)
            s.load_net_state_dict(net_state_dict)
            self._strategies.append(s)
            self._weights.append(cfr_iter)

        self._size = len(self._strategies)
Пример #3
0
    def update_weights(self, weights_for_eval_agent):

        # """"""""""""""""""""""""""""
        # Deep CFR
        # """"""""""""""""""""""""""""
        if self._AVRG:
            avrg_weights = weights_for_eval_agent[self.EVAL_MODE_AVRG_NET]

            for p in range(self.t_prof.n_seats):
                self.avrg_net_policies[p].load_net_state_dict(
                    self.ray.state_dict_to_torch(avrg_weights[p],
                                                 device=self.device))
                self.avrg_net_policies[p].eval()

        # """"""""""""""""""""""""""""
        # SD-CFR
        # """"""""""""""""""""""""""""
        if self._SINGLE:
            list_of_new_iter_strat_state_dicts = copy.deepcopy(
                weights_for_eval_agent[self.EVAL_MODE_SINGLE])

            for p in range(self.t_prof.n_seats):
                for state in list_of_new_iter_strat_state_dicts[p]:
                    state["net"] = self.ray.state_dict_to_torch(
                        state["net"], device=self.device)

                    _iter_strat = IterationStrategy.build_from_state_dict(
                        state=state,
                        t_prof=self.t_prof,
                        env_bldr=self.env_bldr,
                        device=self.device)

                    self._strategy_buffers[p].add(iteration_strat=_iter_strat)
Пример #4
0
    def generate_data(self, traverser, cfr_iter):
        iteration_strats = [
            IterationStrategy(t_prof=self._t_prof, env_bldr=self._env_bldr, owner=p,
                              device=self._t_prof.device_inference, cfr_iter=cfr_iter)
            for p in range(self._t_prof.n_seats)
        ]
        for s in iteration_strats:
            s.load_net_state_dict(state_dict=self._adv_wrappers[s.owner].net_state_dict())

        self._data_sampler.generate(n_traversals=self._t_prof.n_traversals_per_iter,
                                    traverser=traverser,
                                    iteration_strats=iteration_strats,
                                    cfr_iter=cfr_iter,
                                    )

        # Log after both players generated data
        if self._t_prof.log_verbose and traverser == 1 and (cfr_iter % 3 == 0):
            for p in range(self._t_prof.n_seats):
                self._ray.remote(self._chief_handle.add_scalar,
                                 self._exps_adv_buffer_size[p], "Debug/BufferSize", cfr_iter,
                                 self._adv_buffers[p].size)
                if self._AVRG:
                    self._ray.remote(self._chief_handle.add_scalar,
                                     self._exps_avrg_buffer_size[p], "Debug/BufferSize", cfr_iter,
                                     self._avrg_buffers[p].size)

            process = psutil.Process(os.getpid())
            self._ray.remote(self._chief_handle.add_scalar,
                             self._exp_mem_usage, "Debug/MemoryUsage/LA", cfr_iter,
                             process.memory_info().rss)
Пример #5
0
    def generate_data(self, traverser, cfr_iter):
        iteration_strats = [
            IterationStrategy(t_prof=self._t_prof,
                              env_bldr=self._env_bldr,
                              owner=p,
                              device=self._t_prof.device_inference,
                              cfr_iter=cfr_iter)
            for p in range(self._t_prof.n_seats)
        ]
        for s in iteration_strats:
            s.load_net_state_dict(
                state_dict=self._adv_wrappers[s.owner].net_state_dict())

##        #ADDED_CODE #TODO
        if cfr_iter < 20:
            n_traversals = self._t_prof.n_traversals_per_iter  #5
        elif 20 <= cfr_iter <= 100:
            lb = self._t_prof.n_traversals_per_iter  #5
            ub = self._t_prof.n_traversals_per_iter * 3  #15
            n_traversals = np.interp(cfr_iter, [20, 100], [lb, ub])  #5 to 10
            n_traversals = int(round(n_traversals))
        if cfr_iter > 100:
            n_traversals = self._t_prof.n_traversals_per_iter * 3  #15

        self._data_sampler.generate(
            n_traversals=n_traversals,
            traverser=traverser,
            iteration_strats=iteration_strats,
            cfr_iter=cfr_iter,
        )
        ##        #ADDED_CODE #TODO

        #        #ADDED_CODE #TODO
        #        if cfr_iter < 2:
        #            n_traversals=self._t_prof.n_traversals_per_iter #5
        #        elif 2 <= cfr_iter <= 100:
        #            lb = self._t_prof.n_traversals_per_iter #5
        #            ub = self._t_prof.n_traversals_per_iter*3 #15
        #            n_traversals=np.interp(cfr_iter, [2, 100], [lb, ub]) #5 to 10
        #            n_traversals = int(round(n_traversals))
        #        if cfr_iter > 100:
        #            n_traversals = self._t_prof.n_traversals_per_iter*3 #15
        #
        #        self._data_sampler.generate(n_traversals=n_traversals,
        #                                    traverser=traverser,
        #                                    iteration_strats=iteration_strats,
        #                                    cfr_iter=cfr_iter,
        #                                    )
        ##        #ADDED_CODE #TODO
        #

        #        self._data_sampler.generate(n_traversals=self._t_prof.n_traversals_per_iter,
        #                                    traverser=traverser,
        #                                    iteration_strats=iteration_strats,
        #                                    cfr_iter=cfr_iter,
        #                                    )

        # Log after both players generated data
        if self._t_prof.log_verbose and traverser == 1 and (cfr_iter % 3 == 0):
            for p in range(self._t_prof.n_seats):
                self._ray.remote(self._chief_handle.add_scalar,
                                 self._exps_adv_buffer_size[p],
                                 "Debug/BufferSize", cfr_iter,
                                 self._adv_buffers[p].size)
                if self._AVRG:
                    self._ray.remote(self._chief_handle.add_scalar,
                                     self._exps_avrg_buffer_size[p],
                                     "Debug/BufferSize", cfr_iter,
                                     self._avrg_buffers[p].size)

            process = psutil.Process(os.getpid())
            self._ray.remote(self._chief_handle.add_scalar,
                             self._exp_mem_usage, "Debug/MemoryUsage/LA",
                             cfr_iter,
                             process.memory_info().rss)