Esempio n. 1
0
 def startup(self):
     """
     Calls ``sampler.async_initialize()`` to get a double buffer for minibatches,
     followed by ``algo.async_initialize()`` to get a replay buffer on shared memory,
     then launches all workers (sampler, optimizer, memory copier).
     """
     if self.seed is None:
         self.seed = make_seed()
     set_seed(self.seed)
     double_buffer, examples = self.sampler.async_initialize(
         agent=self.agent,
         bootstrap_value=getattr(self.algo, "bootstrap_value", False),
         traj_info_kwargs=self.get_traj_info_kwargs(),
         seed=self.seed,
     )
     self.sampler_batch_size = self.sampler.batch_spec.size
     self.world_size = len(self.affinity.optimizer)
     n_itr = self.get_n_itr()  # Number of sampler iterations.
     replay_buffer = self.algo.async_initialize(
         agent=self.agent,
         sampler_n_itr=n_itr,
         batch_spec=self.sampler.batch_spec,
         mid_batch_reset=self.sampler.mid_batch_reset,
         examples=examples,
         world_size=self.world_size,
     )
     self.launch_workers(n_itr, double_buffer, replay_buffer)
     throttle_itr, delta_throttle_itr = self.optim_startup()
     return throttle_itr, delta_throttle_itr
Esempio n. 2
0
 def launch_workers(self):
     self.affinities = self.affinity
     self.affinity = self.affinities[0]
     self.n_runners = n_runners = len(self.affinities)
     self.rank = rank = 0
     self.par = par = self.build_par_objs(n_runners)
     if self.seed is None:
         self.seed = make_seed()
     port = find_port(offset=self.affinity.get("run_slot", 0))
     workers_kwargs = [
         dict(
             algo=self.algo,
             agent=self.agent,
             sampler=self.sampler,
             n_steps=self.n_steps,
             seed=self.seed + 100 * rank,
             affinity=self.affinities[rank],
             log_interval_steps=self.log_interval_steps,
             rank=rank,
             n_runners=n_runners,
             port=port,
             par=par,
         ) for rank in range(1, n_runners)
     ]
     workers = [self.WorkerCls(**w_kwargs) for w_kwargs in workers_kwargs]
     self.workers = [mp.Process(target=w.train, args=()) for w in workers]
     for w in self.workers:
         w.start()
     torch.distributed.init_process_group(
         backend="nccl",
         rank=rank,
         world_size=n_runners,
         init_method=f"tcp://127.0.0.1:{port}",
     )
Esempio n. 3
0
 def startup(self):
     p = psutil.Process()
     try:
         if self.affinity.get("master_cpus",
                              None) is not None and self.affinity.get(
                                  "set_affinity", True):
             p.cpu_affinity(self.affinity["master_cpus"])
         cpu_affin = p.cpu_affinity()
     except AttributeError:
         cpu_affin = "UNAVAILABLE MacOS"
     logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: "
                f"{cpu_affin}.")
     if self.affinity.get("master_torch_threads", None) is not None:
         torch.set_num_threads(self.affinity["master_torch_threads"])
     logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: "
                f"{torch.get_num_threads()}.")
     if self.seed is None:
         self.seed = make_seed()
     set_seed(self.seed)
     # self.rank = rank = getattr(self, "rank", 0)
     # self.world_size = world_size = getattr(self, "world_size", 1)
     self.algo.initialize(
         n_updates=self.n_updates,
         cuda_idx=self.affinity.get("cuda_idx", None),
     )
     self.initialize_logging()
    def startup(self):
        """
        Sets hardware affinities, initializes the following: 1) sampler (which
        should initialize the agent), 2) agent device and data-parallel wrapper (if applicable),
        3) algorithm, 4) logger.
        """
        p = psutil.Process()
        try:
            if (self.affinity.get("master_cpus", None) is not None
                    and self.affinity.get("set_affinity", True)):
                p.cpu_affinity(self.affinity["master_cpus"])
            cpu_affin = p.cpu_affinity()
        except AttributeError:
            cpu_affin = "UNAVAILABLE MacOS"
        logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: "
                   f"{cpu_affin}.")
        if self.affinity.get("master_torch_threads", None) is not None:
            torch.set_num_threads(self.affinity["master_torch_threads"])
        logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: "
                   f"{torch.get_num_threads()}.")
        if self.seed is None:
            self.seed = make_seed()
        set_seed(self.seed)
        self.rank = rank = getattr(self, "rank", 0)
        self.world_size = world_size = getattr(self, "world_size", 1)

        examples = self.sampler.initialize(
            agent=self.agent,  # Agent gets initialized in sampler.
            affinity=self.affinity,
            seed=self.seed + 1,
            bootstrap_value=getattr(self.algo, "bootstrap_value", False),
            traj_info_kwargs=self.get_traj_info_kwargs(),
            rank=rank,
            world_size=world_size,
        )
        self.itr_batch_size = self.sampler.batch_spec.size * world_size
        n_itr = self.get_n_itr()
        print("CUDA: ", self.affinity.get("cuda_idx", None))
        self.agent.to_device(self.affinity.get("cuda_idx", None))
        if world_size > 1:
            self.agent.data_parallel()
        self.algo.initialize(
            agent=self.agent,
            n_itr=n_itr,
            batch_spec=self.sampler.batch_spec,
            mid_batch_reset=self.sampler.mid_batch_reset,
            examples=examples,
            world_size=world_size,
            rank=rank,
        )
        self.initialize_logging()
        return n_itr
