Example #1
0
 if cfg.test:
     cfg_log.load(CFG_FILE)
     # load data
     df_train, df_test, df_rate = load_data(cfg)
     rl_returns = []
     naked_returns = []
     covered_returns = []
     delta_returns = []
     env = DummyVecEnv([lambda: HedgeEnv(df_test, df_rate, cfg)])
     T = env.get_attr('T')[0]
     model = DDPG(MlpPolicy, env, verbose=1)
     model.load(TEST_MODEL)
     delta = DeltaHedge()
     for i in range(cfg.test_times):
         # rl
         env.set_attr("b_rl", True)
         obs = env.reset()  # every time, create a new transaction
         naked_returns.append(naked(env))
         covered_returns.append(covered(env))
         for i in range(T):
             action, _states = model.predict(obs)
             obs, rewards, done, info = env.step(action)
             # env.render()
         rl_returns.append(env.get_attr('final_reward')[0])
         env.env_method('restart')  # only trace back to the initial state
         env.set_attr("b_rl", False)
         # delta
         for i in range(T):
             action = delta.make_decision(env)
             obs, rewards, done, info = env.step(action)
             # env.render()
Example #2
0
def manytest(
        freq_dict,
        dir_dict,
        threshold_list,
        DATE_PREFIX="0418",
        startyear=2000,
        endyear=2019,
        SAVE_DIR_PREFIX="./output/",
        LOAD_DIR="./output/306",
        MODEL_FILE_PREFIX="BRZ_TW_NASDAQ-Selected_Trans-withleakage+RSI-200000-"
):
    ENV_PARAM = {}
    # Set Default Values
    for key in DEFAULT_PARAMETER:
        if (not key in ENV_PARAM.keys()) or (
                not type(ENV_PARAM[key]) is type(DEFAULT_PARAMETER[key])):
            ENV_PARAM[key] = DEFAULT_PARAMETER[key]

    testEnvSettings = getTestEnvSettings()
    model_list = []
    for i in range(10):
        model_name = MODEL_FILE_PREFIX + str(i) + "-model.model"
        model = PPO2.load(path.join(LOAD_DIR, model_name))
        model_list.append(model)

    for freq in dir_dict:
        SAVE_DIR = "./output/" + dir_dict[freq]
        ENV_PARAM['SAVE_DIR'] = SAVE_DIR
        if not os.path.exists(SAVE_DIR):
            os.makedirs(SAVE_DIR)
        for thres in threshold_list:
            ENV_PARAM['MDD_window'] = freq_dict[freq]
            ENV_PARAM['MDD_threshold'] = thres

            testEnv = DummyVecEnv([
                lambda: RebalancingEnv(df_dict=testEnvSettings["df_dict"],
                                       col_list=testEnvSettings["col_list"],
                                       isTraining=False,
                                       env_param=ENV_PARAM)
            ])

            for start in range(startyear, endyear - 3):
                testStartDate = pd.to_datetime(str(start) + "-01-01")
                ENV_PARAM['testStartDate'] = testStartDate
                for end in range(start + 4, endyear + 1):
                    print("@@@@@@@@@@@@@@@@@@@", thres, freq, start, end,
                          "@@@@@@@@@@@@@@@@@@@")
                    testEndDate = pd.to_datetime(str(end) + "-12-31")
                    testEnv.set_attr("roughStartDate", testStartDate)
                    testEnv.set_attr("roughEndDate", testEndDate)
                    testEnv.reset()

                    VAIRABLE_PREFIX = "TEST_" + (
                        "%.2f" % thres
                    ) + "_" + freq + "Crisis_" + str(start) + "_" + str(end)
                    common_fileName_prefix = DATE_PREFIX + "-" + VAIRABLE_PREFIX + "-"
                    summary_fileName_suffix = "summary-X.out"
                    detail_fileName_suffix = "detailed-ModelNo_X.out"

                    detail_fileName_model = common_fileName_prefix + '-' + detail_fileName_suffix
                    summary_fileName_model = common_fileName_prefix + '-' + summary_fileName_suffix

                    for modelNo in range(10):
                        model = model_list[modelNo]
                        profit_list = []
                        act_profit_list = []
                        detail_list = []
                        print("\n============= START TESTING " + str(modelNo) +
                              " =============\n")
                        obs = testEnv.reset()
                        final_result = []
                        tstep = 200000
                        for testNo in range(
                            (testEndDate - testStartDate
                             ).days):  # Set index number of date as TestNo
                            action, _states = model.predict(obs)
                            if np.isnan(action).any():
                                print(testNo)
                            obs, rewards, done, info = testEnv.step(action)
                            if done:
                                print("Done")
                                break
                            profit_list.append(info[0]['profit'])
                            act_profit_list.append(info[0]['actual_profit'])
                            singleDay_record = testEnv.render(mode="detail")
                            singleDay_record['testNo'] = testNo
                            singleDay_record['rewards'] = rewards[0]
                            detail_list.append(singleDay_record)

                            if testNo % 365 == 0:
                                print("\n======= TESTING " + str(testNo) +
                                      " =======")
                                testEnv.render()

                        detail_fileName = detail_fileName_model[:-5] + str(
                            tstep) + '-' + str(
                                modelNo) + detail_fileName_model[-4:]
                        pickle.dump(
                            detail_list,
                            open(path.join(SAVE_DIR, detail_fileName), "wb"))

                        final_result.append({
                            # "trainStart": trainStartDate,
                            # "trainEnd": trainEndDate,
                            "testStart":
                            testStartDate,
                            "testEnd":
                            testEndDate,
                            "train_step":
                            tstep,
                            "mean":
                            np.mean(profit_list),
                            "max":
                            np.max(profit_list),
                            "min":
                            np.min(profit_list),
                            "std":
                            np.std(profit_list),
                            "final":
                            profit_list[-1],
                            "act_mean":
                            np.mean(act_profit_list),
                            "act_max":
                            np.max(act_profit_list),
                            "act_min":
                            np.min(act_profit_list),
                            "act_std":
                            np.std(act_profit_list),
                            "act_final":
                            act_profit_list[-1],
                            "total_shares_sold":
                            info[0]['total_shares_sold']
                        })

                        summary_fileName = summary_fileName_model[:-5] + str(
                            tstep) + ".out"
                        pickle.dump(
                            final_result,
                            open(path.join(SAVE_DIR, summary_fileName), "wb"))
                        print("********* LENTH: ", len(final_result),
                              " *********")
                        pprint.pprint(final_result[-1])
