Example #1
0
    def __init__(
        self,
        evaluator,
        epsilon,
        env_gen,
        optim=None,
        memory_queue=None,
        memory_size=20000,
        mem_type=None,
        batch_size=64,
        gamma=0.99,
    ):
        self.epsilon = epsilon
        self._epsilon = epsilon  # Backup for evaluatoin
        self.optim = optim
        self.env = env_gen()
        self.memory_queue = memory_queue
        self.batch_size = batch_size
        self.gamma = gamma

        self.policy_net = copy.deepcopy(evaluator)
        self.target_net = copy.deepcopy(evaluator)

        if mem_type == "sumtree":
            self.memory = WeightedMemory(memory_size)
        else:
            self.memory = Memory(memory_size)
Example #2
0
    def __init__(
        self,
        env,
        evaluator,
        lr=0.01,
        gamma=0.99,
        momentum=0.9,
        weight_decay=0.01,
        mem_type="sumtree",
        buffer_size=20000,
        batch_size=16,
        *args,
        **kwargs
    ):

        self.gamma = gamma
        self.env = env
        self.state_size = self.env.width * self.env.height
        self.policy_net = copy.deepcopy(evaluator)
        # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device)
        self.target_net = copy.deepcopy(evaluator)
        # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device)

        self.policy_net.apply(init_weights)
        self.target_net.apply(init_weights)

        self.optim = torch.optim.SGD(self.policy_net.parameters(), weight_decay=weight_decay, momentum=momentum, lr=lr,)

        if mem_type == "sumtree":
            self.memory = WeightedMemory(buffer_size)
        else:
            self.memory = Memory(buffer_size)
        self.batch_size = batch_size
    def __init__(
        self,
        evaluator,
        env_gen,
        optim=None,
        memory_queue=None,
        iterations=100,
        temperature_cutoff=5,
        batch_size=64,
        memory_size=200000,
        min_memory=20000,
        update_nn=True,
        starting_state_dict=None,
    ):
        self.iterations = iterations
        self.evaluator = evaluator.to(device)
        self.env_gen = env_gen
        self.optim = optim
        self.env = env_gen()
        self.root_node = None
        self.reset()
        self.update_nn = update_nn
        self.starting_state_dict = starting_state_dict

        self.memory_queue = memory_queue
        self.temp_memory = []
        self.memory = Memory(memory_size)
        self.min_memory = min_memory
        self.temperature_cutoff = temperature_cutoff
        self.actions = self.env.action_space.n

        self.evaluating = False

        self.batch_size = batch_size

        if APEX_AVAILABLE:
            opt_level = "O1"

            if self.optim:
                self.evaluator, self.optim = amp.initialize(
                    evaluator, optim, opt_level=opt_level)
                print("updating optimizer and evaluator")
            else:
                self.evaluator = amp.initialize(evaluator, opt_level=opt_level)
                print(" updated evaluator")
            self.amp_state_dict = amp.state_dict()
            print(vars(amp._amp_state))
        elif APEX_AVAILABLE:
            opt_level = "O1"
            print(vars(amp._amp_state))

        if self.starting_state_dict:
            print("laoding [sic] state dict in mcts")
            self.load_state_dict(self.starting_state_dict)
    def __init__(self, env=Connect4Env):
        path = os.path.dirname(__file__)
        self.model_shelf = shelve.open(os.path.join(path, self.MODEL_SAVE_FILE))
        self.result_shelf = shelve.open(os.path.join(path, self.RESULT_SAVE_FILE))
        self.elo_value_shelf = shelve.open(os.path.join(path, self.ELO_VALUE_SAVE_FILE))

        self.env = env

        self.memory = Memory()

        atexit.register(self._close)
Example #5
0
def describe_dqn(memory: Memory, agent, gamma: float = 1):
    state = memory.sample_episode().state(0)
    rewards = [e.rewards() for e in memory.episodes()]

    sizes = np.array(lmap(len, rewards))
    discounted = np.array([discount_rewards(r, gamma) for r in rewards])
    summed = np.array([discount_rewards(r, 1) for r in rewards])
    q_values = get_q_values(state, agent)

    return {
        'mean reward': discounted.mean(), 'rewards': discounted[None],
        'mean reward sum': summed.mean(), 'rewards sum': summed[None],
        'mean size': sizes.mean(), 'sizes': sizes[None],
        'q_max': q_values.max(), 'q_min': q_values.min(),
    }