Esempio n. 5
0
 def launch_workers(self):
     """
     As part of startup, fork a separate Python process for each additional
     GPU; the master process runs on the first GPU.  Initialize
     ``torch.distributed`` so the ``DistributedDataParallel`` wrapper can
     work--also makes ``torch.distributed`` avaiable for other
     communication.
     """
     self.affinities = self.affinity
     self.affinity = self.affinities[0]
     self.world_size = world_size = len(self.affinities)
     self.rank = rank = 0
     self.par = par = self.build_par_objs(world_size)
     if self.seed is None:
         self.seed = make_seed()
     port = find_port(offset=self.affinity.get("master_cpus",
                                               [0])[0])  # 29500
     backend = "gloo" if self.affinity.get("cuda_idx",
                                           None) is None else "nccl"
     workers_kwargs = [
         dict(
             algo=self.algo,
             agent=self.agent,
             sampler=self.sampler,
             n_steps=self.n_steps,
             seed=self.seed + 100 * rank,
             affinity=self.affinities[rank],
             log_interval_steps=self.log_interval_steps,
             rank=rank,
             world_size=world_size,
             port=port,
             backend=backend,
             par=par,
         ) for rank in range(1, world_size)
     ]
     workers = [self.WorkerCls(**w_kwargs) for w_kwargs in workers_kwargs]
     # self.workers = [mp.Process(target=w.train, args=()) for w in workers]
     self.workers = [mp.Process(target=w.eval, args=()) for w in workers]
     for w in self.workers:
         w.start()
     # print(torch.cuda.device_count())
     # import os
     # os.environ["MASTER_ADDR"] = "127.0.0.1"
     # os.environ["MASTER_PORT"] = "29500"
     # os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3,4'
     # print(backend, rank, world_size)
     torch.distributed.init_process_group(
         backend=backend,
         rank=rank,
         world_size=world_size,
         init_method=f"tcp://128.112.35.85:{port}",
     )
