def self_play_repeat(self, max_timestep_alice, max_timestep_bob, episode,
                         tolerance, stop_update, set_update, alternate,
                         train_teacher):
        tA = 0
        tB = 0
        tSet = 0

        seed = random.randint(0, 2**32 - 1)

        np.random.seed(seed)

        phase = 0

        s = self.env.reset()

        landmarks = np.random.uniform(-1, 1, (self.env.n_agents, 2))
        landmarks_flags = np.ones(self.env.n_agents)

        s = utils.state_to_teacher_state(s, landmarks, landmarks_flags)
        s = utils.add_phase_to_state(s, phase)
        s_init = copy.deepcopy(s)

        subs_learner = self.get_learners_subpolicies()
        subs_teacher = self.get_teachers_subpolicies()
        teacher_state = {}
        learner_state = {}

        hidden_actor = None
        hidden_critic = None

        while True:

            tA = tA + 1

            input = np.hstack((np.array(s_init), np.array(s)))
            input_t = torch.Tensor(input)

            actions_detached = self.teachers.act(input_t, subs_teacher)

            s_t, r, done, i = self.env.step(copy.deepcopy(actions_detached))
            s_t = utils.state_to_teacher_state(s_t, landmarks, landmarks_flags)
            s_t = utils.add_phase_to_state(s_t, phase)
            """
                ALWAYS REQUEST STOP CONTROLLER FIRST WITH CURRENT ACTION MASK
            """
            mask = self.get_mask(phase)
            action, log_prob, value, hidden_actor, hidden_critic = self.stop.act(
                input_t.flatten(),
                hidden_actor=hidden_actor,
                hidden_critic=hidden_critic,
                mask=torch.Tensor(mask))
            action_item = action.item()

            self.stop.memory.current_seq.append(input.flatten())
            self.stop.memory.log_prob.append(log_prob)
            self.stop.memory.actions.append(action)
            self.stop.memory.values.append(value)
            self.stop.memory.masks.append(mask)
            """
                IF ACTION IS 0 : JUST LET THE CONTROLLERS MOVE ON NEXT STEP
                OTHERWISE : HANDLE ACTION AND GENERATE SCENARIO ACCORDINGLY
                
                double check on bases_set should not be necessary thanks to action mask, but we never know...
                second check on tA ensures a fully defined environment when control is passed to BOB
            """
            if action_item == 1 and phase == 0:
                landmarks = np.array([
                    copy.deepcopy(agent.get_pos()) for agent in self.env.agents
                ])
                landmarks_flags = np.zeros(landmarks_flags.shape)

                tSet = tA
                phase = 1

            if action_item == 2 or tA >= max_timestep_alice:
                finish_zone, finish_zone_radius = utils.compute_finish_zone(
                    np.array([
                        copy.deepcopy(agent.get_pos())
                        for agent in self.env.agents
                    ]))

                teacher_state['s'] = copy.deepcopy(
                    np.hstack((np.array(s_init), np.array(s))))
                teacher_state['s_t'] = copy.deepcopy(
                    np.hstack((np.array(s_init), np.array(s_t))))
                teacher_state['a'] = copy.deepcopy(actions_detached)
                teacher_state['d'] = True

                break

            self.stop.memory.rewards.append(0)
            self.stop.memory.dones.append(False)

            obs = np.hstack((np.array(s_init), np.array(s)))

            obs_t = np.hstack((np.array(s_init), np.array(s_t)))

            self.teachers.push_sample(obs, actions_detached, [0] * self.env.n,
                                      False, obs_t, subs_teacher)
            self.teachers.train(subs_learner)

            s = s_t

        np.random.seed(seed)

        s = self.env.reset(landmark_positions=landmarks,
                           landmark_flags=landmarks_flags,
                           finish_zone_position=finish_zone,
                           finish_zone_radius=finish_zone_radius)

        while True:

            tB = tB + 1

            actions_detached = self.learners.act(s, subs_learner)

            s_t, _, solved, _ = self.env.step(copy.deepcopy(actions_detached))

            if tA + tB >= max_timestep_bob or solved:
                learner_state['s'] = copy.deepcopy(s)
                learner_state['s_t'] = copy.deepcopy(s_t)
                learner_state['a'] = copy.deepcopy(actions_detached)
                learner_state['d'] = solved
                break

            self.learners.push_sample(s, actions_detached, [0] * self.env.n,
                                      False, s_t, subs_learner)
            self.learners.train(subs_teacher)

            s = s_t

        if not solved:
            tB = max_timestep_bob - tA

        R_A = [self.self_play_gamma * max(0, tB - tA)] * self.env.n
        R_B = [self.self_play_gamma * -1 * tB] * self.env.n

        self.teachers.push_sample(teacher_state['s'], teacher_state['a'], R_A,
                                  teacher_state['d'], teacher_state['s_t'],
                                  subs_teacher)
        self.learners.push_sample(learner_state['s'], learner_state['a'], R_B,
                                  learner_state['d'], learner_state['s_t'],
                                  subs_learner)

        self.stop.memory.rewards.append(R_A[0])
        self.stop.memory.dones.append(True)
        self.stop.memory.new_seq()

        nb_bases = np.array([
            landmark.get_activated() for landmark in self.env.landmarks
        ]).astype(int).sum()

        self.writer.add_scalars(
            "Self play BOB bases activated {}".format(self.run_id),
            {'Bases activated': nb_bases}, episode)
        self.writer.add_scalars(
            "Self play episode time {}".format(self.run_id), {
                'ALICE TIME': tA,
                'BOB TIME': tB,
                'SET TIME': tSet
            }, episode)
        self.writer.add_scalars("Self play rewards {}".format(self.run_id), {
            "ALICE REWARD": R_A[0],
            'BOB REWARD': R_B[0]
        }, episode)
        self.writer.add_scalars(
            "Self play finish zone radius {}".format(self.run_id),
            {"FINISH ZONE RADIUS": finish_zone_radius}, episode)

        print("TA : {} TB : {} TS : {} RA : {} RB {} {}".format(
            tA, tB, tSet, R_A, R_B, "SOLVED" if solved else ""))

        if episode % stop_update == 0:
            self.stop.update()

        return tA, tB
    def self_play_repeat(self, max_timestep_alice, max_timestep_bob, episode,
                         tolerance, stop_update, set_update, alternate,
                         train_teacher):
        tA = 0
        tB = 0
        tSet = 0

        seed = random.randint(0, 2**32 - 1)

        np.random.seed(seed)

        phase = 0

        s = self.env.reset()

        landmarks = np.random.uniform(-1, 1, (self.env.n_agents, 2))
        landmarks_flags = np.ones(self.env.n_agents)
        """ One hot encode the learner that should succeed """
        target_learner = np.zeros(self.n_learners)
        target_learner[np.random.randint(self.n_learners)] = 1

        s = utils.state_to_teacher_state(s, landmarks, landmarks_flags,
                                         target_learner)
        s = utils.add_phase_to_state(s, phase)
        s_init = copy.deepcopy(s)

        subs_learner = [
            self.get_learners_subpolicies() for _ in range(self.n_learners)
        ]
        subs_teacher = self.get_teachers_subpolicies()
        teacher_state = {}
        learner_state = [{} for _ in range(self.n_learners)]

        while True:

            tA = tA + 1

            input = np.hstack((np.array(s_init), np.array(s)))
            input_t = torch.Tensor(input)

            actions_detached = self.teachers.act(input_t, subs_teacher)

            s_t, r, done, i = self.env.step(copy.deepcopy(actions_detached))
            s_t = utils.state_to_teacher_state(s_t, landmarks, landmarks_flags,
                                               target_learner)
            s_t = utils.add_phase_to_state(s_t, phase)
            """
                ALWAYS REQUEST STOP CONTROLLER FIRST WITH CURRENT ACTION MASK
            """
            mask = self.get_mask(phase)
            action, log_prob, value = self.stop.act(input_t.flatten(),
                                                    torch.Tensor(mask))
            action_item = action.item()

            self.stop.memory.states.append(input.flatten())
            self.stop.memory.log_prob.append(log_prob)
            self.stop.memory.actions.append(action)
            self.stop.memory.values.append(value)
            self.stop.memory.masks.append(mask)
            """
                IF ACTION IS 0 : JUST LET THE CONTROLLERS MOVE ON NEXT STEP
                OTHERWISE : HANDLE ACTION AND GENERATE SCENARIO ACCORDINGLY
                
                double check on bases_set should not be necessary thanks to action mask, but we never know...
                second check on tA ensures a fully defined environment when control is passed to BOB
            """
            if action_item == 1 and phase == 0:
                landmarks = np.array([
                    copy.deepcopy(agent.get_pos()) for agent in self.env.agents
                ])
                landmarks_flags = np.zeros(landmarks_flags.shape)

                tSet = tA
                phase = 1

            if action_item == 2 or tA >= max_timestep_alice:
                finish_zone, finish_zone_radius = utils.compute_finish_zone(
                    np.array([
                        copy.deepcopy(agent.get_pos())
                        for agent in self.env.agents
                    ]))

                teacher_state['s'] = copy.deepcopy(
                    np.hstack((np.array(s_init), np.array(s))))
                teacher_state['s_t'] = copy.deepcopy(
                    np.hstack((np.array(s_init), np.array(s_t))))
                teacher_state['a'] = copy.deepcopy(actions_detached)
                teacher_state['d'] = True

                break

            self.stop.memory.rewards.append(0)
            self.stop.memory.dones.append(False)

            obs = np.hstack((np.array(s_init), np.array(s)))

            obs_t = np.hstack((np.array(s_init), np.array(s_t)))

            self.teachers.push_sample(obs, actions_detached, [0] * self.env.n,
                                      False, obs_t, subs_teacher)

            s = s_t

        learners_results = np.zeros(self.n_learners)
        learners_steps = np.zeros(self.n_learners).astype(int)

        for learner in range(self.n_learners):

            np.random.seed(seed)
            s = self.env.reset(landmark_positions=landmarks,
                               landmark_flags=landmarks_flags,
                               finish_zone_position=finish_zone,
                               finish_zone_radius=finish_zone_radius)

            while True:

                learners_steps[learner] += 1

                actions_detached = self.learners[learner].act(
                    s, subs_learner[learner])

                s_t, _, solved, _ = self.env.step(
                    copy.deepcopy(actions_detached))

                if learners_steps[learner] >= max_timestep_bob or solved:
                    learner_state[learner]['s'] = copy.deepcopy(s)
                    learner_state[learner]['s_t'] = copy.deepcopy(s_t)
                    learner_state[learner]['a'] = copy.deepcopy(
                        actions_detached)
                    learner_state[learner]['d'] = solved
                    break

                self.learners[learner].push_sample(s, actions_detached,
                                                   [0] * self.env.n, False,
                                                   s_t, subs_learner[learner])

                s = s_t

            learners_results[learner] = 1 if solved else 0

        R_A = [
            2 * learners_results[np.argmax(target_learner)] -
            np.sum(learners_results)
        ] * self.env.n

        self.teachers.push_sample(teacher_state['s'], teacher_state['a'], R_A,
                                  teacher_state['d'], teacher_state['s_t'],
                                  subs_teacher)

        for learner in range(self.n_learners):
            self.learners[learner].push_sample(
                learner_state[learner]['s'], learner_state[learner]['a'],
                [learners_results[learner]] * self.env.n,
                bool(learners_results[learner]), learner_state[learner]['s_t'],
                subs_learner[learner])

        self.stop.memory.rewards.append(R_A[0])
        self.stop.memory.dones.append(True)

        nb_bases = np.array([
            landmark.get_activated() for landmark in self.env.landmarks
        ]).astype(int).sum()

        self.writer.add_scalars(
            "Self play BOB bases activated {}".format(self.run_id),
            {'Bases activated': nb_bases}, episode)
        self.writer.add_scalars(
            "Self play episode time {}".format(self.run_id), {
                'ALICE TIME': tA,
                'SET TIME': tSet
            }, episode)
        self.writer.add_scalars(
            "Self play episode time {}".format(self.run_id), {
                'BOB {} TIME'.format(i): learners_steps[i]
                for i in range(self.n_learners)
            })
        self.writer.add_scalars("Self play rewards {}".format(self.run_id),
                                {"ALICE REWARD": R_A[0]}, episode)
        self.writer.add_scalars(
            "Self play rewards {}".format(self.run_id), {
                "BOB REWARD {}".format(i): learners_results[i]
                for i in range(self.n_learners)
            }, episode)
        self.writer.add_scalars(
            "Self play finish zone radius {}".format(self.run_id),
            {"FINISH ZONE RADIUS": finish_zone_radius}, episode)

        print("TA : {} TB : {} TS : {} RA : {} RB {}".format(
            tA, learners_steps, tSet, R_A, learners_results))

        if alternate is False or train_teacher is True:
            for _ in range(tA):
                self.teachers.train(subs_teacher)

            if episode % stop_update == 0:
                #if len(self.stop.memory) >= self.stop.update_step:
                self.stop.update()

        if alternate is False or train_teacher is False:
            for learner in range(self.n_learners):
                for _ in range(learners_steps[learner]):
                    self.learners[learner].train(subs_learner[learner])

        return tA, tB
    def explore_self_play_repeat(self,
                                 tMAX,
                                 tolerance,
                                 set_probability=0.5,
                                 stop_probability=0.5):

        tA = 0
        tB = 0
        solved = False

        seed = random.randint(0, 2**32 - 1)
        np.random.seed(seed)
        phase = 0

        s = self.env.reset()

        landmarks = np.random.uniform(-1, 1, (self.env.n_agents, 2))
        landmarks_flags = np.ones(self.env.n_agents)

        s = utils.state_to_teacher_state(s, landmarks, landmarks_flags)
        s = utils.add_phase_to_state(s, phase)

        s_init = copy.deepcopy(s)

        subs_learner = self.get_learners_subpolicies()
        subs_teacher = self.get_teachers_subpolicies()

        teacher_state = {}
        learner_state = {}

        stop_flag = False
        set_flag = False

        while True:

            tA = tA + 1

            if not set_flag:

                set_flag = np.random.rand() < set_probability

                if tA >= tMAX:
                    set_flag = True

                if set_flag:
                    landmarks = np.array([
                        copy.deepcopy(agent.get_pos())
                        for agent in self.env.agents
                    ])
                    landmarks_flags = np.zeros(landmarks_flags.shape)
                    phase = 1

            actions_detached = self.teachers.random_act()
            s_t, r, done, i = self.env.step(copy.deepcopy(actions_detached))
            s_t = utils.state_to_teacher_state(s_t, landmarks, landmarks_flags)
            s_t = utils.add_phase_to_state(s_t, phase)

            stop_flag = np.random.rand() < stop_probability

            if tA >= tMAX:
                stop_flag = True

            if stop_flag or tA >= tMAX:

                finish_zone, finish_zone_radius = utils.compute_finish_zone(
                    np.array([
                        copy.deepcopy(agent.get_pos())
                        for agent in self.env.agents
                    ]))

                teacher_state['s'] = copy.deepcopy(s)
                teacher_state['s_t'] = copy.deepcopy(s_t)
                teacher_state['a'] = copy.deepcopy(actions_detached)
                teacher_state['d'] = True
                s = s_t
                break

            obs = np.hstack((np.array(s_init), np.array(s)))

            obs_t = np.hstack((np.array(s_init), np.array(s_t)))

            self.teachers.push_sample(obs, actions_detached, [0] * self.env.n,
                                      False, obs_t, subs_teacher)
            s = s_t

        s_final = copy.deepcopy(s_t)
        np.random.seed(seed)

        s = self.env.reset(landmark_positions=landmarks,
                           finish_zone_position=finish_zone,
                           finish_zone_radius=finish_zone_radius)

        save_s = None
        save_s_t = None

        while True:

            tB = tB + 1
            actions_detached = self.learners.random_act()
            s_t, _, solved, _ = self.env.step(copy.deepcopy(actions_detached))

            if tA + tB >= tMAX or solved:
                learner_state['s'] = copy.deepcopy(s)
                learner_state['s_t'] = copy.deepcopy(s_t)
                learner_state['a'] = copy.deepcopy(actions_detached)
                learner_state['d'] = solved
                break

            reward = 0

            self.learners.push_sample(s, actions_detached, [0] * self.env.n,
                                      solved, s_t, subs_learner)

            s = s_t

        if solved is False:
            tB = tMAX - tA

        R_A = [self.self_play_gamma * max(0, tB - tA)] * self.env.n
        R_B = [self.self_play_gamma * -1 * tB] * self.env.n

        obs = np.hstack((np.array(s_init), np.array(teacher_state['s'])))
        obs_t = np.hstack((np.array(s_init), np.array(teacher_state['s_t'])))

        self.teachers.push_sample(obs, teacher_state['a'], R_A,
                                  teacher_state['d'], obs_t, subs_teacher)
        self.learners.push_sample(learner_state['s'], learner_state['a'], R_B,
                                  solved, learner_state['s_t'], subs_learner)
    def explore_self_play_repeat(self,
                                 max_timestep_alice,
                                 max_timestep_bob,
                                 set_probability=0.5,
                                 stop_probability=0.5):

        tA = 0
        tB = 0
        solved = False

        seed = random.randint(0, 2**32 - 1)
        np.random.seed(seed)
        phase = 0

        s = self.env.reset()

        landmarks = np.random.uniform(-1, 1, (self.env.n_agents, 2))
        landmarks_flags = np.ones(self.env.n_agents)
        """ One hot encode the learner that should succeed """
        target_learner = np.zeros(self.n_learners)
        target_learner[np.random.randint(self.n_learners)] = 1

        s = utils.state_to_teacher_state(s, landmarks, landmarks_flags,
                                         target_learner)
        s = utils.add_phase_to_state(s, phase)

        s_init = copy.deepcopy(s)

        subs_learner = [
            self.get_learners_subpolicies() for _ in range(self.n_learners)
        ]
        subs_teacher = self.get_teachers_subpolicies()

        teacher_state = {}
        learner_state = [{} for _ in range(self.n_learners)]

        stop_flag = False
        set_flag = False

        while True:

            tA = tA + 1

            if not set_flag:

                set_flag = np.random.rand() < set_probability

                if tA >= max_timestep_alice:
                    set_flag = True

                if set_flag:
                    landmarks = np.array([
                        copy.deepcopy(agent.get_pos())
                        for agent in self.env.agents
                    ])
                    landmarks_flags = np.zeros(landmarks_flags.shape)
                    phase = 1

            actions_detached = self.teachers.random_act()
            s_t, r, done, i = self.env.step(copy.deepcopy(actions_detached))
            s_t = utils.state_to_teacher_state(s_t, landmarks, landmarks_flags,
                                               target_learner)
            s_t = utils.add_phase_to_state(s_t, phase)

            stop_flag = np.random.rand() < stop_probability

            if tA >= max_timestep_alice:
                stop_flag = True

            if stop_flag or tA >= max_timestep_alice:

                finish_zone, finish_zone_radius = utils.compute_finish_zone(
                    np.array([
                        copy.deepcopy(agent.get_pos())
                        for agent in self.env.agents
                    ]))

                teacher_state['s'] = copy.deepcopy(s)
                teacher_state['s_t'] = copy.deepcopy(s_t)
                teacher_state['a'] = copy.deepcopy(actions_detached)
                teacher_state['d'] = True
                s = s_t
                break

            obs = np.hstack((np.array(s_init), np.array(s)))

            obs_t = np.hstack((np.array(s_init), np.array(s_t)))

            self.teachers.push_sample(obs, actions_detached, [0] * self.env.n,
                                      False, obs_t, subs_teacher)
            s = s_t

        learners_results = np.zeros(self.n_learners)
        learners_step = np.zeros(self.n_learners)

        for learner in range(self.n_learners):

            np.random.seed(seed)
            s = self.env.reset(landmark_positions=landmarks,
                               finish_zone_position=finish_zone,
                               finish_zone_radius=finish_zone_radius)

            while True:

                learners_step[learner] += 1

                actions_detached = self.learners[learner].random_act()
                s_t, _, solved, _ = self.env.step(
                    copy.deepcopy(actions_detached))

                if learners_step[learner] >= max_timestep_bob or solved:

                    learner_state[learner]['s'] = copy.deepcopy(s)
                    learner_state[learner]['s_t'] = copy.deepcopy(s_t)
                    learner_state[learner]['a'] = copy.deepcopy(
                        actions_detached)
                    learner_state[learner]['d'] = solved
                    break

                reward = 0

                self.learners[learner].push_sample(s, actions_detached,
                                                   [0] * self.env.n, solved,
                                                   s_t, subs_learner[learner])

                s = s_t

            learners_results[learner] = 1 if solved else 0

        obs = np.hstack((np.array(s_init), np.array(teacher_state['s'])))
        obs_t = np.hstack((np.array(s_init), np.array(teacher_state['s_t'])))

        R_A = [
            2 * learners_results[np.argmax(target_learner)] -
            np.sum(learners_results)
        ] * self.env.n

        self.teachers.push_sample(obs, teacher_state['a'], R_A,
                                  teacher_state['d'], obs_t, subs_teacher)

        for learner in range(self.n_learners):
            self.learners[learner].push_sample(
                learner_state[learner]['s'], learner_state[learner]['a'],
                [learners_results[learner]] * self.env.n, solved,
                learner_state[learner]['s_t'], subs_learner[learner])