Example #1
0
class EpsilonGreedy(BaseModel):
    def __init__(
        self,
        evaluator,
        epsilon,
        env_gen,
        optim=None,
        memory_queue=None,
        memory_size=20000,
        mem_type=None,
        batch_size=64,
        gamma=0.99,
    ):
        self.epsilon = epsilon
        self._epsilon = epsilon  # Backup for evaluatoin
        self.optim = optim
        self.env = env_gen()
        self.memory_queue = memory_queue
        self.batch_size = batch_size
        self.gamma = gamma

        self.policy_net = copy.deepcopy(evaluator)
        self.target_net = copy.deepcopy(evaluator)

        if mem_type == "sumtree":
            self.memory = WeightedMemory(memory_size)
        else:
            self.memory = Memory(memory_size)

    def __call__(self, s):
        return self._epsilon_greedy(s)

    def q(self, s):
        if not isinstance(s, torch.Tensor):
            s = torch.from_numpy(s).long()
        # s = self.policy_net.preprocess(s)
        return self.policy_net(s)  # Only get predict policiies

    def _epsilon_greedy(self, s):
        if np.random.rand() < self.epsilon:
            possible_moves = [i for i, move in enumerate(self.env.valid_moves()) if move]
            a = random.choice(possible_moves)
        else:
            weights = self.q(s).detach().cpu().numpy()  # TODO maybe do this with tensors
            mask = (
                -1000000000 * ~np.array(self.env.valid_moves())
            ) + 1  # just a really big negative number? is quite hacky
            a = np.argmax(weights + mask)
        return a

    def load_state_dict(self, state_dict, target=False):
        self.policy_net.load_state_dict(state_dict)
        if target:
            self.target_net.load_state_dict(state_dict)

    def update(self, s, a, r, done, next_s):
        self.push_to_queue(s, a, r, done, next_s)
        self.pull_from_queue()
        if self.ready:
            self.update_from_memory()

    def push_to_queue(self, s, a, r, done, next_s):
        s = torch.tensor(s, device=device)
        # s = self.policy_net.preprocess(s)
        a = torch.tensor(a, device=device)
        r = torch.tensor(r, device=device)
        done = torch.tensor(done, device=device)
        next_s = torch.tensor(next_s, device=device)
        # next_s = self.policy_net.preprocess(next_s)
        self.memory_queue.put(Transition(s, a, r, done, next_s))

    def pull_from_queue(self):
        while not self.memory_queue.empty():
            experience = self.memory_queue.get()
            self.memory.add(experience)

    def update_from_memory(self):
        if isinstance(self.memory, WeightedMemory):
            tree_idx, batch, sample_weights = self.memory.sample(self.batch_size)
            sample_weights = torch.tensor(sample_weights, device=device)
        else:
            batch = self.memory.sample(self.batch_size)
        batch_t = Transition(*zip(*batch))  # transposed batch
        s_batch, a_batch, r_batch, done_batch, s_next_batch = batch_t
        s_batch = torch.cat(s_batch)
        a_batch = torch.stack(a_batch)
        r_batch = torch.stack(r_batch).view(-1, 1)
        s_next_batch = torch.cat(s_next_batch)
        done_batch = torch.stack(done_batch).view(-1, 1)
        q = self._state_action_value(s_batch, a_batch)

        # Get Actual Q values

        double_actions = self.policy_net(s_next_batch).max(1)[1].detach()  # used for double q learning
        q_next = self._state_action_value(s_next_batch, double_actions)

        q_next_actual = (~done_batch) * q_next  # Removes elements thx`at are done
        q_target = r_batch + self.gamma * q_next_actual
        ###TEST if clamping works or is even good practise
        q_target = q_target.clamp(-1, 1)
        ###/TEST

        if isinstance(self.memory, WeightedMemory):
            absolute_loss = torch.abs(q - q_target).detach().cpu().numpy()
            loss = weighted_smooth_l1_loss(
                q, q_target, sample_weights
            )  # TODO fix potential non-linearities using huber loss
            self.memory.batch_update(tree_idx, absolute_loss)

        else:
            loss = F.smooth_l1_loss(q, q_target)

        self.optim.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():  # see if this ends up doing anything - should just be relu
            param.grad.data.clamp_(-1, 1)
        self.optim.step()

    # determines when a neural net has enough data to train
    @property
    def ready(self):
        return len(self.memory) >= self.memory.max_size

    def state_dict(self):
        return self.policy_net.state_dict()

    def update_target_net(self):
        self.target_net.load_state_dict(self.state_dict())

    def train(self, train_state=True):
        return self.policy_net.train(train_state)

    def reset(self, *args, **kwargs):
        self.env.reset()

    def _state_action_value(self, s, a):
        a = a.view(-1, 1)
        return self.policy_net(s).gather(1, a)

    def evaluate(self, evaluate_state=False):
        # like train - sets evaluate state
        if evaluate_state:
            # self._epsilon = self.epsilon
            self.epsilon = 0
        else:
            self.epsilon = self._epsilon
        # self.evaluating = evaluate_state

    def play_action(self, action, player):
        self.env.step(action, player)