Example #3
0
class LagrangianCMDPSolver(CMDPSolverBase):
    """
    Class to solve CMDP with Lagrangian method.

    The method we use is bases on "Batch policy learning under constraints"
    by Le et al. The constrained MDP is addressed by solving a sequence of
    unconstrained MDPs. In particular, we alternate between a best response (BR)
    algorithm that solves the unconstrained problem deriving from fixing the
    value of the Lagrange multipliers and an online optimization algorithm
    that sets the multipliers based on the performance of the BR.
    """

    # TODO: Estimate the duality gap for stopping

    def __init__(self, env, br_algo, online_algo, br_kwargs=None, online_kwargs=None, _init_setup_model=True,
                 lagrangian_ronuds=10, log_training=False,
                 br_uses_vec_env=False, n_envs=1, use_sub_proc_env=True):
        """
        
        Parameters
        ----------
        env: src.envs.CMDP or None
        br_algo: stable baselines algorithm class
            Best response algorithm
        online_algo: src.online
            Online optimization algorithm class
        br_kwargs: dict
            Keyword arguments for best response
        online_kwargs: dict
            Keyword arguments for online opt algorithm
        _init_setup_model: bool
            Whether to set up the br and online upon initialization
        lagrangian_ronuds: int
            Number of times we alternate between br and online
        log_training: bool  
            Whether to log episode rewards and constraints during training
        br_uses_vec_env: bool
            Whether br algorithms needs a vectorized environment            
        n_envs: int 
            Number of environments to use (only relevant for vectorized case)
        use_sub_proc_env: bool
            Whether to use subprocesses for vectorized env (otherwise dummy 
            vec is used)
        """
        self.br_algo = br_algo
        self.online_algo = online_algo
            
        self.br_kwargs = {} if br_kwargs is None else br_kwargs
        online_kwargs = {} if online_kwargs is None else online_kwargs
        self.online_kwargs = online_kwargs.copy()
        
        # Initialize placeholders to fill when setting the environment and 
        # the model
        self.br = None
        self.online = None
        self.unconstrainedMDP = None  # The MDP resulting from Lagrangian ofCMDP
        
        self._env = None
        self.observation_space = None
        self.action_space = None
        self.env_generator = None
        self.lagrangian_rounds = lagrangian_ronuds
        self._log_training = log_training
        self.training_rewards = None
        self.training_constraints = None
        
        # Vectorized environment arguments        
        self.br_uses_vec_env = br_uses_vec_env
        self.use_sub_proc_env = use_sub_proc_env
        self.n_envs = n_envs

        self.set_env(env)

        if _init_setup_model:
            self.setup_model()

    def set_unconstrainedMDP(self):
        """
        Set up the unconstrained Lagrangian MDP.

        It can be set up either as a normal environment, a dummy vecotrized
        environment or a multiprocessing vectorized environment
        """
        assert self.online is not None, 'Need a value for Lagrange ' \
                                        'multipliers to initialize the ' \
                                        'unconstrained MDP'

        if self.br_uses_vec_env:
            # The function that generate the Lagrangian environment needs to
            # be outside the class to avoid pickling errors with
            # multiprocessing
            lagrangian_env = partial(get_lagrangian_env,
                                     cenv=None, # Passing _env here is not  necessary and slows down serialization a lot
                                     w=self.online.w,
                                     cenv_gen=self.env_generator)
            assert self.env_generator is not None, \
                'Environment generator is necessary for vectorized env'

            # With subprocesses for env
            if self.use_sub_proc_env:
                self.unconstrainedMDP = SubprocVecEnv(
                        [lagrangian_env for _ in range(self.n_envs)])

            # With dummy vec env
            else:
                self.unconstrainedMDP = DummyVecEnv(
                    [lagrangian_env for _ in range(self.n_envs)])
        else:
            lagrangian_env = partial(get_lagrangian_env,
                                     cenv=self._env,
                                     w=self.online.w,
                                     cenv_gen=self.env_generator)
            self.unconstrainedMDP = lagrangian_env()

    def _initialize_online(self):
        if self._env is not None:
            d = self._env.n_constraints + 1
            self.online_kwargs.update({'d': d})
            self.online = self.online_algo(**self.online_kwargs)
        else:
            print('Skipping online initialization since there is no env')

    def update_online(self, keep_multipliers=False):
        """
        Update online optimization algorithm.
        """
        if self.online is not None and keep_multipliers and \
                self._env.n_constraints + 1 == len(self.online.w):
            pass
        else:
            self._initialize_online()

    def setup_model(self):
        """
        Set best response.
        """
        if self.unconstrainedMDP is None:
            self.br = None
        else:
            br_kwargs = self.br_kwargs.copy()
            br_kwargs.update({'env': self.unconstrainedMDP})
            self.br = self.br_algo(**br_kwargs)


    def _setup_learn(self, seed):
        """
        check the environment, set the seed, and set the logger

        Parameters
        ----------
        seed: int
            The seed value
        """
        if self._env is None:
            raise ValueError("Error: cannot train the model without a valid environment, please set an environment with"
                             "set_env(self, env) method.")
        if seed is not None:
            set_global_seeds(seed)

    def learn(self, total_timesteps, seed=None, log=False):
        """
        Solve the CMDP alternating BR and online algorithm.

        Parameters
        ----------
        total_timesteps: int
            Total number of timesteps the algorithm is run for. Each
            Lagrangian round (i.e. alternation of br and online) is run to
            total_timesteps/self.lagrangian_rounds.
        seed: int or None
            The random seed
        log: Bool
            Print to screen some statistics about the BR training.

        Returns
        -------
        R: float
            Return when evaluating the policy learned by BR in last
            Lagrangian round
        G: np.ndarray
            Constraint when evaluating the policy learned by BR in last
            Lagrangian round
        w: np.ndarray
            Lagrange multipliers
        """

        self._setup_learn(seed)

        if total_timesteps < self.lagrangian_rounds:
            raise ValueError("There should be more time steps than Lagrangian rounds")

        # Number of timesteps per Lagrangian round
        br_time_steps = np.full(self.lagrangian_rounds, int(total_timesteps / self.lagrangian_rounds))
        br_time_steps[-1] += np.mod(total_timesteps, self.lagrangian_rounds)

        # Alternate between br and online
        for ts in br_time_steps:

            # Reset the monitor that tracks the performance of BR on the
            # unconstrained Lagrangian MDP (constraint violation is also
            # tracked)
            if self.br_uses_vec_env:
                self.unconstrainedMDP.env_method('reset_monitor')
            else:
                self.unconstrainedMDP.reset_monitor()
            self.br._init_num_timesteps()  # Reset exploration schedule

            # Train BR on unconstrained MDP
            if log:
                self.br.learn(ts, log_interval=ts)
            else:
                self.br.learn(ts, log_interval=np.inf)

            # Get training performance
            if self.br_uses_vec_env:
                # Get reward and constraints from all envs
                r_tmp = self.unconstrainedMDP.env_method(
                    'get_episode_rewards')
                g_tmp = self.unconstrainedMDP.env_method(
                    'get_episode_constraints')
                current_rewards = np.concatenate(r_tmp)
                current_constraints = np.concatenate(g_tmp)
            else:
                current_rewards = \
                    self.unconstrainedMDP.get_episode_rewards()
                current_constraints = \
                    self.unconstrainedMDP.get_episode_constraints()

            R = np.mean(current_rewards)
            G = np.mean(current_constraints, axis=0)

            # Log info about training
            if self._log_training:
                if self.training_rewards is None:
                    self.training_rewards = np.copy(current_rewards)
                else:
                    self.training_rewards = np.hstack((
                        self.training_rewards, current_rewards))
                # self.training_rewards.append(list(current_rewards))
                if self.training_constraints is None:
                    self.training_constraints = np.copy(current_constraints)
                else:
                    self.training_constraints = np.vstack((
                        self.training_constraints, current_constraints))

            # evaluate performance may be necessary for off-policy methods
            # where the deployed policy is different from the one that
            # collects data (in that case, it would make sense to adjust the
            # multipliers according to the optimized policy and not the
            # exploratory one)
            # R, G = self.evaluate_performance(int(0.2 * ts), min_episodes=5)


            # print('Evaluation r:{}\tEvaluation g {}'.format(R, G))

            # Online algorithm updates multipliers based on BR performance
            self.online.step(-np.append(G, 0))

            # Set new multipliers
            if self.br_uses_vec_env:
                self.unconstrainedMDP.set_attr('lam', self.online.w[:-1])
            else:
                self.unconstrainedMDP.lam = self.online.w[:-1]

        return R, G, self.online.w

    def predict(self, observation, state=None, mask=None, deterministic=True):
        """
        Get the best response action from an observation
        """
        if self.br is not None:
            return self.br.predict(observation, state, mask, deterministic)
        else:
            raise ValueError('Need a valid environment to setup learner and predict its action')

    def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
        if self.br is not None:
            return self.br.action_probability(observation, state, mask, actions, logp)
        else:
            raise ValueError('Need a valid environment to setup learner and predict its action probabilities')

    def evaluate_performance(self, min_steps, min_episodes):
        """
        Deploy policy learned by BR to evaluate its performance in terms of
        return and constraint violation.

        Parameters
        ----------
        min_steps: int
            Minimum number of steps that we run the environment for
        min_episodes: int
            Minimum number of episodes

        Returns
        -------
        R: float
            Average return across episodes
        G: np.ndarray
            Average constraint value across episods

        """
        if self.unconstrainedMDP is None:
            raise ValueError('Cannot reset monitor without a valid environment')

        n_episodes = 0
        n_steps = 0
        max_steps = min_steps * 5 #  Fix a timeout

        # TODO: If we move to subproc env, we should aim to use the
        #  vectorized env properly here

        if self.br_uses_vec_env:
            # This is equivalent to the non-vectorized case since we operate
            # only on one env. However, we still need to use the vectorized
            # env interface to access the individual attributes and methods.

            # Reser monitor and env
            self.unconstrainedMDP.env_method('reset_monitor')
            obs = self.unconstrainedMDP.env_method('reset', indices=0)[0]

            # Run env
            while (n_episodes < min_episodes or n_steps < min_steps) and not n_steps > max_steps:
                action, _ = self.br.predict(obs, deterministic=True)
                obs, reward, done, info = self.unconstrainedMDP.env_method(
                    'step', action, indices=0)[0]

                if done:
                    n_episodes += 1
                    obs = self.unconstrainedMDP.env_method('reset',
                                                           indices=0)[0]
                n_steps += 1

            # Compute return and contraint
            R = np.mean(self.unconstrainedMDP.env_method(
                'get_episode_rewards', indices=0)[0])
            G = np.mean(self.unconstrainedMDP.env_method(
                'get_episode_constraints', indices=0)[0], axis=0)
        else:
            # Reser monitor and env
            self.unconstrainedMDP.reset_monitor()
            obs = self.unconstrainedMDP.reset()

            # Run env
            while (n_episodes < min_episodes or n_steps < min_steps) and not n_steps > max_steps:
                action, _ = self.br.predict(obs, deterministic=True)
                obs, reward, done, info = self.unconstrainedMDP.step(action)

                if done:
                    n_episodes += 1
                    obs = self.unconstrainedMDP.reset()
                n_steps += 1

            # Compute return and contraint
            R = np.mean(self.unconstrainedMDP.get_episode_rewards())
            G = np.mean(self.unconstrainedMDP.get_episode_constraints(), axis=0)

        return R, G

    def set_env(self, env, keep_multipliers=False, reset_br=False):
        """
        Set a new environment.

        Parameters
        ----------
        env: src.envs.CMDP
        keep_multipliers: bool
        setup_model: bool
        """
        # Clean up resources if vectorized env already exists

        if isinstance(self.unconstrainedMDP, (DummyVecEnv, SubprocVecEnv)):
            self.unconstrainedMDP.close()

        # For vectorized environment we need an environment generating
        # function, otherwise we can simply set the env
        if self.br_uses_vec_env:
            if env is not None:
                assert callable(env), 'An environments generating callable is ' \
                                      'necessary for algorithms requiring a ' \
                                      'vectorized environment'

                # If necessary, this extra copy of the env can be removed.
                # Need to check all the places where _env is accessed and
                # modify them.
                super().set_env(env())
                self.env_generator = env
        else:
            super().set_env(env)
            self.env_generator = None  # Not needed in non-vectorized case

        if self.get_env() is not None:
            self.update_online(keep_multipliers)
            self.set_unconstrainedMDP()
            if reset_br or self.br is None:
                self.setup_model()
            self.br.set_env(self.unconstrainedMDP)

        self.training_rewards = None
        self.training_constraints = None

    def get_env(self):
        return super().get_env()

    def set_multipliers(self, w):
        if self.online is not None:
            if len(w) != len(self.online.w):
                raise ValueError('Multipliers must have the same length. Old ones have length {}, while new ones have '
                                 'length {}'.format(len(self.online.w), len(w)))
            else:
                self.online.w = w
        else:
            warnings.warn('There is no online algorithm to set the multipliers for')

    def get_multipliers(self):
        return self.online.w

    def get_br_params(self):
        return self.br.get_parameters()

    def set_br_params(self, params):
        self.br.load_parameters(params)

    def get_params(self):
        params = self.get_br_params()
        multipliers = self.get_multipliers()
        params.update({'multipliers': multipliers})
        return params

    def set_params(self, params):
        multipliers = params['multipliers']
        self.set_multipliers(multipliers)

        del params['multipliers']
        self.set_br_params(params)

    def get_training_performance(self):
        if not self._log_training:
            warnings.warn('Log training is set to False and no data was logged')

        return self.training_rewards, self.training_constraints

    @property
    def log_training(self):
        return self._log_training

    @log_training.setter
    def log_training(self, value):
        self._log_training = bool(value)