Esempio n. 6
0
 def startup(self):
     """
     Adds support for next_obs sampling.
     """
     p = psutil.Process()
     try:
         if (self.affinity.get("master_cpus", None) is not None
                 and self.affinity.get("set_affinity", True)):
             p.cpu_affinity(self.affinity["master_cpus"])
         cpu_affin = p.cpu_affinity()
     except AttributeError:
         cpu_affin = "UNAVAILABLE MacOS"
     logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: "
                f"{cpu_affin}.")
     if self.affinity.get("master_torch_threads", None) is not None:
         torch.set_num_threads(self.affinity["master_torch_threads"])
     logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: "
                f"{torch.get_num_threads()}.")
     if self.seed is None:
         self.seed = make_seed()
     set_seed(self.seed)
     self.rank = rank = getattr(self, "rank", 0)
     self.world_size = world_size = getattr(self, "world_size", 1)
     examples = self.sampler.initialize(
         agent=self.agent,  # Agent gets initialized in sampler.
         affinity=self.affinity,
         seed=self.seed + 1,
         bootstrap_value=getattr(self.algo, "bootstrap_value", False),
         next_obs=getattr(self.algo, "next_obs", False),  # MODIFIED HERE
         traj_info_kwargs=self.get_traj_info_kwargs(),
         rank=rank,
         world_size=world_size,
     )
     self.itr_batch_size = self.sampler.batch_spec.size * world_size
     n_itr = self.get_n_itr()
     self.agent.to_device(self.affinity.get("cuda_idx", None))
     if world_size > 1:
         self.agent.data_parallel()
     self.algo.initialize(
         agent=self.agent,
         n_itr=n_itr,
         batch_spec=self.sampler.batch_spec,
         mid_batch_reset=self.sampler.mid_batch_reset,
         examples=examples,
         world_size=world_size,
         rank=rank,
     )
     self.initialize_logging()
     return n_itr
Esempio n. 7
0
 def async_initialize(self,
                      agent,
                      bootstrap_value=False,
                      traj_info_kwargs=None,
                      seed=None):
     """Instantiate an example environment and use it to initialize the
     agent (on shared memory).  Pre-allocate a double-buffer for sample
     batches, and return that buffer along with example data (e.g.
     `observation`, `action`, etc.)
     """
     self.seed = make_seed() if seed is None else seed
     # Construct an example of each kind of data that needs to be stored.
     env = self.EnvCls(**self.env_kwargs)
     # Sampler always receives new params through shared memory:
     agent.initialize(
         env.spaces,
         share_memory=True,
         global_B=self.batch_spec.B,
         env_ranks=list(range(self.batch_spec.B)),
     )
     _, samples_np, examples = build_samples_buffer(
         agent,
         env,
         self.batch_spec,
         bootstrap_value,
         agent_shared=True,
         env_shared=True,
         subprocess=True,
     )  # Would like subprocess=True, but might hang?
     _, samples_np2, _ = build_samples_buffer(
         agent,
         env,
         self.batch_spec,
         bootstrap_value,
         agent_shared=True,
         env_shared=True,
         subprocess=True,
     )
     env.close()
     del env
     if traj_info_kwargs:
         for k, v in traj_info_kwargs.items():
             setattr(self.TrajInfoCls, "_" + k, v)
     self.double_buffer = double_buffer = (samples_np, samples_np2)
     self.samples_np = samples_np  # In case leftover use during worker init.
     self.examples = examples
     self.agent = agent
     return double_buffer, examples
Esempio n. 8
0
 def launch_workers(self):
     """
     As part of startup, fork a separate Python process for each additional
     GPU; the master process runs on the first GPU.  Initialize
     ``torch.distributed`` so the ``DistributedDataParallel`` wrapper can
     work--also makes ``torch.distributed`` avaiable for other
     communication.
     """
     self.affinities = self.affinity
     self.affinity = self.affinities[0]
     self.world_size = world_size = len(self.affinities)
     self.rank = rank = 0
     self.par = par = self.build_par_objs(world_size)
     if self.seed is None:
         self.seed = make_seed()
     port = find_port(offset=self.affinity.get("master_cpus", [0])[0])
     backend = "gloo" if self.affinity.get("cuda_idx",
                                           None) is None else "nccl"
     workers_kwargs = [
         dict(
             algo=self.algo,
             agent=self.agent,
             sampler=self.sampler,
             n_steps=self.n_steps,
             seed=self.seed + 100 * rank,
             affinity=self.affinities[rank],
             log_interval_steps=self.log_interval_steps,
             rank=rank,
             world_size=world_size,
             port=port,
             backend=backend,
             par=par,
             log_dir=self.log_dir,
             pretrain=self.pretrain,
         ) for rank in range(1, world_size)
     ]
     workers = [self.WorkerCls(**w_kwargs) for w_kwargs in workers_kwargs]
     self.workers = [mp.Process(target=w.train, args=()) for w in workers]
     for w in self.workers:
         w.start()
     torch.distributed.init_process_group(
         backend=backend,
         rank=rank,
         world_size=world_size,
         init_method=f"tcp://127.0.0.1:{port}",
     )
