Ejemplo n.º 1
0
class Train:
    def __init__(self):
        self.parser = argparse.ArgumentParser(
            "Reinforcement Learning experiments for multiagent environments")
        self.parse_args()
        self.arglist = self.parser.parse_args()

    def parse_default_args(self):
        """
        Parse default arguments for MARL training script
        """
        # algorithm
        self.parser.add_argument("--alg",
                                 type=str,
                                 default="maddpg",
                                 help="name of the algorithm to use")
        self.parser.add_argument("--hidden_dim", default=128, type=int)

        # curiosity
        self.parser.add_argument("--curiosity",
                                 type=str,
                                 default=None,
                                 help="name of curiosity to use")
        self.parser.add_argument(
            "--joint_curiosity",
            action="store_true",
            default=False,
            help="flag if curiosity should be applied jointly for all agents",
        )
        self.parser.add_argument(
            "--curiosity_hidden_dim",
            type=int,
            default=64,
            help="curiosity internal state representation size",
        )
        self.parser.add_argument(
            "--curiosity_state_rep_size",
            type=int,
            default=64,
            help="curiosity internal state representation size",
        )
        self.parser.add_argument(
            "--count_key_dim",
            type=int,
            default=32,
            help="key dimensionality of hash-count-based curiosity",
        )
        self.parser.add_argument("--count_decay",
                                 type=float,
                                 default=1,
                                 help="factor for count decay speed")
        self.parser.add_argument("--eta",
                                 type=int,
                                 default=5,
                                 help="curiosity loss weighting factor")
        self.parser.add_argument(
            "--curiosity_lr",
            type=float,
            default=5e-6,
            help="learning rate for curiosity optimizer",
        )

        # training length
        self.parser.add_argument("--num_episodes",
                                 type=int,
                                 default=25000,
                                 help="number of episodes")
        self.parser.add_argument("--max_episode_len",
                                 type=int,
                                 default=25,
                                 help="maximum episode length")

        # core training parameters
        self.parser.add_argument("--n_training_threads",
                                 default=6,
                                 type=int,
                                 help="number of training threads")
        self.parser.add_argument(
            "--no_rewards",
            action="store_true",
            default=False,
            help="flag if no rewards should be used",
        )
        self.parser.add_argument(
            "--sparse_rewards",
            action="store_true",
            default=False,
            help="flag if sparse rewards should be used",
        )
        self.parser.add_argument("--sparse_freq",
                                 type=int,
                                 default=25,
                                 help="number of steps before sparse rewards")
        self.parser.add_argument("--gamma",
                                 type=float,
                                 default=0.9,
                                 help="discount factor")
        self.parser.add_argument(
            "--tau",
            type=float,
            default=0.01,
            help="tau as stepsize for target network updates")
        self.parser.add_argument("--lr",
                                 type=float,
                                 default=0.01,
                                 help="learning rate for Adam optimizer")
        self.parser.add_argument("--dropout_p",
                                 type=float,
                                 default=0.0,
                                 help="Dropout probability")
        self.parser.add_argument("--seed",
                                 type=int,
                                 default=None,
                                 help="random seed used throughout training")
        self.parser.add_argument("--steps_per_update",
                                 type=int,
                                 default=100,
                                 help="number of steps before updates")

        self.parser.add_argument("--buffer_capacity",
                                 type=int,
                                 default=int(1e6),
                                 help="Replay buffer capacity")
        self.parser.add_argument(
            "--batch_size",
            type=int,
            default=1024,
            help="number of episodes to optimize at the same time",
        )

        # exploration settings
        self.parser.add_argument(
            "--no_exploration",
            action="store_true",
            default=False,
            help="flag if no exploration should be used",
        )
        self.parser.add_argument("--decay_factor",
                                 type=float,
                                 default=0.99999,
                                 help="exploration decay factor")
        self.parser.add_argument("--exploration_bonus",
                                 type=float,
                                 default=1.0,
                                 help="exploration bonus value")
        self.parser.add_argument("--n_exploration_eps",
                                 default=25000,
                                 type=int)
        self.parser.add_argument("--init_noise_scale", default=0.3, type=float)
        self.parser.add_argument("--final_noise_scale",
                                 default=0.0,
                                 type=float)

        # visualisation
        self.parser.add_argument("--display",
                                 action="store_true",
                                 default=False)
        self.parser.add_argument("--save_frames",
                                 action="store_true",
                                 default=False)
        self.parser.add_argument("--plot",
                                 action="store_true",
                                 default=False,
                                 help="plot reward and exploration bonus")
        self.parser.add_argument("--eval_frequency",
                                 default=100,
                                 type=int,
                                 help="frequency of evaluation episodes")
        self.parser.add_argument("--eval_episodes",
                                 default=5,
                                 type=int,
                                 help="number of evaluation episodes")
        self.parser.add_argument(
            "--dump_losses",
            action="store_true",
            default=False,
            help="dump losses after computation",
        )

        # run name for store path
        self.parser.add_argument("--run",
                                 type=str,
                                 default="default",
                                 help="run name for stored paths")

        # model storing
        self.parser.add_argument(
            "--save_models_dir",
            type=str,
            default="models",
            help="path where models should be saved",
        )
        self.parser.add_argument("--save_interval", default=1000, type=int)
        self.parser.add_argument(
            "--load_models",
            type=str,
            default=None,
            help="path where models should be loaded from if set",
        )
        self.parser.add_argument(
            "--load_models_extension",
            type=str,
            default="final",
            help="name extension for models to load",
        )

    def parse_args(self):
        """
        parse own arguments
        """
        self.parse_default_args()

    def extract_sizes(self, spaces):
        """
        Extract space dimensions
        :param spaces: list of Gym spaces
        :return: list of ints with sizes for each agent
        """
        sizes = []
        for space in spaces:
            if isinstance(space, Box):
                size = sum(space.shape)
            elif isinstance(space, Dict):
                size = sum(self.extract_sizes(space.values()))
            elif isinstance(space, Discrete) or isinstance(space, MultiBinary):
                size = space.n
            elif isinstance(space, MultiDiscrete):
                size = sum(space.nvec)
            else:
                raise ValueError("Unknown class of space: ", type(space))
            sizes.append(size)
        return sizes

    def create_environment(self):
        """
        Create environment instance
        :return: environment (gym interface), env_name, task_name, n_agents, observation_sizes,
                 action_sizes, discrete_actions
        """
        raise NotImplementedError()

    def reset_environment(self):
        """
        Reset environment for new episode
        :return: observation (as torch tensor)
        """
        raise NotImplementedError

    def select_actions(self, obs, explore=True):
        """
        Select actions for agents
        :param obs: joint observation
        :param explore: flag if exploration should be used
        :return: action_tensor, action_list
        """
        raise NotImplementedError()

    def environment_step(self, actions):
        """
        Take step in the environment
        :param actions: actions to apply for each agent
        :return: reward, done, next_obs (as Pytorch tensors)
        """
        raise NotImplementedError()

    def environment_render(self):
        """
        Render visualisation of environment
        """
        raise NotImplementedError()

    def eval(self, ep, n_agents):
        """
        Execute evaluation episode without exploration
        :param ep: episode number
        :param n_agents: number of agents in task
        :return: episode_rewards, episode_length, done
        """
        obs = self.reset_environment()
        self.alg.reset(ep)

        episode_rewards = np.array([0.0] * n_agents)
        episode_length = 0
        done = False

        while not done and episode_length < self.arglist.max_episode_len:
            torch_obs = [
                Variable(torch.Tensor(obs[i]), requires_grad=False)
                for i in range(n_agents)
            ]

            actions, _ = self.select_actions(torch_obs, False)
            rewards, dones, next_obs = self.environment_step(actions)

            episode_rewards += rewards

            obs = next_obs
            episode_length += 1
            done = all(dones)

        return episode_rewards, episode_length, done

    def train(self):
        """
        Abstract training flow
        """
        # set random seeds before model creation
        if self.arglist.seed is not None:
            random.seed(self.arglist.seed)
            np.random.seed(self.arglist.seed)
            torch.manual_seed(self.arglist.seed)
            torch.cuda.manual_seed(self.arglist.seed)
            if torch.cuda.is_available():
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False

        # use number of threads if no GPUs are available
        if not USE_CUDA:
            torch.set_num_threads(self.arglist.n_training_threads)

        env, env_name, task_name, n_agents, observation_sizes, action_sizes, discrete_actions = (
            self.create_environment())
        self.env = env
        self.n_agents = n_agents

        steps = self.arglist.num_episodes * self.arglist.max_episode_len
        # steps-th root of GOAL_EPSILON
        decay_epsilon = GOAL_EPSILON**(1 / float(steps))
        self.arglist.decay_factor = decay_epsilon
        print("Epsilon is decaying with factor %.7f to %.3f over %d steps." %
              (decay_epsilon, GOAL_EPSILON, steps))

        print("Observation sizes: ", observation_sizes)
        print("Action sizes: ", action_sizes)

        # Create curiosity instances
        if self.arglist.curiosity is None:
            print("No curiosity is to be used!")
        elif self.arglist.curiosity == "icm":
            print("Training uses Intrinsic Curiosity Module (ICM)!")
        elif self.arglist.curiosity == "rnd":
            print("Training uses Random Network Distillation (RND)!")
        elif self.arglist.curiosity == "count":
            print("Training uses hash-based counting exploration bonus!")
        else:
            raise ValueError("Unknown curiosity: " + self.arglist.curiosity)

        # create algorithm trainer
        if self.arglist.alg == "maddpg":
            self.alg = MADDPG(n_agents, observation_sizes, action_sizes,
                              discrete_actions, self.arglist)
            print(
                "Training multi-agent deep deterministic policy gradient (MADDPG) on "
                + env_name + " environment")
        elif self.arglist.alg == "iql":
            self.alg = IQL(n_agents, observation_sizes, action_sizes,
                           discrete_actions, self.arglist)
            print("Training independent q-learning (IQL) on " + env_name +
                  " environment")
        else:
            raise ValueError("Unknown algorithm: " + self.arglist.alg)

        self.memory = ReplayBuffer(
            self.arglist.buffer_capacity,
            n_agents,
            observation_sizes,
            action_sizes,
            self.arglist.no_rewards,
        )

        # set random seeds past model creation
        if self.arglist.seed is not None:
            random.seed(self.arglist.seed)
            np.random.seed(self.arglist.seed)
            torch.manual_seed(self.arglist.seed)
            torch.cuda.manual_seed(self.arglist.seed)
            if torch.cuda.is_available():
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False

        if self.arglist.load_models is not None:
            print("Loading models from " + self.arglist.load_models +
                  " with extension " + self.arglist.load_models_extension)
            self.alg.load_model_networks(
                self.arglist.load_models,
                "_" + self.arglist.load_models_extension)

        self.model_saver = ModelSaver(self.arglist.save_models_dir,
                                      self.arglist.run, self.arglist.alg)
        self.logger = Logger(
            n_agents,
            self.arglist.eta,
            task_name,
            self.arglist.run,
            self.arglist.alg,
            self.arglist.curiosity,
        )
        self.plotter = Plotter(
            self.logger,
            n_agents,
            self.arglist.eval_frequency,
            task_name,
            self.arglist.run,
            self.arglist.alg,
            self.arglist.curiosity,
        )
        if self.arglist.save_frames:
            self.frame_saver = FrameSaver(self.arglist.eta, task_name,
                                          self.arglist.run, self.arglist.alg)

        print("Starting iterations...")
        start_time = time.time()
        t = 0

        for ep in range(self.arglist.num_episodes):
            obs = self.reset_environment()
            self.alg.reset(ep)

            episode_rewards = np.array([0.0] * n_agents)
            if self.arglist.sparse_rewards:
                sparse_rewards = np.array([0.0] * n_agents)
            episode_length = 0
            done = False
            interesting_episode = False

            while not done and episode_length < self.arglist.max_episode_len:
                torch_obs = [
                    Variable(torch.Tensor(obs[i]), requires_grad=False)
                    for i in range(n_agents)
                ]

                actions, agent_actions = self.select_actions(
                    torch_obs, not self.arglist.no_exploration)
                rewards, dones, next_obs = self.environment_step(actions)

                episode_rewards += rewards
                if self.arglist.sparse_rewards:
                    sparse_rewards += rewards

                if self.arglist.no_rewards:
                    rewards = [0.0] * n_agents
                elif self.arglist.sparse_rewards:
                    if (episode_length + 1) % self.arglist.sparse_freq == 0:
                        rewards = list(sparse_rewards /
                                       self.arglist.sparse_freq)
                    else:
                        rewards = [0.0] * n_agents
                self.memory.push(obs, agent_actions, rewards, next_obs, dones)

                t += 1

                if (len(self.memory) >= self.arglist.batch_size
                        and (t % self.arglist.steps_per_update) == 0):
                    losses = self.alg.update(self.memory, USE_CUDA)
                    self.logger.log_losses(ep, losses)
                    if self.arglist.dump_losses:
                        self.logger.dump_losses(1)

                # for displaying learned policies
                if self.arglist.display:
                    self.environment_render()
                if self.arglist.save_frames:
                    self.frame_saver.add_frame(
                        self.env.render("rgb_array")[0], ep)
                    if self.arglist.curiosity is not None:
                        curiosities = self.alg.get_curiosities(
                            obs, agent_actions, next_obs)
                        interesting = self.frame_saver.save_interesting_frame(
                            curiosities)
                        interesting_episode = interesting_episode or interesting

                obs = next_obs
                episode_length += 1
                done = all(dones)

            if ep % self.arglist.eval_frequency == 0:
                eval_rewards = np.zeros((self.arglist.eval_episodes, n_agents))
                for i in range(self.arglist.eval_episodes):
                    ep_rewards, _, _ = self.eval(ep, n_agents)
                    eval_rewards[i, :] = ep_rewards
                if self.arglist.alg == "maddpg":
                    self.logger.log_episode(
                        ep,
                        eval_rewards.mean(0),
                        eval_rewards.var(0),
                        self.alg.agents[0].get_exploration_scale(),
                    )
                if self.arglist.alg == "iql":
                    self.logger.log_episode(ep, eval_rewards.mean(0),
                                            eval_rewards.var(0),
                                            self.alg.agents[0].epsilon)
                self.logger.dump_episodes(1)
            if ep % 100 == 0 and ep > 0:
                duration = time.time() - start_time
                self.logger.dump_train_progress(ep, self.arglist.num_episodes,
                                                duration)

            if interesting_episode:
                self.frame_saver.save_episode_gif()

            if ep % (self.arglist.save_interval // 2) == 0 and ep > 0:
                # update plots
                self.plotter.update_reward_plot(self.arglist.plot)
                self.plotter.update_exploration_plot(self.arglist.plot)
                self.plotter.update_alg_loss_plot(self.arglist.plot)
                if self.arglist.curiosity is not None:
                    self.plotter.update_cur_loss_plot(self.arglist.plot)
                    self.plotter.update_intrinsic_reward_plot(
                        self.arglist.plot)

            if ep % self.arglist.save_interval == 0 and ep > 0:
                # save plots
                print("Remove previous plots")
                self.plotter.clear_plots()
                print("Saving intermediate plots")
                self.plotter.save_reward_plot(str(ep))
                self.plotter.save_exploration_plot(str(ep))
                self.plotter.save_alg_loss_plots(str(ep))
                self.plotter.save_cur_loss_plots(str(ep))
                self.plotter.save_intrinsic_reward_plot(str(ep))
                # save models
                print("Remove previous models")
                self.model_saver.clear_models()
                print("Saving intermediate models")
                self.model_saver.save_models(self.alg, str(ep))
                # save logs
                print("Remove previous logs")
                self.logger.clear_logs()
                print("Saving intermediate logs")
                self.logger.save_episodes(extension=str(ep))
                self.logger.save_losses(extension=str(ep))
                # save parameter log
                self.logger.save_parameters(
                    env_name,
                    task_name,
                    n_agents,
                    observation_sizes,
                    action_sizes,
                    discrete_actions,
                    self.arglist,
                )

        duration = time.time() - start_time
        print("Overall duration: %.2fs" % duration)

        print("Remove previous plots")
        self.plotter.clear_plots()
        print("Saving final plots")
        self.plotter.save_reward_plot("final")
        self.plotter.save_exploration_plot("final")
        self.plotter.save_alg_loss_plots("final")
        self.plotter.save_cur_loss_plots("final")
        self.plotter.save_intrinsic_reward_plot("final")

        # save models
        print("Remove previous models")
        self.model_saver.clear_models()
        print("Saving final models")
        self.model_saver.save_models(self.alg, "final")

        # save logs
        print("Remove previous logs")
        self.logger.clear_logs()
        print("Saving final logs")
        self.logger.save_episodes(extension="final")
        self.logger.save_losses(extension="final")
        self.logger.save_duration_cuda(duration, torch.cuda.is_available())

        # save parameter log
        self.logger.save_parameters(
            env_name,
            task_name,
            n_agents,
            observation_sizes,
            action_sizes,
            discrete_actions,
            self.arglist,
        )

        env.close()

    if __name__ == "__main__":
        train = Train()
        train.train()
Ejemplo n.º 2
0
class Train:
    def __init__(self):
        self.parser = argparse.ArgumentParser(
            "Reinforcement Learning experiments for multiagent environments")
        self.parse_args()
        self.arglist = self.parser.parse_args()

    def parse_default_args(self):
        """
        Parse default arguments for MARL training script
        """
        # algorithm
        self.parser.add_argument("--hidden_dim", default=128, type=int)
        self.parser.add_argument("--shared_experience",
                                 action="store_true",
                                 default=False)
        self.parser.add_argument("--shared_lambda", default=1.0, type=float)
        self.parser.add_argument("--targets",
                                 type=str,
                                 default="simple",
                                 help="target computation used for DQN")

        # training length
        self.parser.add_argument("--num_episodes",
                                 type=int,
                                 default=120000,
                                 help="number of episodes")
        self.parser.add_argument("--max_episode_len",
                                 type=int,
                                 default=25,
                                 help="maximum episode length")

        # core training parameters
        self.parser.add_argument("--n_training_threads",
                                 default=1,
                                 type=int,
                                 help="number of training threads")
        self.parser.add_argument("--gamma",
                                 type=float,
                                 default=0.99,
                                 help="discount factor")
        self.parser.add_argument(
            "--tau",
            type=float,
            default=0.05,
            help="tau as stepsize for target network updates")
        self.parser.add_argument(
            "--lr",
            type=float,
            default=0.0001,
            help="learning rate for Adam optimizer"  #use 5e-5 for RWARE
        )
        self.parser.add_argument("--seed",
                                 type=int,
                                 default=None,
                                 help="random seed used throughout training")
        self.parser.add_argument("--steps_per_update",
                                 type=int,
                                 default=1,
                                 help="number of steps before updates")

        self.parser.add_argument("--buffer_capacity",
                                 type=int,
                                 default=int(1e6),
                                 help="Replay buffer capacity")
        self.parser.add_argument(
            "--batch_size",
            type=int,
            default=128,
            help="number of episodes to optimize at the same time",
        )
        self.parser.add_argument("--epsilon",
                                 type=float,
                                 default=1.0,
                                 help="epsilon value")
        self.parser.add_argument("--goal_epsilon",
                                 type=float,
                                 default=0.01,
                                 help="epsilon target value")
        self.parser.add_argument("--epsilon_decay",
                                 type=float,
                                 default=10,
                                 help="epsilon decay value")
        self.parser.add_argument("--epsilon_anneal_slow",
                                 action="store_true",
                                 default=False,
                                 help="anneal epsilon slowly")

        # visualisation
        self.parser.add_argument("--render",
                                 action="store_true",
                                 default=False)
        self.parser.add_argument("--eval_frequency",
                                 default=50,
                                 type=int,
                                 help="frequency of evaluation episodes")
        self.parser.add_argument("--eval_episodes",
                                 default=5,
                                 type=int,
                                 help="number of evaluation episodes")
        self.parser.add_argument("--run",
                                 type=str,
                                 default="default",
                                 help="run name for stored paths")
        self.parser.add_argument("--save_interval", default=100, type=int)
        self.parser.add_argument("--training_returns_freq",
                                 default=100,
                                 type=int)

    def parse_args(self):
        """
        parse own arguments
        """
        self.parse_default_args()

    def extract_sizes(self, spaces):
        """
        Extract space dimensions
        :param spaces: list of Gym spaces
        :return: list of ints with sizes for each agent
        """
        sizes = []
        for space in spaces:
            if isinstance(space, Box):
                size = sum(space.shape)
            elif isinstance(space, Dict):
                size = sum(self.extract_sizes(space.values()))
            elif isinstance(space, Discrete) or isinstance(space, MultiBinary):
                size = space.n
            elif isinstance(space, MultiDiscrete):
                size = sum(space.nvec)
            else:
                raise ValueError("Unknown class of space: ", type(space))
            sizes.append(size)
        return sizes

    def create_environment(self):
        """
        Create environment instance
        :return: environment (gym interface), env_name, task_name, n_agents, observation_sizes,
                 action_sizes, discrete_actions
        """
        raise NotImplementedError()

    def reset_environment(self):
        """
        Reset environment for new episode
        :return: observation (as torch tensor)
        """
        raise NotImplementedError

    def select_actions(self, obs, explore=True):
        """
        Select actions for agents
        :param obs: joint observation
        :param explore: flag if exploration should be used
        :return: action_tensor, action_list
        """
        raise NotImplementedError()

    def environment_step(self, actions):
        """
        Take step in the environment
        :param actions: actions to apply for each agent
        :return: reward, done, next_obs (as Pytorch tensors)
        """
        raise NotImplementedError()

    def environment_render(self):
        """
        Render visualisation of environment
        """
        raise NotImplementedError()

    def fill_buffer(self, timesteps):
        """
        Randomly sample actions and store experience in buffer
        :param timesteps: number of timesteps
        """
        t = 0
        while t < timesteps:
            done = False
            obs = self.reset_environment()
            while not done and t < timesteps:
                actions = [space.sample() for space in self.action_spaces]
                rewards, dones, next_obs, _ = self.environment_step(actions)
                onehot_actions = np.zeros((len(actions), self.action_sizes[0]))
                onehot_actions[np.arange(len(actions)), actions] = 1
                self.memory.add(obs, onehot_actions, rewards, next_obs, dones)
                obs = next_obs
                t += 1
                done = all(dones)

    def eval(self, ep, n_agents):
        """
        Execute evaluation episode without exploration
        :param ep: episode number
        :param n_agents: number of agents in task
        :return: returns, episode_length, done
        """
        obs = self.reset_environment()
        self.alg.reset(ep)

        episode_returns = np.array([0.0] * n_agents)
        episode_length = 0
        done = False

        while not done and episode_length < self.arglist.max_episode_len:
            torch_obs = [
                Variable(torch.Tensor(obs[i]), requires_grad=False)
                for i in range(n_agents)
            ]

            actions, _ = self.select_actions(torch_obs, False)
            rewards, dones, next_obs, _ = self.environment_step(actions)

            episode_returns += rewards

            obs = next_obs
            episode_length += 1
            done = all(dones)

        return episode_returns, episode_length, done

    def set_seeds(self, seed):
        """
        Set random seeds before model creation
        :param seed (int): seed to use
        """
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            if torch.cuda.is_available():
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False

    def train(self):
        """
        Abstract training flow
        """
        # set random seeds before model creation
        self.set_seeds(self.arglist.seed)

        # use number of threads if no GPUs are available
        if not USE_CUDA:
            torch.set_num_threads(self.arglist.n_training_threads)

        env, env_name, task_name, n_agents, observation_spaces, action_spaces, observation_sizes, action_sizes = (
            self.create_environment())
        self.env = env
        self.n_agents = n_agents
        self.observation_spaces = observation_spaces
        self.action_spaces = action_spaces
        self.observation_sizes = observation_sizes
        self.action_sizes = action_sizes

        if self.arglist.max_episode_len == 25:
            steps = self.arglist.num_episodes * 20  #self.arglist.max_episode_len
        else:
            steps = self.arglist.num_episodes * self.arglist.max_episode_len
        # steps-th root of goal epsilon
        if self.arglist.epsilon_anneal_slow:
            decay_factor = self.arglist.epsilon_decay**(1 / float(steps))
            self.arglist.decay_factor = decay_factor
            print(
                f"Epsilon is decaying with (({self.arglist.epsilon_decay} - {decay_factor}**t) / {self.arglist.epsilon_decay}) to {self.arglist.goal_epsilon} over {steps} steps."
            )
        else:
            decay_epsilon = self.arglist.goal_epsilon**(1 / float(steps))
            self.arglist.decay_factor = decay_epsilon
            print(
                "Epsilon is decaying with factor %.7f to %.3f over %d steps." %
                (decay_epsilon, self.arglist.goal_epsilon, steps))

        print("Observation sizes: ", observation_sizes)
        print("Action sizes: ", action_sizes)

        target_type = self.arglist.targets
        if not target_type in TARGET_TYPES:
            print(f"Invalid target type {target_type}!")
            return
        else:
            if target_type == "simple":
                print("Simple target computation used")
            elif target_type == "double":
                print("Double target computation used")
            elif target_type == "our-double":
                print("Agent-double target computation used")
            elif target_type == "our-clipped":
                print("Agent-clipped target computation used")

        # create algorithm trainer
        self.alg = IQL(n_agents, observation_sizes, action_sizes, self.arglist)

        obs_size = observation_sizes[0]
        for o_size in observation_sizes[1:]:
            assert obs_size == o_size
        act_size = action_sizes[0]
        for a_size in action_sizes[1:]:
            assert act_size == a_size

        self.memory = MARLReplayBuffer(
            self.arglist.buffer_capacity,
            n_agents,
        )

        # set random seeds past model creation
        self.set_seeds(self.arglist.seed)

        self.model_saver = ModelSaver("models", self.arglist.run)
        self.logger = Logger(
            n_agents,
            task_name,
            self.arglist.run,
        )

        self.fill_buffer(5000)

        print("Starting iterations...")
        start_time = time.process_time()
        # timer = time.process_time()
        # env_time = 0
        # step_time = 0
        # update_time = 0
        # after_ep_time = 0

        t = 0
        training_returns_saved = 0

        episode_returns = []
        episode_agent_returns = []
        for ep in range(self.arglist.num_episodes):
            obs = self.reset_environment()
            self.alg.reset(ep)

            # episode_returns = np.array([0.0] * n_agents)
            episode_length = 0
            done = False

            while not done and episode_length < self.arglist.max_episode_len:
                torch_obs = [
                    Variable(torch.Tensor(obs[i]), requires_grad=False)
                    for i in range(n_agents)
                ]

                # env_time += time.process_time() - timer
                # timer = time.process_time()
                actions, onehot_actions = self.select_actions(torch_obs)
                # step_time += time.process_time() - timer
                # timer = time.process_time()
                rewards, dones, next_obs, info = self.environment_step(actions)

                # episode_returns += rewards

                self.memory.add(obs, onehot_actions, rewards, next_obs, dones)

                t += 1

                # env_time += time.process_time() - timer
                # timer = time.process_time()
                if (len(self.memory) >= self.arglist.batch_size
                        and (t % self.arglist.steps_per_update) == 0):
                    losses = self.alg.update(self.memory, USE_CUDA)
                    self.logger.log_losses(ep, losses)
                    #self.logger.dump_losses(1)

                # update_time += time.process_time() - timer
                # timer = time.process_time()
                # for displaying learned policies
                if self.arglist.render:
                    self.environment_render()

                obs = next_obs
                episode_length += 1
                done = all(dones)

                if done or episode_length == self.arglist.max_episode_len:
                    episode_returns.append(info["episode_reward"])
                    agent_returns = []
                    for i in range(n_agents):
                        agent_returns.append(info[f"agent{i}/episode_reward"])
                    episode_agent_returns.append(agent_returns)

            # env_time += time.process_time() - timer
            # timer = time.process_time()
            if (training_returns_saved +
                    1) * t >= self.arglist.training_returns_freq:
                training_returns_saved += 1
                returns = np.array(episode_returns[-10:])
                mean_return = returns.mean()
                agent_returns = np.array(episode_agent_returns[-10:])
                mean_agent_return = agent_returns.mean(axis=0)

                self.logger.log_training_returns(t, mean_return,
                                                 mean_agent_return)

            if ep % self.arglist.eval_frequency == 0:
                eval_returns = np.zeros((self.arglist.eval_episodes, n_agents))
                for i in range(self.arglist.eval_episodes):
                    ep_returns, _, _ = self.eval(ep, n_agents)
                    eval_returns[i, :] = ep_returns
                self.logger.log_episode(ep, eval_returns.mean(0),
                                        eval_returns.var(0),
                                        self.alg.agents[0].epsilon)
                self.logger.dump_episodes(1)
            if ep % 100 == 0 and ep > 0:
                duration = time.process_time() - start_time
                self.logger.dump_train_progress(ep, self.arglist.num_episodes,
                                                duration)

            if ep % self.arglist.save_interval == 0 and ep > 0:
                # save models
                print("Remove previous models")
                self.model_saver.clear_models()
                print("Saving intermediate models")
                self.model_saver.save_models(self.alg, str(ep))
                # save logs
                print("Remove previous logs")
                self.logger.clear_logs()
                print("Saving intermediate logs")
                self.logger.save_training_returns(extension=str(ep))
                self.logger.save_episodes(extension=str(ep))
                self.logger.save_losses(extension=str(ep))
                # save parameter log
                self.logger.save_parameters(
                    env_name,
                    task_name,
                    n_agents,
                    observation_sizes,
                    action_sizes,
                    self.arglist,
                )

            # after_ep_time += time.process_time() - timer
            # timer = time.process_time()
            # print(f"Episode {ep} times:")
            # print(f"\tEnv time: {env_time}s")
            # print(f"\tStep time: {step_time}s")
            # print(f"\tUpdate time: {update_time}s")
            # print(f"\tAfter Ep time: {after_ep_time}s")
            # env_time = 0
            # step_time = 0
            # update_time = 0
            # after_ep_time = 0

        duration = time.process_time() - start_time
        print("Overall duration: %.2fs" % duration)

        # save models
        print("Remove previous models")
        self.model_saver.clear_models()
        print("Saving final models")
        self.model_saver.save_models(self.alg, "final")

        # save logs
        print("Remove previous logs")
        self.logger.clear_logs()
        print("Saving final logs")
        self.logger.save_episodes(extension="final")
        self.logger.save_losses(extension="final")
        self.logger.save_duration_cuda(duration, torch.cuda.is_available())

        # save parameter log
        self.logger.save_parameters(
            env_name,
            task_name,
            n_agents,
            observation_sizes,
            action_sizes,
            self.arglist,
        )

        env.close()

    if __name__ == "__main__":
        train = Train()
        train.train()