Example #6
0
class EpsilonGreedy(BaseModel):
    def __init__(
        self,
        evaluator,
        epsilon,
        env_gen,
        optim=None,
        memory_queue=None,
        memory_size=20000,
        mem_type=None,
        batch_size=64,
        gamma=0.99,
    ):
        self.epsilon = epsilon
        self._epsilon = epsilon  # Backup for evaluatoin
        self.optim = optim
        self.env = env_gen()
        self.memory_queue = memory_queue
        self.batch_size = batch_size
        self.gamma = gamma

        self.policy_net = copy.deepcopy(evaluator)
        self.target_net = copy.deepcopy(evaluator)

        if mem_type == "sumtree":
            self.memory = WeightedMemory(memory_size)
        else:
            self.memory = Memory(memory_size)

    def __call__(self, s):
        return self._epsilon_greedy(s)

    def q(self, s):
        if not isinstance(s, torch.Tensor):
            s = torch.from_numpy(s).long()
        # s = self.policy_net.preprocess(s)
        return self.policy_net(s)  # Only get predict policiies

    def _epsilon_greedy(self, s):
        if np.random.rand() < self.epsilon:
            possible_moves = [i for i, move in enumerate(self.env.valid_moves()) if move]
            a = random.choice(possible_moves)
        else:
            weights = self.q(s).detach().cpu().numpy()  # TODO maybe do this with tensors
            mask = (
                -1000000000 * ~np.array(self.env.valid_moves())
            ) + 1  # just a really big negative number? is quite hacky
            a = np.argmax(weights + mask)
        return a

    def load_state_dict(self, state_dict, target=False):
        self.policy_net.load_state_dict(state_dict)
        if target:
            self.target_net.load_state_dict(state_dict)

    def update(self, s, a, r, done, next_s):
        self.push_to_queue(s, a, r, done, next_s)
        self.pull_from_queue()
        if self.ready:
            self.update_from_memory()

    def push_to_queue(self, s, a, r, done, next_s):
        s = torch.tensor(s, device=device)
        # s = self.policy_net.preprocess(s)
        a = torch.tensor(a, device=device)
        r = torch.tensor(r, device=device)
        done = torch.tensor(done, device=device)
        next_s = torch.tensor(next_s, device=device)
        # next_s = self.policy_net.preprocess(next_s)
        self.memory_queue.put(Transition(s, a, r, done, next_s))

    def pull_from_queue(self):
        while not self.memory_queue.empty():
            experience = self.memory_queue.get()
            self.memory.add(experience)

    def update_from_memory(self):
        if isinstance(self.memory, WeightedMemory):
            tree_idx, batch, sample_weights = self.memory.sample(self.batch_size)
            sample_weights = torch.tensor(sample_weights, device=device)
        else:
            batch = self.memory.sample(self.batch_size)
        batch_t = Transition(*zip(*batch))  # transposed batch
        s_batch, a_batch, r_batch, done_batch, s_next_batch = batch_t
        s_batch = torch.cat(s_batch)
        a_batch = torch.stack(a_batch)
        r_batch = torch.stack(r_batch).view(-1, 1)
        s_next_batch = torch.cat(s_next_batch)
        done_batch = torch.stack(done_batch).view(-1, 1)
        q = self._state_action_value(s_batch, a_batch)

        # Get Actual Q values

        double_actions = self.policy_net(s_next_batch).max(1)[1].detach()  # used for double q learning
        q_next = self._state_action_value(s_next_batch, double_actions)

        q_next_actual = (~done_batch) * q_next  # Removes elements thx`at are done
        q_target = r_batch + self.gamma * q_next_actual
        ###TEST if clamping works or is even good practise
        q_target = q_target.clamp(-1, 1)
        ###/TEST

        if isinstance(self.memory, WeightedMemory):
            absolute_loss = torch.abs(q - q_target).detach().cpu().numpy()
            loss = weighted_smooth_l1_loss(
                q, q_target, sample_weights
            )  # TODO fix potential non-linearities using huber loss
            self.memory.batch_update(tree_idx, absolute_loss)

        else:
            loss = F.smooth_l1_loss(q, q_target)

        self.optim.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():  # see if this ends up doing anything - should just be relu
            param.grad.data.clamp_(-1, 1)
        self.optim.step()

    # determines when a neural net has enough data to train
    @property
    def ready(self):
        return len(self.memory) >= self.memory.max_size

    def state_dict(self):
        return self.policy_net.state_dict()

    def update_target_net(self):
        self.target_net.load_state_dict(self.state_dict())

    def train(self, train_state=True):
        return self.policy_net.train(train_state)

    def reset(self, *args, **kwargs):
        self.env.reset()

    def _state_action_value(self, s, a):
        a = a.view(-1, 1)
        return self.policy_net(s).gather(1, a)

    def evaluate(self, evaluate_state=False):
        # like train - sets evaluate state
        if evaluate_state:
            # self._epsilon = self.epsilon
            self.epsilon = 0
        else:
            self.epsilon = self._epsilon
        # self.evaluating = evaluate_state

    def play_action(self, action, player):
        self.env.step(action, player)