Esempio n. 9
0
 def startup(self):
     p = psutil.Process()
     try:
         if self.affinity.get("master_cpus", None) is not None:
             p.cpu_affinity(self.affinity["master_cpus"])
         cpu_affin = p.cpu_affinity()
     except AttributeError:
         cpu_affin = "UNAVAILABLE MacOS"
     logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: "
                f"{cpu_affin}.")
     if self.affinity.get("master_torch_threads", None) is not None:
         torch.set_num_threads(self.affinity["master_torch_threads"])
     logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: "
                f"{torch.get_num_threads()}.")
     if self.seed is None:
         self.seed = make_seed()
     set_seed(self.seed)
     examples = self.sampler.initialize(
         agent=self.agent,  # Agent gets intialized in sampler.
         affinity=self.affinity,
         seed=self.seed + 1,
         bootstrap_value=getattr(self.algo, "bootstrap_value", False),
         traj_info_kwargs=self.get_traj_info_kwargs(),
     )
     n_runners = getattr(self, "n_runners", 1)
     self.itr_batch_size = self.sampler.batch_spec.size * n_runners
     n_itr = self.get_n_itr()
     self.agent.initialize_cuda(
         cuda_idx=self.affinity.get("cuda_idx", None),
         ddp=n_runners > 1,  # Multi-GPU training (and maybe sampling).
     )
     self.algo.initialize(agent=self.agent,
                          n_itr=n_itr,
                          batch_spec=self.sampler.batch_spec,
                          mid_batch_reset=self.sampler.mid_batch_reset,
                          examples=examples)
     self.initialize_logging()
     return n_itr
 def async_initialize(self,
                      agent,
                      bootstrap_value=False,
                      traj_info_kwargs=None,
                      seed=None):
     self.seed = make_seed() if seed is None else seed
     # Construct an example of each kind of data that needs to be stored.
     env = self.EnvCls(**self.env_kwargs)
     # Sampler always receives new params through shared memory:
     agent.initialize(env.spaces,
                      share_memory=True,
                      global_B=self.batch_spec.B,
                      env_ranks=list(range(self.batch_spec.B)))
     _, samples_np, examples = build_samples_buffer(
         agent,
         env,
         self.batch_spec,
         bootstrap_value,
         agent_shared=True,
         env_shared=True,
         subprocess=True)  # Would like subprocess=True, but might hang?
     _, samples_np2, _ = build_samples_buffer(agent,
                                              env,
                                              self.batch_spec,
                                              bootstrap_value,
                                              agent_shared=True,
                                              env_shared=True,
                                              subprocess=True)
     env.close()
     del env
     if traj_info_kwargs:
         for k, v in traj_info_kwargs.items():
             setattr(self.TrajInfoCls, "_" + k, v)
     self.double_buffer = double_buffer = (samples_np, samples_np2)
     self.samples_np = samples_np  # In case leftover use during worker init.
     self.examples = examples
     self.agent = agent
     return double_buffer, examples