Example #4
0
class AsymmetricSelfPlay(TrainingSession):
    def __init__(self, model_builder, model_params, env_params, eval_env_params,
                 train_episodes, eval_episodes, num_evals,
                 switch_frequency, path, seed, num_envs=1):
        super(AsymmetricSelfPlay, self).__init__(model_params, path, seed)

        # log start time
        start_time = time.perf_counter()

        # initialize parallel training environments
        self.logger.debug("Initializing training envs...")
        env1, env2 = [], []

        for i in range(num_envs):
            # no overlap between episodes at each process
            if seed is not None:
                current_seed = seed + (train_episodes // num_envs) * i
            else:
                current_seed = None

            # create one env per process
            env1.append(lambda: LOCMDraftSelfPlayEnv(seed=current_seed,
                                                     play_first=True, **env_params))
            env2.append(lambda: LOCMDraftSelfPlayEnv(seed=current_seed,
                                                     play_first=False, **env_params))

        # wrap envs in a vectorized env
        self.env1 = DummyVecEnv(env1)
        self.env2 = DummyVecEnv(env2)

        # initialize parallel evaluating environments
        self.logger.debug("Initializing evaluation envs...")
        eval_seed = seed + train_episodes if seed is not None else None
        self.evaluator: Evaluator = Evaluator(eval_env_params, eval_episodes,
                                              eval_seed, num_envs)

        # build the models
        self.logger.debug("Building the models...")
        self.model1 = model_builder(self.env1, seed, **model_params)
        self.model1.adversary = model_builder(self.env2, seed, **model_params)
        self.model2 = model_builder(self.env2, seed, **model_params)
        self.model2.adversary = model_builder(self.env1, seed, **model_params)

        # initialize parameters of adversary models accordingly
        self.model1.adversary.load_parameters(self.model2.get_parameters(), exact_match=True)
        self.model2.adversary.load_parameters(self.model1.get_parameters(), exact_match=True)

        # set adversary models as adversary policies of the self-play envs
        def make_adversary_policy(model, env):
            def adversary_policy(obs):
                zero_completed_obs = np.zeros((num_envs,) + env.observation_space.shape)
                zero_completed_obs[0, :] = obs

                actions, _ = model.adversary.predict(zero_completed_obs)

                return actions[0]

            return adversary_policy

        self.env1.set_attr('adversary_policy',
                           make_adversary_policy(self.model1, self.env1))
        self.env2.set_attr('adversary_policy',
                           make_adversary_policy(self.model2, self.env2))

        # create necessary folders
        os.makedirs(self.path + '/role0', exist_ok=True)
        os.makedirs(self.path + '/role1', exist_ok=True)

        # set tensorflow log dirs
        self.model1.tensorflow_log = self.path + '/role0'
        self.model2.tensorflow_log = self.path + '/role1'

        # save parameters
        self.train_episodes = train_episodes
        self.eval_episodes = eval_episodes
        self.num_evals = num_evals
        self.switch_frequency = switch_frequency
        self.eval_frequency = train_episodes / num_evals
        self.num_switches = math.ceil(train_episodes / switch_frequency)

        # initialize control attributes
        self.model1.role_id, self.model2.role_id = 0, 1
        self.model1.last_eval, self.model1.next_eval = None, 0
        self.model2.last_eval, self.model2.next_eval = None, 0
        self.model1.last_switch, self.model1.next_switch = 0, self.switch_frequency
        self.model2.last_switch, self.model2.next_switch = 0, self.switch_frequency

        # initialize results
        self.checkpoints = [], []
        self.win_rates = [], []
        self.episode_lengths = [], []
        self.action_histograms = [], []

        # log end time
        end_time = time.perf_counter()

        self.logger.debug("Finished initializing training session "
                          f"({round(end_time - start_time, ndigits=3)}s).")

    def _training_callback(self, _locals=None, _globals=None):
        model = _locals['self']
        episodes_so_far = model.num_timesteps // 30

        # if it is time to evaluate, do so
        if episodes_so_far >= model.next_eval:
            # save model
            model_path = f'{self.path}/role{model.role_id}/{episodes_so_far}'
            model.save(model_path)
            save_model_as_json(model, self.params['activation'], model_path)
            self.logger.debug(f"Saved model at {model_path}.zip/json.")

            # evaluate the model
            self.logger.info(f"Evaluating model {model.role_id} "
                             f"({episodes_so_far} episodes)...")
            start_time = time.perf_counter()

            mean_reward, ep_length, act_hist = \
                self.evaluator.run(RLDraftAgent(model),
                                   play_first=model.role_id == 0)

            end_time = time.perf_counter()
            self.logger.info(f"Finished evaluating "
                             f"({round(end_time - start_time, 3)}s). "
                             f"Avg. reward: {mean_reward}")

            # save the results
            self.checkpoints[model.role_id].append(episodes_so_far)
            self.win_rates[model.role_id].append((mean_reward + 1) / 2)
            self.episode_lengths[model.role_id].append(ep_length)
            self.action_histograms[model.role_id].append(act_hist)

            # update control attributes
            model.last_eval = episodes_so_far
            model.next_eval += self.eval_frequency

            # write partial results to file
            self._save_results()

        # if training should end, return False to end training
        training_is_finished = episodes_so_far >= model.next_switch

        if training_is_finished:
            model.last_switch = episodes_so_far
            model.next_switch += self.switch_frequency

        return not training_is_finished

    def _train(self):
        # save and evaluate starting models
        self._training_callback({'self': self.model1})
        self._training_callback({'self': self.model2})

        try:
            self.logger.debug(f"Training will switch models every "
                              f"{self.switch_frequency} episodes")

            for _ in range(self.num_switches):
                # train the first player model
                self.model1.learn(total_timesteps=REALLY_BIG_INT,
                                  reset_num_timesteps=False,
                                  callback=self._training_callback)
                self.logger.debug(f"Model {self.model1.role_id} trained for "
                                  f"{self.model1.num_timesteps // 30} episodes. "
                                  f"Switching to model {self.model2.role_id}.")

                # train the second player model
                self.model2.learn(total_timesteps=REALLY_BIG_INT,
                                  reset_num_timesteps=False,
                                  callback=self._training_callback)
                self.logger.debug(f"Model {self.model2.role_id} trained for "
                                  f"{self.model2.num_timesteps // 30} episodes. "
                                  f"Switching to model {self.model1.role_id}.")

                # update parameters of adversary models
                self.model1.adversary.load_parameters(self.model2.get_parameters(),
                                                      exact_match=True)
                self.model2.adversary.load_parameters(self.model1.get_parameters(),
                                                      exact_match=True)
                self.logger.debug("Parameters of adversary networks updated.")
        except KeyboardInterrupt:
            pass

        self.logger.debug(f"Training ended at {self.model1.num_timesteps // 30} "
                          f"episodes")

        # save and evaluate final models, if not done yet
        if len(self.win_rates[0]) < self.num_evals:
            self._training_callback({'self': self.model1})

        if len(self.win_rates[1]) < self.num_evals:
            self._training_callback({'self': self.model1})

        # close the envs
        for e in (self.env1, self.env2, self.evaluator):
            e.close()