Example #7
0
class Q:
    def __init__(
        self,
        env,
        evaluator,
        lr=0.01,
        gamma=0.99,
        momentum=0.9,
        weight_decay=0.01,
        mem_type="sumtree",
        buffer_size=20000,
        batch_size=16,
        *args,
        **kwargs
    ):

        self.gamma = gamma
        self.env = env
        self.state_size = self.env.width * self.env.height
        self.policy_net = copy.deepcopy(evaluator)
        # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device)
        self.target_net = copy.deepcopy(evaluator)
        # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device)

        self.policy_net.apply(init_weights)
        self.target_net.apply(init_weights)

        self.optim = torch.optim.SGD(self.policy_net.parameters(), weight_decay=weight_decay, momentum=momentum, lr=lr,)

        if mem_type == "sumtree":
            self.memory = WeightedMemory(buffer_size)
        else:
            self.memory = Memory(buffer_size)
        self.batch_size = batch_size

    def __call__(self, s, player=None):  # TODO use player variable
        if not isinstance(s, torch.Tensor):
            s = torch.from_numpy(s).long()
        s = self.policy_net.preprocess(s)
        return self.policy_net(s)

    def state_action_value(self, s, a):
        a = a.view(-1, 1)
        return self.policy_net(s).gather(1, a)

    def update(self, s, a, r, done, s_next):
        s = torch.tensor(s, device=device)
        # s = self.policy_net.preprocess(s)
        a = torch.tensor(a, device=device)
        r = torch.tensor(r, device=device)
        done = torch.tensor(done, device=device)
        s_next = torch.tensor(s_next, device=device)
        # s_next = self.policy_net.preprocess(s_next)

        if not self.ready:
            self.memory.add(Transition(s, a, r, done, s_next))
            return

        # Using batch memory
        self.memory.add(Transition(s, a, r, done, s_next))
        if isinstance(self.memory, WeightedMemory):
            tree_idx, batch, sample_weights = self.memory.sample(self.batch_size)
            sample_weights = torch.tensor(sample_weights, device=device)
        else:
            batch = self.memory.sample(self.batch_size)
        batch_t = Transition(*zip(*batch))  # transposed batch

        # Get expected Q values
        s_batch, a_batch, r_batch, done_batch, s_next_batch = batch_t
        s_batch = torch.cat(s_batch)
        a_batch = torch.stack(a_batch)
        r_batch = torch.stack(r_batch).view(-1, 1)
        s_next_batch = torch.cat(s_next_batch)
        done_batch = torch.stack(done_batch).view(-1, 1)
        q = self.state_action_value(s_batch, a_batch)

        # Get Actual Q values

        double_actions = self.policy_net(s_next_batch).max(1)[1].detach()  # used for double q learning
        q_next = self.state_action_value(s_next_batch, double_actions)

        q_next_actual = (~done_batch) * q_next  # Removes elements thx`at are done
        q_target = r_batch + self.gamma * q_next_actual
        ###TEST if clamping works or is even good practise
        q_target = q_target.clamp(-1, 1)
        ###/TEST

        if isinstance(self.memory, WeightedMemory):
            absolute_loss = torch.abs(q - q_target).detach().cpu().numpy()
            loss = weighted_smooth_l1_loss(
                q, q_target, sample_weights
            )  # TODO fix potential non-linearities using huber loss
            self.memory.batch_update(tree_idx, absolute_loss)

        else:
            loss = F.smooth_l1_loss(q, q_target)

        self.optim.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():  # see if this ends up doing anything - should just be relu
            param.grad.data.clamp_(-1, 1)
        self.optim.step()