Esempio n. 11
0
 def launch_workers(self):
     self.affinities = self.affinity
     self.affinity = self.affinities[0]
     self.world_size = world_size = len(self.affinities)
     self.rank = rank = 0
     self.par = par = self.build_par_objs(world_size)
     if self.seed is None:
         self.seed = make_seed()
     port = find_port(offset=self.affinity.get("master_cpus", [0])[0])
     backend = "gloo" if self.affinity.get("cuda_idx",
                                           None) is None else "nccl"
     workers_kwargs = [
         dict(
             algo=self.algo,
             agent=self.agent,
             sampler=self.sampler,
             n_steps=self.n_steps,
             seed=self.seed + 100 * rank,
             affinity=self.affinities[rank],
             log_interval_steps=self.log_interval_steps,
             rank=rank,
             world_size=world_size,
             port=port,
             backend=backend,
             par=par,
         ) for rank in range(1, world_size)
     ]
     workers = [self.WorkerCls(**w_kwargs) for w_kwargs in workers_kwargs]
     self.workers = [mp.Process(target=w.train, args=()) for w in workers]
     for w in self.workers:
         w.start()
     torch.distributed.init_process_group(
         backend=backend,
         rank=rank,
         world_size=world_size,
         init_method=f"tcp://127.0.0.1:{port}",
     )
Esempio n. 12
0
 def startup(self):
     if self.seed is None:
         self.seed = make_seed()
     set_seed(self.seed)
     double_buffer, examples = self.sampler.async_initialize(
         agent=self.agent,
         bootstrap_value=getattr(self.algo, "bootstrap_value", False),
         traj_info_kwargs=self.get_traj_info_kwargs(),
         seed=self.seed,
     )
     self.sampler_batch_size = self.sampler.batch_spec.size
     self.world_size = len(self.affinity.optimizer)
     n_itr = self.get_n_itr()  # Number of sampler iterations.
     replay_buffer = self.algo.async_initialize(
         agent=self.agent,
         sampler_n_itr=n_itr,
         batch_spec=self.sampler.batch_spec,
         mid_batch_reset=self.sampler.mid_batch_reset,
         examples=examples,
         world_size=self.world_size,
     )
     self.launch_workers(n_itr, double_buffer, replay_buffer)
     throttle_itr, delta_throttle_itr = self.optim_startup()
     return throttle_itr, delta_throttle_itr
Esempio n. 13
0
    def forward(self, observation, prev_action, prev_reward):
        lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
        action = torch.randint(low=0, high=self.num_actions, size=(T * B, ))
        action = restore_leading_dims((action), lead_dim, T, B)
        return action


# Setup the data collection pipeline
sampler = GpuSampler(EnvCls=gym.make,
                     env_kwargs=config["env"],
                     CollectorCls=GpuResetCollector,
                     eval_env_kwargs=config["env"],
                     **config["sampler"])
agent = RandomAgent(ModelCls=RandomDiscreteModel,
                    model_kwargs={"num_actions": 15})
seed = make_seed()
set_seed(seed)
sampler.initialize(agent=agent, affinity=affinity, seed=seed + 1, rank=0)
steps = config["train_steps"]

# Create the model
model = BiGAN(**config["model"])
if config["load_path"]:
    model.load_state_dict(torch.load(config["load_path"]))
# Setup the optimizers
lr = config["optim"]["lr"]
d_optimizer = torch.optim.Adam(model.d.parameters(),
                               lr=lr,
                               betas=(0.5, 0.999),
                               weight_decay=2.5e-5)
