コード例 #1
0
    def __init__(
        self,
        evaluator,
        epsilon,
        env_gen,
        optim=None,
        memory_queue=None,
        memory_size=20000,
        mem_type=None,
        batch_size=64,
        gamma=0.99,
    ):
        self.epsilon = epsilon
        self._epsilon = epsilon  # Backup for evaluatoin
        self.optim = optim
        self.env = env_gen()
        self.memory_queue = memory_queue
        self.batch_size = batch_size
        self.gamma = gamma

        self.policy_net = copy.deepcopy(evaluator)
        self.target_net = copy.deepcopy(evaluator)

        if mem_type == "sumtree":
            self.memory = WeightedMemory(memory_size)
        else:
            self.memory = Memory(memory_size)
コード例 #2
0
    def __init__(
        self,
        env,
        evaluator,
        lr=0.01,
        gamma=0.99,
        momentum=0.9,
        weight_decay=0.01,
        mem_type="sumtree",
        buffer_size=20000,
        batch_size=16,
        *args,
        **kwargs
    ):

        self.gamma = gamma
        self.env = env
        self.state_size = self.env.width * self.env.height
        self.policy_net = copy.deepcopy(evaluator)
        # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device)
        self.target_net = copy.deepcopy(evaluator)
        # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device)

        self.policy_net.apply(init_weights)
        self.target_net.apply(init_weights)

        self.optim = torch.optim.SGD(self.policy_net.parameters(), weight_decay=weight_decay, momentum=momentum, lr=lr,)

        if mem_type == "sumtree":
            self.memory = WeightedMemory(buffer_size)
        else:
            self.memory = Memory(buffer_size)
        self.batch_size = batch_size
コード例 #3
0
    def __init__(
        self,
        evaluator,
        env_gen,
        optim=None,
        memory_queue=None,
        iterations=100,
        temperature_cutoff=5,
        batch_size=64,
        memory_size=200000,
        min_memory=20000,
        update_nn=True,
        starting_state_dict=None,
    ):
        self.iterations = iterations
        self.evaluator = evaluator.to(device)
        self.env_gen = env_gen
        self.optim = optim
        self.env = env_gen()
        self.root_node = None
        self.reset()
        self.update_nn = update_nn
        self.starting_state_dict = starting_state_dict

        self.memory_queue = memory_queue
        self.temp_memory = []
        self.memory = Memory(memory_size)
        self.min_memory = min_memory
        self.temperature_cutoff = temperature_cutoff
        self.actions = self.env.action_space.n

        self.evaluating = False

        self.batch_size = batch_size

        if APEX_AVAILABLE:
            opt_level = "O1"

            if self.optim:
                self.evaluator, self.optim = amp.initialize(
                    evaluator, optim, opt_level=opt_level)
                print("updating optimizer and evaluator")
            else:
                self.evaluator = amp.initialize(evaluator, opt_level=opt_level)
                print(" updated evaluator")
            self.amp_state_dict = amp.state_dict()
            print(vars(amp._amp_state))
        elif APEX_AVAILABLE:
            opt_level = "O1"
            print(vars(amp._amp_state))

        if self.starting_state_dict:
            print("laoding [sic] state dict in mcts")
            self.load_state_dict(self.starting_state_dict)
コード例 #4
0
    def __init__(self, env=Connect4Env):
        path = os.path.dirname(__file__)
        self.model_shelf = shelve.open(os.path.join(path, self.MODEL_SAVE_FILE))
        self.result_shelf = shelve.open(os.path.join(path, self.RESULT_SAVE_FILE))
        self.elo_value_shelf = shelve.open(os.path.join(path, self.ELO_VALUE_SAVE_FILE))

        self.env = env

        self.memory = Memory()

        atexit.register(self._close)