class MCTreeSearch(BaseModel):
    def __init__(
        self,
        evaluator,
        env_gen,
        optim=None,
        memory_queue=None,
        iterations=100,
        temperature_cutoff=5,
        batch_size=64,
        memory_size=200000,
        min_memory=20000,
        update_nn=True,
        starting_state_dict=None,
    ):
        self.iterations = iterations
        self.evaluator = evaluator.to(device)
        self.env_gen = env_gen
        self.optim = optim
        self.env = env_gen()
        self.root_node = None
        self.reset()
        self.update_nn = update_nn
        self.starting_state_dict = starting_state_dict

        self.memory_queue = memory_queue
        self.temp_memory = []
        self.memory = Memory(memory_size)
        self.min_memory = min_memory
        self.temperature_cutoff = temperature_cutoff
        self.actions = self.env.action_space.n

        self.evaluating = False

        self.batch_size = batch_size

        if APEX_AVAILABLE:
            opt_level = "O1"

            if self.optim:
                self.evaluator, self.optim = amp.initialize(
                    evaluator, optim, opt_level=opt_level)
                print("updating optimizer and evaluator")
            else:
                self.evaluator = amp.initialize(evaluator, opt_level=opt_level)
                print(" updated evaluator")
            self.amp_state_dict = amp.state_dict()
            print(vars(amp._amp_state))
        elif APEX_AVAILABLE:
            opt_level = "O1"
            print(vars(amp._amp_state))

        if self.starting_state_dict:
            print("laoding [sic] state dict in mcts")
            self.load_state_dict(self.starting_state_dict)

    def reset(self, player=1):
        base_state = self.env.reset()
        probs, v = self.evaluator(base_state)
        self._set_root(MCNode(state=base_state, v=v, player=player))
        self.root_node.create_children(probs, self.env.valid_moves())
        self.moves_played = 0
        self.temp_memory = []

        return base_state

    ## Ignores the inputted state for the moment. Produces the correct action, and changes the root node appropriately
    # TODO might want to do a check on the state to make sure it is consistent
    def __call__(self, s):  # not using player
        move = self._search_and_play()
        return move

    def _search_and_play(self):
        final_temp = 1
        temperature = 1 if self.moves_played < self.temperature_cutoff else final_temp
        self.search()
        move = self._play(temperature)
        return move

    def play_action(self, action, player):
        self._set_node(action)

    def _set_root(self, node):
        if self.root_node:
            self.root_node.active_root = False
        self.root_node = node
        self.root_node.active_root = True

    def _prune(self, action):
        self._set_root(self.root_node.children[action])
        self.root_node.parent.children = [self.root_node]

    def _set_node(self, action):
        node = self.root_node.children[action]
        if self.root_node.children[action].n == 0:
            # TODO check if leaf and n > 0??
            # TODO might not backup??????????
            node, v = self._expand_node(self.root_node, action,
                                        self.root_node.player)
            node.backup(v)
            node.v = v
        self._set_root(node)

    def update(self, s, a, r, done, next_s):
        self.push_to_queue(s, a, r, done, next_s)
        self.pull_from_queue()
        if self.ready:
            self.update_from_memory()

    def pull_from_queue(self):
        while not self.memory_queue.empty():
            experience = self.memory_queue.get()
            self.memory.add(experience)

    def push_to_queue(self, s, a, r, done, next_s):
        if done:
            for experience in self.temp_memory:
                experience = experience._replace(
                    actual_val=torch.tensor(r).float().to(device))
                self.memory_queue.put(experience)
            self.temp_memory = []

    def loss(self, batch):
        batch_t = Move(*zip(*batch))  # transposed batch
        s, actual_val, tree_probs = batch_t
        s_batch = torch.stack(s)
        net_probs_batch, predict_val_batch = self.evaluator.forward(s_batch)
        predict_val_batch = predict_val_batch.view(-1)
        actual_val_batch = torch.stack(actual_val)
        tree_probs_batch = torch.stack(tree_probs)

        c = MSELossFlat(floatify=True)
        value_loss = c(predict_val_batch, actual_val_batch)

        prob_loss = -(net_probs_batch.log() *
                      tree_probs_batch).sum() / net_probs_batch.size()[0]

        loss = value_loss + prob_loss
        return loss

    def update_from_memory(self):
        batch = self.memory.sample(self.batch_size)

        loss = self.loss(batch)

        self.optim.zero_grad()

        if APEX_AVAILABLE:
            with amp.scale_loss(loss, self.optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        # for param in self.evaluator.parameters():  # see if this ends up doing anything - should just be relu
        #     param.grad.data.clamp_(-1, 1)
        self.optim.step()

    def _play(self, temp=0.05):
        if self.evaluating:  # Might want to just make this greedy
            temp = temp / 20  # More likely to choose higher visited nodes

        play_probs = [
            child.play_prob(temp) for child in self.root_node.children
        ]
        play_probs = play_probs / sum(play_probs)

        action = np.random.choice(self.actions, p=play_probs)

        self.moves_played += 1

        self.temp_memory.append(
            Move(
                torch.tensor(self.root_node.state).to(device),
                None,
                torch.tensor(play_probs).float().to(device),
            ))
        return action

    def _expand_node(self, parent_node, action, player=1):
        env = self.env_gen()
        env.set_state(copy.copy(parent_node.state))
        s, r, done, _ = env.step(action, player=player)
        r = r * player
        if done:
            v = r
            child_node = parent_node.children[action]
        else:
            probs, v = self.evaluator(s, parent_node.player)
            child_node = parent_node.children[action]
            child_node.create_children(probs, env.valid_moves())
            assert child_node.children
        child_node.state = s
        return child_node, v

    def search(self):
        if self.evaluating:
            self.root_node.add_noise(
            )  # Might want to remove this in evaluation?
        for i in range(self.iterations):
            node = self.root_node
            while True:
                select_probs = [
                    child.select_prob if child.valid else -10000000000
                    for child in node.children
                ]  # real big negative nuber
                action = np.argmax(select_probs +
                                   0.000001 * np.random.rand(self.actions))
                if node.children[action].is_leaf:
                    node, v = self._expand_node(node, action, node.player)
                    node.backup(v)
                    node.v = v
                    break
                else:
                    node = node.children[action]
        if self.evaluating:
            self.root_node.remove_noise()  # Don't think this is necessary?

    def load_state_dict(self, state_dict, target=False):
        self.evaluator.load_state_dict(state_dict)

    # determines when a neural net has enough data to train
    @property
    def ready(self):
        # Hard code value for the moment
        return len(self.memory) >= self.min_memory and self.update_nn

    def state_dict(self):
        return self.evaluator.state_dict()

    def update_target_net(self):
        # No target net so pass
        pass

    def deduplicate(self):
        self.memory.deduplicate("state", ["actual_val", "tree_probs"], Move)

    def train(self, train_state=True):
        # Sets training true/false
        return self.evaluator.train(train_state)

    def evaluate(self, evaluate_state=False):
        # like train - sets evaluate state
        self.evaluating = evaluate_state
class Elo():
    MODEL_SAVE_FILE = ".ELO_MODEL"
    RESULT_SAVE_FILE = ".ELO_RESULT"
    ELO_VALUE_SAVE_FILE = ".ELO_VALUE"

    ELO_CONSTANT=400

    def __init__(self, env=Connect4Env):
        path = os.path.dirname(__file__)
        self.model_shelf = shelve.open(os.path.join(path, self.MODEL_SAVE_FILE))
        self.result_shelf = shelve.open(os.path.join(path, self.RESULT_SAVE_FILE))
        self.elo_value_shelf = shelve.open(os.path.join(path, self.ELO_VALUE_SAVE_FILE))

        self.env = env

        self.memory = Memory()

        atexit.register(self._close)

    def _close(self):
        self.model_shelf.close()
        self.result_shelf.close()
        self.elo_value_shelf.close()

    def add_model(self, name, model_container):
        try:
            if self.model_shelf[name]:
                raise ValueError("Model name already in use")
        except KeyError:
            self.model_shelf[name] = model_container
            print(f"added model {name}")

    def compare_models(self, *args):
        combinations = itertools.combinations(args, 2)

        for model_1, model_2 in combinations:
            self._compare(model_1, model_2)

    def _compare(self, model_1, model_2, num_games=100):
        assert model_1 != model_2
        assert "_" not in model_1
        assert "_" not in model_2
        if model_1 > model_2:
            key = f"{model_1}__{model_2}"
            swap = False
        else:
            key = f"{model_2}__{model_1}"
            swap = True
        if key in self.result_shelf:
            old_results = self.result_shelf[key]
        else:
            old_results = {"wins": 0, "draws": 0, "losses": 0}
        new_results = self._get_results(model_1, model_2)
        if swap:
            new_results_ordered = {"wins": new_results["losses"], "draws": new_results["draws"],
                                   "losses": new_results["wins"]}
        else:
            new_results_ordered = new_results
        total_results = {status: new_results_ordered[status] + old_results[status] for status in
                         ("wins", "draws", "losses")}
        self.result_shelf[key] = total_results

    def _get_results(self, model_1, model_2, num_games=100):

        scheduler = self_play_parallel.SelfPlayScheduler(policy_container=self.model_shelf[model_1],
                                                         opposing_policy_container=self.model_shelf[model_2],
                                                         evaluation_policy_container=self.model_shelf[model_2],
                                                         env_gen=self.env, epoch_length=num_games, initial_games=0,
                                                         self_play=False, save_dir=None)
        _, breakdown = scheduler.compare_models()
        results = {status: breakdown["first"][status] + breakdown["second"][status] for status in
                   ("wins", "draws", "losses")}
        print(f"{model_1} wins: {results['wins']} {model_2} wins: {results['losses']} draws: {results['draws']}")
        return results

    def calculate_elo_2(self,anchor_model="random", anchor_elo=0):
        k_factor = 5
        models = list(self.model_shelf.keys())

        model_indices = {model: i for i, model in enumerate(model for model in models if model != anchor_model)}
        if "elo" in self.elo_value_shelf:
            elo_values = self.elo_value_shelf["elo"]
            initial_weights = [elo_values.get(model, 0) for model in models if model != anchor_model]
        else:
            initial_weights = None

        self._convert_memory2(model_indices)
        model_qs = {model: torch.ones(1, requires_grad=True) for model in models}  # q = 10^(rating/400)
        model_qs[anchor_model] = torch.tensor(10 ** (anchor_elo / self.ELO_CONSTANT), requires_grad=False)
        epoch_length = 1000
        num_epochs = 200
        batch_size = 32
        elo_net = EloNetwork(len(models), initial_weights)
        optim = torch.optim.SGD([elo_net.elo_vals.weight], lr=400)
        for i in range(num_epochs):
            for j in range(epoch_length):
                optim.zero_grad()

                batch = self.memory.sample(batch_size)
                batch_t = result_container(*zip(*batch))
                players, results = batch_t
                players = torch.stack(players)
                results = torch.stack(results)
                expected_results = elo_net(players)
                loss = elo_net.loss(expected_results, results)

                loss.backward()
                optim.step()
            for param_group in optim.param_groups:
                param_group['lr'] = param_group['lr'] * 0.99

        model_elos = {model: elo_net.elo_vals.weight.tolist()[model_indices[model]] for model in models if
                      model != anchor_model}
        model_elos[anchor_model] = anchor_elo
        self.elo_value_shelf["elo"] = model_elos
        print(model_elos)
        return model_elos




    def calculate_elo(self, anchor_model="random", anchor_elo=0):
        models = list(self.model_shelf.keys())

        model_indices = {model: i for i, model in enumerate(model for model in models if model != anchor_model)}
        if "elo" in self.elo_value_shelf:
            elo_values = self.elo_value_shelf["elo"]
            initial_weights = [elo_values.get(model, 0) for model in models if model != anchor_model]
            print(initial_weights)
        else:
            initial_weights = None

        self._convert_memory(model_indices)

        model_qs = {model: torch.ones(1, requires_grad=True) for model in models}  # q = 10^(rating/400)
        model_qs[anchor_model] = torch.tensor(10 ** (anchor_elo / self.ELO_CONSTANT), requires_grad=False)
        epoch_length = 1000
        num_epochs = 200
        batch_size = 32
        elo_net = EloNetwork(len(models), initial_weights)
        optim = torch.optim.SGD([elo_net.elo_vals.weight], lr=400)

        for i in range(num_epochs):
            for j in range(epoch_length):
                optim.zero_grad()

                batch = self.memory.sample(batch_size)
                batch_t = result_container(*zip(*batch))
                players, results = batch_t
                players = torch.stack(players)
                results = torch.stack(results)
                expected_results = elo_net(players)
                loss = elo_net.loss(expected_results, results)

                loss.backward()
                optim.step()
            for param_group in optim.param_groups:
                param_group['lr'] = param_group['lr'] * 0.99

        model_elos = {model: elo_net.elo_vals.weight.tolist()[model_indices[model]] for model in models if
                      model != anchor_model}
        model_elos[anchor_model] = anchor_elo
        self.elo_value_shelf["elo"] = model_elos
        print(model_elos)
        return model_elos

    def _convert_memory(self, model_indices):

        keys = list(self.result_shelf.keys())
        for key in keys:
            model1, model2 = key.split("__")

            val_1 = self._onehot(model1, model_indices)
            val_2 = self._onehot(model2, model_indices)
            players = torch.stack((val_1, val_2), 1).t()

            results = self.result_shelf[key]
            result_map = {"wins": 1, "losses": 0, "draws": 0.5}
            for result, value in result_map.items():
                for _ in range(results[result]):
                    self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float)))


    def _convert_memory2(self, model_indices):

        keys = list(self.result_shelf.keys())
        for key in keys:
            model1, model2 = key.split("__")

            val_1 = self._onehot(model1, model_indices)
            val_2 = self._onehot(model2, model_indices)
            players = torch.stack((val_1, val_2), 1).t()

            results = self.result_shelf[key]
            result_map = {"wins": 1, "losses": 0, "draws": 0.5}
            for result, value in result_map.items():
                for _ in range(results[result]):
                    if result == "wins":
                        self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float)))
                        self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float)))

                    if result == "losses":
                        self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float)))
                        self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float)))
                    if result == "draws":
                        self.memory.add(result_container(players, torch.tensor(0, dtype=torch.float)))
                        self.memory.add(result_container(players, torch.tensor(1, dtype=torch.float)))


    def _onehot(self, model, model_indices):
        model_idx = model_indices[model] if model in model_indices else None
        if model_idx is not None:
            val = torch.nn.functional.one_hot(torch.tensor(model_idx), len(model_indices))
        else:
            val = torch.zeros(len(model_indices), dtype=torch.long)
        return val

    def manual_play(self, model_name):
        model = self.model_shelf[model_name].setup()
        model.train(False)
        model.evaluate(True)

        manual = ManualPlay(Connect4Env(), model)
        manual.play()

    def observe(self, model_name, opponent_model_name):

        model = self.model_shelf[model_name].setup()
        opponent_model = self.model_shelf[opponent_model_name].setup()

        model.train(False)
        opponent_model.train(False)
        model.evaluate(True)
        opponent_model.evaluate(True)

        view = View(Connect4Env(), model,opponent_model)
        view.play()