Example #2
0
class Q:
    def __init__(
        self,
        env,
        evaluator,
        lr=0.01,
        gamma=0.99,
        momentum=0.9,
        weight_decay=0.01,
        mem_type="sumtree",
        buffer_size=20000,
        batch_size=16,
        *args,
        **kwargs
    ):

        self.gamma = gamma
        self.env = env
        self.state_size = self.env.width * self.env.height
        self.policy_net = copy.deepcopy(evaluator)
        # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device)
        self.target_net = copy.deepcopy(evaluator)
        # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device)

        self.policy_net.apply(init_weights)
        self.target_net.apply(init_weights)

        self.optim = torch.optim.SGD(self.policy_net.parameters(), weight_decay=weight_decay, momentum=momentum, lr=lr,)

        if mem_type == "sumtree":
            self.memory = WeightedMemory(buffer_size)
        else:
            self.memory = Memory(buffer_size)
        self.batch_size = batch_size

    def __call__(self, s, player=None):  # TODO use player variable
        if not isinstance(s, torch.Tensor):
            s = torch.from_numpy(s).long()
        s = self.policy_net.preprocess(s)
        return self.policy_net(s)

    def state_action_value(self, s, a):
        a = a.view(-1, 1)
        return self.policy_net(s).gather(1, a)

    def update(self, s, a, r, done, s_next):
        s = torch.tensor(s, device=device)
        # s = self.policy_net.preprocess(s)
        a = torch.tensor(a, device=device)
        r = torch.tensor(r, device=device)
        done = torch.tensor(done, device=device)
        s_next = torch.tensor(s_next, device=device)
        # s_next = self.policy_net.preprocess(s_next)

        if not self.ready:
            self.memory.add(Transition(s, a, r, done, s_next))
            return

        # Using batch memory
        self.memory.add(Transition(s, a, r, done, s_next))
        if isinstance(self.memory, WeightedMemory):
            tree_idx, batch, sample_weights = self.memory.sample(self.batch_size)
            sample_weights = torch.tensor(sample_weights, device=device)
        else:
            batch = self.memory.sample(self.batch_size)
        batch_t = Transition(*zip(*batch))  # transposed batch

        # Get expected Q values
        s_batch, a_batch, r_batch, done_batch, s_next_batch = batch_t
        s_batch = torch.cat(s_batch)
        a_batch = torch.stack(a_batch)
        r_batch = torch.stack(r_batch).view(-1, 1)
        s_next_batch = torch.cat(s_next_batch)
        done_batch = torch.stack(done_batch).view(-1, 1)
        q = self.state_action_value(s_batch, a_batch)

        # Get Actual Q values

        double_actions = self.policy_net(s_next_batch).max(1)[1].detach()  # used for double q learning
        q_next = self.state_action_value(s_next_batch, double_actions)

        q_next_actual = (~done_batch) * q_next  # Removes elements thx`at are done
        q_target = r_batch + self.gamma * q_next_actual
        ###TEST if clamping works or is even good practise
        q_target = q_target.clamp(-1, 1)
        ###/TEST

        if isinstance(self.memory, WeightedMemory):
            absolute_loss = torch.abs(q - q_target).detach().cpu().numpy()
            loss = weighted_smooth_l1_loss(
                q, q_target, sample_weights
            )  # TODO fix potential non-linearities using huber loss
            self.memory.batch_update(tree_idx, absolute_loss)

        else:
            loss = F.smooth_l1_loss(q, q_target)

        self.optim.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():  # see if this ends up doing anything - should just be relu
            param.grad.data.clamp_(-1, 1)
        self.optim.step()