g_optimizer = torch.optim.Adam(list(model.e.parameters()) +
Esempio n. 14
0
    def startup(self):
        """
        一些初始化工作。
        """
        p = psutil.Process()  # 获取当前进程的信息

        # 设置CPU亲和性(MacOS不支持)
        try:
            if self.affinity.get("master_cpus",
                                 None) is not None and self.affinity.get(
                                     "set_affinity", True):
                p.cpu_affinity(self.affinity["master_cpus"])
            cpu_affin = p.cpu_affinity()  # set了之后再取出来
        except AttributeError:
            cpu_affin = "UNAVAILABLE MacOS"
        logger.log(
            f"Runner {getattr(self, 'rank', '')} master CPU affinity: {cpu_affin}."
        )

        # 设置线程数
        if self.affinity.get("master_torch_threads", None) is not None:
            torch.set_num_threads(
                self.affinity["master_torch_threads"])  # 设置CPU并发执行的线程数
        logger.log(
            f"Runner {getattr(self, 'rank', '')} master Torch threads: {torch.get_num_threads()}."
        )

        # 设置随机数种子
        if self.seed is None:
            self.seed = make_seed()
        set_seed(self.seed)

        self.rank = rank = getattr(self, "rank", 0)  # rank是外部传进来的值,如果没有传则默认为1
        self.world_size = world_size = getattr(
            self, "world_size", 1)  # world_size是外部传进来的值,如果没有传则默认为1

        # 初始化Sampler实例,这里的变量名examples起得不好,可能会让人误解
        examples = self.sampler.initialize(
            agent=self.
            agent,  # Agent gets initialized in sampler. agent会在Sampler中被初始化,所以要传进去
            affinity=self.affinity,  # CPU亲和性
            seed=self.seed + 1,  # 随机种子
            bootstrap_value=getattr(self.algo, "bootstrap_value", False),
            traj_info_kwargs=self.get_traj_info_kwargs(),  # 此方法里只设置折扣系数
            rank=rank,
            world_size=world_size,
        )
        """
        batch_spec.size的实现参见 BatchSpec 类,它表示的所有的environment实例上的所有时间步的数量总和。这里又乘了一个没有说明含义的
        world_size,用一个不正经的比喻,我猜这里有大概是"平行宇宙"的概念(美剧闪电侠),在"当前宇宙"内的发生的采样,它是算在 batch_spec.size 
        内的,而像这样的场景,我们可以把它复制很多个出来,用所有这些创造出来的集合来训练RL模型。
        """
        self.itr_batch_size = self.sampler.batch_spec.size * world_size  # 所有迭代的时间步数
        n_itr = self.get_n_itr(
        )  # 计算模型训练的迭代次数。在这里,迭代次数并不是直接指定的,而是经过一个复杂的方法计算出来的
        self.agent.to_device(self.affinity.get("cuda_idx",
                                               None))  # 在指定的设备上运行程序
        if world_size > 1:
            self.agent.data_parallel()

        # 初始化算法(Algorithm)对象
        self.algo.initialize(
            agent=self.agent,
            n_itr=n_itr,
            batch_spec=self.sampler.batch_spec,
            mid_batch_reset=self.sampler.mid_batch_reset,
            examples=examples,
            world_size=world_size,
            rank=rank,
        )

        # 初始化日志参数
        self.initialize_logging()
        return n_itr
Esempio n. 15
0
def build_and_train(game="cartpole",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2,
                    eval=False,
                    serial=False,
                    train_mask=[True, True],
                    wandb_log=False,
                    save_models_to_wandb=False,
                    log_interval_steps=1e5,
                    observation_mode="agent",
                    inc_player_last_act=False,
                    alt_train=False,
                    eval_perf=False,
                    n_steps=50e6,
                    one_agent=False):
    # def envs:
    if observation_mode == "agent":
        fully_obs = False
        rand_obs = False
    elif observation_mode == "random":
        fully_obs = False
        rand_obs = True
    elif observation_mode == "full":
        fully_obs = True
        rand_obs = False

    n_serial = None
    if game == "cartpole":
        work_env = gym.make
        env_name = 'CartPole-v1'
        cont_act = False
        state_space_low = np.asarray([
            0.0, 0.0, 0.0, 0.0, -4.8000002e+00, -3.4028235e+38, -4.1887903e-01,
            -3.4028235e+38
        ])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 4.8000002e+00, 3.4028235e+38, 4.1887903e-01,
            3.4028235e+38
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        player_act_space.shape = (1, )
        print(player_act_space)
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 1),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_cartpole
        observer_reward_shaping = observer_reward_shaping_cartpole
        max_decor_steps = 20
        b_size = 20
        num_envs = 8
        max_episode_length = np.inf
        player_model_kwargs = dict(hidden_sizes=[24],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[64],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "hiv":
        work_env = wn.gym.make
        env_name = 'HIV-v0'
        cont_act = False
        state_space_low = np.asarray(
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, np.inf, np.inf, np.inf, np.inf,
            np.inf, np.inf
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 3),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_hiv
        observer_reward_shaping = observer_reward_shaping_hiv
        max_decor_steps = 10
        b_size = 32
        num_envs = 8
        max_episode_length = 100
        player_model_kwargs = dict(hidden_sizes=[32],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[64],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "heparin":
        work_env = HeparinEnv
        env_name = 'Heparin-Simulator'
        cont_act = False
        state_space_low = np.asarray([
            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 18728.926, 72.84662, 0.0, 0.0,
            0.0, 0.0, 0.0
        ])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.7251439e+04, 1.0664291e+02,
            200.0, 8.9383472e+02, 1.0025734e+02, 1.5770737e+01, 4.7767456e+01
        ])
        # state_space_low = np.asarray([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,18728.926,72.84662,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
        # state_space_high = np.asarray([1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.7251439e+04,1.0664291e+02,0.0000000e+00,8.9383472e+02,1.4476662e+02,1.3368750e+02,1.6815166e+02,1.0025734e+02,1.5770737e+01,4.7767456e+01,7.7194958e+00])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 4),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_hep
        observer_reward_shaping = observer_reward_shaping_hep
        max_decor_steps = 3
        b_size = 20
        num_envs = 8
        max_episode_length = 20
        player_model_kwargs = dict(hidden_sizes=[32],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[128],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "halfcheetah":
        assert not serial
        assert not one_agent
        work_env = gym.make
        env_name = 'HalfCheetah-v2'
        cont_act = True
        temp_env = work_env(env_name)
        state_space_low = np.concatenate([
            np.zeros(temp_env.observation_space.low.shape),
            temp_env.observation_space.low
        ])
        state_space_high = np.concatenate([
            np.ones(temp_env.observation_space.high.shape),
            temp_env.observation_space.high
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = temp_env.action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 4),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = None
        observer_reward_shaping = None
        temp_env.close()
        max_decor_steps = 0
        b_size = 20
        num_envs = 8
        max_episode_length = np.inf
        player_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_model_kwargs = dict(hidden_sizes=[256, 256])
        player_q_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_q_model_kwargs = dict(hidden_sizes=[256, 256])
        player_v_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_v_model_kwargs = dict(hidden_sizes=[256, 256])
    if game == "halfcheetah":
        observer_act_space = Box(
            low=state_space_low[:int(len(state_space_low) / 2)],
            high=state_space_high[:int(len(state_space_high) / 2)])
    else:
        if serial:
            n_serial = int(len(state_space_high) / 2)
            observer_act_space = Discrete(2)
            observer_act_space.shape = (1, )
        else:
            if one_agent:
                observer_act_space = IntBox(
                    low=0,
                    high=player_act_space.n *
                    int(2**int(len(state_space_high) / 2)))
            else:
                observer_act_space = IntBox(low=0,
                                            high=int(2**int(
                                                len(state_space_high) / 2)))

    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        alt = False
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        if eval:
            eval_collector_cl = SerialEvalCollector
        else:
            eval_collector_cl = None
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        alt = False
        Sampler = CpuSampler
        if eval:
            eval_collector_cl = CpuEvalCollector
        else:
            eval_collector_cl = None
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    env_kwargs = dict(work_env=work_env,
                      env_name=env_name,
                      obs_spaces=[obs_space, observer_obs_space],
                      action_spaces=[player_act_space, observer_act_space],
                      serial=serial,
                      player_reward_shaping=player_reward_shaping,
                      observer_reward_shaping=observer_reward_shaping,
                      fully_obs=fully_obs,
                      rand_obs=rand_obs,
                      inc_player_last_act=inc_player_last_act,
                      max_episode_length=max_episode_length,
                      cont_act=cont_act)
    if eval:
        eval_env_kwargs = env_kwargs
        eval_max_steps = 1e4
        num_eval_envs = num_envs
    else:
        eval_env_kwargs = None
        eval_max_steps = None
        num_eval_envs = 0
    sampler = Sampler(
        EnvCls=CWTO_EnvWrapper,
        env_kwargs=env_kwargs,
        batch_T=b_size,
        batch_B=num_envs,
        max_decorrelation_steps=max_decor_steps,
        eval_n_envs=num_eval_envs,
        eval_CollectorCls=eval_collector_cl,
        eval_env_kwargs=eval_env_kwargs,
        eval_max_steps=eval_max_steps,
    )
    if game == "halfcheetah":
        player_algo = SAC()
        observer_algo = SACBeta()
        player = SacAgent(ModelCls=PiMlpModel,
                          QModelCls=QofMuMlpModel,
                          model_kwargs=player_model_kwargs,
                          q_model_kwargs=player_q_model_kwargs,
                          v_model_kwargs=player_v_model_kwargs)
        observer = SacAgentBeta(ModelCls=PiMlpModelBeta,
                                QModelCls=QofMuMlpModel,
                                model_kwargs=observer_model_kwargs,
                                q_model_kwargs=observer_q_model_kwargs,
                                v_model_kwargs=observer_v_model_kwargs)
    else:
        player_model = CWTO_LstmModel
        observer_model = CWTO_LstmModel

        player_algo = PPO()
        observer_algo = PPO()
        player = CWTO_LstmAgent(ModelCls=player_model,
                                model_kwargs=player_model_kwargs,
                                initial_model_state_dict=None)
        observer = CWTO_LstmAgent(ModelCls=observer_model,
                                  model_kwargs=observer_model_kwargs,
                                  initial_model_state_dict=None)
    if one_agent:
        agent = CWTO_AgentWrapper(player,
                                  observer,
                                  serial=serial,
                                  n_serial=n_serial,
                                  alt=alt,
                                  train_mask=train_mask,
                                  one_agent=one_agent,
                                  nplayeract=player_act_space.n)
    else:
        agent = CWTO_AgentWrapper(player,
                                  observer,
                                  serial=serial,
                                  n_serial=n_serial,
                                  alt=alt,
                                  train_mask=train_mask)

    if eval:
        RunnerCl = MinibatchRlEval
    else:
        RunnerCl = MinibatchRl

    runner = RunnerCl(player_algo=player_algo,
                      observer_algo=observer_algo,
                      agent=agent,
                      sampler=sampler,
                      n_steps=n_steps,
                      log_interval_steps=log_interval_steps,
                      affinity=affinity,
                      wandb_log=wandb_log,
                      alt_train=alt_train)
    config = dict(domain=game)
    if game == "halfcheetah":
        name = "sac_" + game
    else:
        name = "ppo_" + game
    log_dir = os.getcwd() + "/cwto_logs/" + name
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
    if save_models_to_wandb:
        agent.save_models_to_wandb()
    if eval_perf:
        eval_n_envs = 10
        eval_envs = [CWTO_EnvWrapper(**env_kwargs) for _ in range(eval_n_envs)]
        set_envs_seeds(eval_envs, make_seed())
        eval_collector = SerialEvalCollector(envs=eval_envs,
                                             agent=agent,
                                             TrajInfoCls=TrajInfo_obs,
                                             max_T=1000,
                                             max_trajectories=10,
                                             log_full_obs=True)
        traj_infos_player, traj_infos_observer = eval_collector.collect_evaluation(
            runner.get_n_itr())
        observations = []
        player_actions = []
        returns = []
        observer_actions = []
        for traj in traj_infos_player:
            observations.append(np.array(traj.Observations))
            player_actions.append(np.array(traj.Actions))
            returns.append(traj.Return)
        for traj in traj_infos_observer:
            observer_actions.append(
                np.array([
                    obs_action_translator(act, eval_envs[0].power_vec,
                                          eval_envs[0].obs_size)
                    for act in traj.Actions
                ]))

        # save results:
        open_obs = open('eval_observations.pkl', "wb")
        pickle.dump(observations, open_obs)
        open_obs.close()
        open_ret = open('eval_returns.pkl', "wb")
        pickle.dump(returns, open_ret)
        open_ret.close()
        open_pact = open('eval_player_actions.pkl', "wb")
        pickle.dump(player_actions, open_pact)
        open_pact.close()
        open_oact = open('eval_observer_actions.pkl', "wb")
        pickle.dump(observer_actions, open_oact)
        open_oact.close()