def initialize_worker(rank, seed=None, cpu=None, torch_threads=None): """Assign CPU affinity, set random seed, set torch_threads if needed to prevent MKL deadlock. """ log_str = f"Sampler rank {rank} initialized" cpu = [cpu] if isinstance(cpu, int) else cpu p = psutil.Process() try: if cpu is not None: p.cpu_affinity(cpu) cpu_affin = p.cpu_affinity() except AttributeError: cpu_affin = "UNAVAILABLE MacOS" log_str += f", CPU affinity {cpu_affin}" torch_threads = ( 1 if torch_threads is None and cpu is not None else torch_threads ) # Default to 1 to avoid possible MKL hang. if torch_threads is not None: torch.set_num_threads(torch_threads) log_str += f", Torch threads {torch.get_num_threads()}" if seed is not None: set_seed(seed) time.sleep(0.3) # (so the printing from set_seed is not intermixed) log_str += f", Seed {seed}" logger.log(log_str)
def initialize_worker(rank, seed=None, cpu=None, torch_threads=None): """ 初始化采样用的worker。 :param rank: 采样进程的标识序号。 :param seed: 种子,一个整数值。 :param cpu: CPU序号,例如 0, 1, 2 等等。 :param torch_threads: CPU并发执行的线程数。 """ log_str = f"Sampler rank {rank} initialized" cpu = [cpu] if isinstance(cpu, int) else cpu p = psutil.Process() try: if cpu is not None: p.cpu_affinity(cpu) # 设置CPU亲和性(MacOS不支持) cpu_affin = p.cpu_affinity() except AttributeError: cpu_affin = "UNAVAILABLE MacOS" log_str += f", CPU affinity {cpu_affin}" torch_threads = ( 1 if torch_threads is None and cpu is not None else torch_threads ) # Default to 1 to avoid possible MKL hang. if torch_threads is not None: torch.set_num_threads(torch_threads) # 设置CPU并发执行的线程数 log_str += f", Torch threads {torch.get_num_threads()}" if seed is not None: set_seed(seed) time.sleep(0.3) # (so the printing from set_seed is not intermixed) log_str += f", Seed {seed}" logger.log(log_str)
def startup(self): """ Calls ``sampler.async_initialize()`` to get a double buffer for minibatches, followed by ``algo.async_initialize()`` to get a replay buffer on shared memory, then launches all workers (sampler, optimizer, memory copier). """ if self.seed is None: self.seed = make_seed() set_seed(self.seed) double_buffer, examples = self.sampler.async_initialize( agent=self.agent, bootstrap_value=getattr(self.algo, "bootstrap_value", False), traj_info_kwargs=self.get_traj_info_kwargs(), seed=self.seed, ) self.sampler_batch_size = self.sampler.batch_spec.size self.world_size = len(self.affinity.optimizer) n_itr = self.get_n_itr() # Number of sampler iterations. replay_buffer = self.algo.async_initialize( agent=self.agent, sampler_n_itr=n_itr, batch_spec=self.sampler.batch_spec, mid_batch_reset=self.sampler.mid_batch_reset, examples=examples, world_size=self.world_size, ) self.launch_workers(n_itr, double_buffer, replay_buffer) throttle_itr, delta_throttle_itr = self.optim_startup() return throttle_itr, delta_throttle_itr
def initialize_worker(rank, seed=None, cpu=None, torch_threads=None, group=None): log_str = f"Sampler rank {rank} initialized" cpu = [cpu] if isinstance(cpu, int) else cpu p = psutil.Process() try: if cpu is not None: p.cpu_affinity(cpu) cpu_affin = p.cpu_affinity() except AttributeError: cpu_affin = "UNAVAILABLE MacOS" log_str += f", CPU affinity {cpu_affin}" torch_threads = (len(cpu) if torch_threads is None and cpu is not None else torch_threads) if torch_threads is not None: torch.set_num_threads(torch_threads) log_str += f", Torch threads {torch.get_num_threads()}" if seed is not None: set_seed(seed) time.sleep(0.3) # (so the printing from set_seed is not intermixed) log_str += f", Seed {seed}" logger.log(log_str)
def startup(self): p = psutil.Process() try: if self.affinity.get("master_cpus", None) is not None and self.affinity.get( "set_affinity", True): p.cpu_affinity(self.affinity["master_cpus"]) cpu_affin = p.cpu_affinity() except AttributeError: cpu_affin = "UNAVAILABLE MacOS" logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: " f"{cpu_affin}.") if self.affinity.get("master_torch_threads", None) is not None: torch.set_num_threads(self.affinity["master_torch_threads"]) logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: " f"{torch.get_num_threads()}.") if self.seed is None: self.seed = make_seed() set_seed(self.seed) # self.rank = rank = getattr(self, "rank", 0) # self.world_size = world_size = getattr(self, "world_size", 1) self.algo.initialize( n_updates=self.n_updates, cuda_idx=self.affinity.get("cuda_idx", None), ) self.initialize_logging()
def startup(self): """ Sets hardware affinities, initializes the following: 1) sampler (which should initialize the agent), 2) agent device and data-parallel wrapper (if applicable), 3) algorithm, 4) logger. """ p = psutil.Process() try: if (self.affinity.get("master_cpus", None) is not None and self.affinity.get("set_affinity", True)): p.cpu_affinity(self.affinity["master_cpus"]) cpu_affin = p.cpu_affinity() except AttributeError: cpu_affin = "UNAVAILABLE MacOS" logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: " f"{cpu_affin}.") if self.affinity.get("master_torch_threads", None) is not None: torch.set_num_threads(self.affinity["master_torch_threads"]) logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: " f"{torch.get_num_threads()}.") if self.seed is None: self.seed = make_seed() set_seed(self.seed) self.rank = rank = getattr(self, "rank", 0) self.world_size = world_size = getattr(self, "world_size", 1) examples = self.sampler.initialize( agent=self.agent, # Agent gets initialized in sampler. affinity=self.affinity, seed=self.seed + 1, bootstrap_value=getattr(self.algo, "bootstrap_value", False), traj_info_kwargs=self.get_traj_info_kwargs(), rank=rank, world_size=world_size, ) self.itr_batch_size = self.sampler.batch_spec.size * world_size n_itr = self.get_n_itr() print("CUDA: ", self.affinity.get("cuda_idx", None)) self.agent.to_device(self.affinity.get("cuda_idx", None)) if world_size > 1: self.agent.data_parallel() self.algo.initialize( agent=self.agent, n_itr=n_itr, batch_spec=self.sampler.batch_spec, mid_batch_reset=self.sampler.mid_batch_reset, examples=examples, world_size=world_size, rank=rank, ) self.initialize_logging() return n_itr
def startup(self): """ Adds support for next_obs sampling. """ p = psutil.Process() try: if (self.affinity.get("master_cpus", None) is not None and self.affinity.get("set_affinity", True)): p.cpu_affinity(self.affinity["master_cpus"]) cpu_affin = p.cpu_affinity() except AttributeError: cpu_affin = "UNAVAILABLE MacOS" logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: " f"{cpu_affin}.") if self.affinity.get("master_torch_threads", None) is not None: torch.set_num_threads(self.affinity["master_torch_threads"]) logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: " f"{torch.get_num_threads()}.") if self.seed is None: self.seed = make_seed() set_seed(self.seed) self.rank = rank = getattr(self, "rank", 0) self.world_size = world_size = getattr(self, "world_size", 1) examples = self.sampler.initialize( agent=self.agent, # Agent gets initialized in sampler. affinity=self.affinity, seed=self.seed + 1, bootstrap_value=getattr(self.algo, "bootstrap_value", False), next_obs=getattr(self.algo, "next_obs", False), # MODIFIED HERE traj_info_kwargs=self.get_traj_info_kwargs(), rank=rank, world_size=world_size, ) self.itr_batch_size = self.sampler.batch_spec.size * world_size n_itr = self.get_n_itr() self.agent.to_device(self.affinity.get("cuda_idx", None)) if world_size > 1: self.agent.data_parallel() self.algo.initialize( agent=self.agent, n_itr=n_itr, batch_spec=self.sampler.batch_spec, mid_batch_reset=self.sampler.mid_batch_reset, examples=examples, world_size=world_size, rank=rank, ) self.initialize_logging() return n_itr
def startup(self): p = psutil.Process() p.cpu_affinity(self.affinity["cpus"]) logger.log(f"Optimizer rank {rank} CPU affinity: {p.cpu_affinity()}.") torch.set_num_threads(self.affinity["torch_threads"]) logger.log(f"Optimizer rank {rank} Torch threads: {torch.get_num_threads()}.") logger.log(f"Optimizer rank {rank} CUDA index: " f"{self.affinity.get('cuda_idx', None)}.") set_seed(self.seed) self.agent.initialize_cuda( cuda_idx=self.affinity.get("cuda_idx", None), dpp=True, ) self.algo.initialize_async(agent=self.agent, updates_per_sync=self.updates_per_sync)
def startup(self): torch.distributed.init_process_group( backend="nccl", rank=self.rank, world_size=self.world_size, init_method=f"tcp://127.0.0.1:{self.port}", ) p = psutil.Process() if self.affinity.get("set_affinity", True): p.cpu_affinity(self.affinity["cpus"]) logger.log(f"Optimizer rank {self.rank} CPU affinity: {p.cpu_affinity()}.") torch.set_num_threads(self.affinity["torch_threads"]) logger.log(f"Optimizer rank {self.rank} Torch threads: {torch.get_num_threads()}.") logger.log(f"Optimizer rank {self.rank} CUDA index: " f"{self.affinity.get('cuda_idx', None)}.") set_seed(self.seed) self.agent.to_device(cuda_idx=self.affinity.get("cuda_idx", None)) self.agent.data_parallel() self.algo.optim_initialize(rank=self.rank)
def startup(self): p = psutil.Process() p.cpu_affinity(self.affinity["cpus"]) logger.log("Optimizer master CPU affinity: {p.cpu_affinity()}.") torch.set_num_threads(self.affinity["torch_threads"]) logger.log("Optimizer master Torch threads: {torch.get_num_threads()}.") set_seed(self.seed) self.agent.initialize_cuda( cuda_idx=self.affinity.get("cuda_idx", None), dpp=self.n_runner > 1, ) self.algo.initialize_async(agent=self.agent, updates_per_sync=self.updates_per_sync) throttle_itr = 1 + self.algo.min_steps_learn // self.itr_batch_size delta_throttle_itr = (self.algo.batch_size * self.n_runner * self.algo.updates_per_optimize / # (is updates_per_sync) (self.itr_batch_size * self.training_ratio)) self.initilaize_logging() return throttle_itr, delta_throttle_itr
def startup(self): p = psutil.Process() try: if self.affinity.get("master_cpus", None) is not None: p.cpu_affinity(self.affinity["master_cpus"]) cpu_affin = p.cpu_affinity() except AttributeError: cpu_affin = "UNAVAILABLE MacOS" logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: " f"{cpu_affin}.") if self.affinity.get("master_torch_threads", None) is not None: torch.set_num_threads(self.affinity["master_torch_threads"]) logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: " f"{torch.get_num_threads()}.") if self.seed is None: self.seed = make_seed() set_seed(self.seed) examples = self.sampler.initialize( agent=self.agent, # Agent gets intialized in sampler. affinity=self.affinity, seed=self.seed + 1, bootstrap_value=getattr(self.algo, "bootstrap_value", False), traj_info_kwargs=self.get_traj_info_kwargs(), ) n_runners = getattr(self, "n_runners", 1) self.itr_batch_size = self.sampler.batch_spec.size * n_runners n_itr = self.get_n_itr() self.agent.initialize_cuda( cuda_idx=self.affinity.get("cuda_idx", None), ddp=n_runners > 1, # Multi-GPU training (and maybe sampling). ) self.algo.initialize(agent=self.agent, n_itr=n_itr, batch_spec=self.sampler.batch_spec, mid_batch_reset=self.sampler.mid_batch_reset, examples=examples) self.initialize_logging() return n_itr
def test_seed(self): sampler = SerialSampler( EnvCls=gym_make, env_kwargs={"id": "MountainCarContinuous-v0"}, batch_T=1, batch_B=1, ) agent = SacAgent(pretrain_std=0.0) agent.give_min_itr_learn(10000) set_seed(0) sampler.initialize(agent, seed=0) samples_1 = sampler.obtain_samples(0) set_seed(0) sampler.initialize(agent, seed=0) samples_2 = sampler.obtain_samples(0) # Dirty hack to compare objects containing tensors. self.assertEqual(str(samples_1), str(samples_2)) samples_3 = sampler.obtain_samples(0) self.assertNotEqual(samples_1, samples_3)
def startup(self): if self.seed is None: self.seed = make_seed() set_seed(self.seed) double_buffer, examples = self.sampler.async_initialize( agent=self.agent, bootstrap_value=getattr(self.algo, "bootstrap_value", False), traj_info_kwargs=self.get_traj_info_kwargs(), seed=self.seed, ) self.sampler_batch_size = self.sampler.batch_spec.size self.world_size = len(self.affinity.optimizer) n_itr = self.get_n_itr() # Number of sampler iterations. replay_buffer = self.algo.async_initialize( agent=self.agent, sampler_n_itr=n_itr, batch_spec=self.sampler.batch_spec, mid_batch_reset=self.sampler.mid_batch_reset, examples=examples, world_size=self.world_size, ) self.launch_workers(n_itr, double_buffer, replay_buffer) throttle_itr, delta_throttle_itr = self.optim_startup() return throttle_itr, delta_throttle_itr
def build_and_train(log_dir, game="pong", run_ID=0, cuda_idx=None, eval=False, save_model='last', load_model_path=None, n_parallel=2, CumSteps=0): device = 'cpu' if cuda_idx is None else 'cuda' params = torch.load( load_model_path, map_location=torch.device(device)) if load_model_path else {} agent_state_dict = params.get('agent_state_dict') optimizer_state_dict = params.get('optimizer_state_dict') ##--- wu ---## log_interval_steps = 5e4 prefill = 5e4 train_every = 16 batch_B = 16 n_steps = 1e4 if eval else 5e6 itr_start = max(0, CumSteps - prefill) // train_every ##--- wu ---## action_repeat = 4 # 2 env_kwargs = dict( name=game, action_repeat=action_repeat, size=(64, 64), grayscale=True, # False life_done=True, sticky_actions=True, ) factory_method = make_wapper( AtariEnv, [OneHotAction, TimeLimit], [dict(), dict(duration=1000000 / action_repeat)]) # 1000 sampler = GpuSampler( EnvCls=factory_method, TrajInfoCls=AtariTrajInfo, env_kwargs=env_kwargs, eval_env_kwargs=env_kwargs, batch_T=1, batch_B=batch_B, max_decorrelation_steps=0, eval_n_envs=10, eval_max_steps=int(10e5), eval_max_trajectories=5, ) algo = Dreamer( initial_optim_state_dict=optimizer_state_dict, horizon=10, use_pcont=True, replay_size=int(2e6), # int(5e6) kl_scale=0.1, batch_size=50, batch_length=50, C=1, # 100, train_every=train_every // batch_B, # 1000 pretrain=100, world_lr=2e-4, # 6e-4, value_lr=1e-4, # 8e-5, actor_lr=4e-5, # 8e-5, discount=0.999, # 0.99, expl_amount=0.0, # 0.3, prefill=prefill // batch_B, # 5000 discount_scale=5., # 10. video_every=int(2e4 // 16 * 16 // batch_B), # int(10) ) if eval: # for eval - all versions agent = AtariDreamerAgent(train_noise=0.0, eval_noise=0, expl_type="epsilon_greedy", itr_start=itr_start, the_expl_mode='eval', expl_min=0.0, expl_decay=11000, initial_model_state_dict=agent_state_dict, model_kwargs=dict(use_pcont=True)) else: # for train - all versions # agent = AtariDreamerAgent(train_noise=0.4, eval_noise=0, expl_type="epsilon_greedy", itr_start=itr_start, the_expl_mode='train', # expl_min=0.1, expl_decay=11000, initial_model_state_dict=agent_state_dict, # model_kwargs=dict(use_pcont=True)) # for train - dreamer_V2 agent = AtariDreamerAgent(train_noise=0.0, eval_noise=0, expl_type="epsilon_greedy", itr_start=itr_start, the_expl_mode='train', expl_min=0.0, expl_decay=11000, initial_model_state_dict=agent_state_dict, model_kwargs=dict(use_pcont=True)) my_seed = 0 # reproductivity set_seed(my_seed) runner_cls = MinibatchRlEval if eval else MinibatchRl runner = runner_cls( algo= algo, # Uses gathered samples to train the agent (e.g. defines a loss function and performs gradient descent). agent= agent, # Chooses control action to the environment in sampler; trained by the algorithm. Interface to model. sampler=sampler, n_steps=n_steps, log_interval_steps=log_interval_steps, affinity=dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel))), seed=my_seed, ) config = dict(game=game) name = "dreamer_" + game with logger_context(log_dir, run_ID, name, config, snapshot_mode=save_model, override_prefix=True, use_summary_writer=True): runner.train()
def startup(self): """ Sets hardware affinities, initializes the following: 1) sampler (which should initialize the agent), 2) agent device and data-parallel wrapper (if applicable), 3) algorithm, 4) logger. This function is nearly identical to MinibatchRlBase.startup with the main difference being the initialization of the extra eval samplers. """ p = psutil.Process() try: if self.affinity.get("master_cpus", None) is not None and self.affinity.get( "set_affinity", True): p.cpu_affinity(self.affinity["master_cpus"]) cpu_affin = p.cpu_affinity() except AttributeError: cpu_affin = "UNAVAILABLE MacOS" print(f"Runner {getattr(self, 'rank', '')} master CPU affinity: " f"{cpu_affin}.") if self.affinity.get("master_torch_threads", None) is not None: torch.set_num_threads(self.affinity["master_torch_threads"]) print(f"Runner {getattr(self, 'rank', '')} master Torch threads: " f"{torch.get_num_threads()}.") set_seed(self.seed) self.rank = rank = getattr(self, "rank", 0) self.world_size = world_size = getattr(self, "world_size", 1) for i, sampler in enumerate(self.extra_eval_samplers.values()): sampler.initialize( agent=self.agent, # Agent gets intialized in sampler. affinity=self.affinity, seed=self.seed + i, bootstrap_value=getattr(self.algo, "bootstrap_value", False), traj_info_kwargs=self.get_traj_info_kwargs(), rank=rank, world_size=world_size, ) examples = self.sampler.initialize( agent=self.agent, # Agent gets intialized in sampler. affinity=self.affinity, seed=self.seed + 1, bootstrap_value=getattr(self.algo, "bootstrap_value", False), traj_info_kwargs=self.get_traj_info_kwargs(), rank=rank, world_size=world_size, ) self.itr_batch_size = self.sampler.batch_spec.size * world_size n_itr = self.get_n_itr() self.agent.to_device(self.affinity.get("cuda_idx", None)) if world_size > 1: self.agent.data_parallel() self.algo.initialize( agent=self.agent, n_itr=self.n_itr, batch_spec=self.sampler.batch_spec, mid_batch_reset=self.sampler.mid_batch_reset, examples=examples, world_size=world_size, rank=rank, ) print( f"Running {self.n_itr} iterations with batch size {self.itr_batch_size}." ) self.logger.initialize_logging() return n_itr
def build_and_train(env_id="HalfCheetah-v3", log_dir='results', alg_name='ddpg', run_ID=0, cuda_idx=None, seed=42, q_hidden_sizes=[64, 64], q_nonlinearity='relu', batch_size=32, q_target=None, log_freq=1e3): set_seed(seed) sampler = SerialSampler( EnvCls=gym_make, env_kwargs=dict(id=env_id), eval_env_kwargs=dict(id=env_id), batch_T=1, # One time-step per sampler iteration. batch_B=1, # One environment (i.e. sampler Batch dimension). max_decorrelation_steps=0, eval_n_envs=10, eval_max_steps=int(51e3), eval_max_trajectories=50, ) if q_nonlinearity == 'relu': q_nonlin = torch.nn.ReLU if q_nonlinearity == 'sine': q_nonlin = Sine if q_nonlinearity == 'linear': q_nonlin = Linear if alg_name.lower() == 'ddpg': if q_target is None: q_target = True algo = DDPG(batch_size=batch_size, target=q_target, min_steps_learn=log_freq) agent = DdpgAgent(q_hidden_sizes=q_hidden_sizes, q_nonlinearity=q_nonlin) elif alg_name.lower() == 'preqn': if q_target is None: q_target = False algo = PreQN(batch_size=batch_size, target=q_target, min_steps_learn=log_freq) agent = PreqnAgent(q_hidden_sizes=q_hidden_sizes, q_nonlinearity=q_nonlin) runner = MinibatchRlEval( algo=algo, agent=agent, sampler=sampler, seed=seed, n_steps=1e6, log_interval_steps=log_freq, #1e4, affinity=dict(cuda_idx=cuda_idx), ) config = dict(env_id=env_id) log_dir = os.path.join(log_dir, env_id) log_dir = os.path.join(log_dir, alg_name.lower()) log_dir += '-' + q_nonlinearity log_dir += '-hs' + str(q_hidden_sizes) log_dir += '-qt' + str(q_target) log_dir += '-bs' + str(batch_size) name = '' #env_id with logger_context(log_dir, run_ID, name, config, override_prefix=True, use_summary_writer=True): runner.train()
def startup(self): """ 一些初始化工作。 """ p = psutil.Process() # 获取当前进程的信息 # 设置CPU亲和性(MacOS不支持) try: if self.affinity.get("master_cpus", None) is not None and self.affinity.get( "set_affinity", True): p.cpu_affinity(self.affinity["master_cpus"]) cpu_affin = p.cpu_affinity() # set了之后再取出来 except AttributeError: cpu_affin = "UNAVAILABLE MacOS" logger.log( f"Runner {getattr(self, 'rank', '')} master CPU affinity: {cpu_affin}." ) # 设置线程数 if self.affinity.get("master_torch_threads", None) is not None: torch.set_num_threads( self.affinity["master_torch_threads"]) # 设置CPU并发执行的线程数 logger.log( f"Runner {getattr(self, 'rank', '')} master Torch threads: {torch.get_num_threads()}." ) # 设置随机数种子 if self.seed is None: self.seed = make_seed() set_seed(self.seed) self.rank = rank = getattr(self, "rank", 0) # rank是外部传进来的值,如果没有传则默认为1 self.world_size = world_size = getattr( self, "world_size", 1) # world_size是外部传进来的值,如果没有传则默认为1 # 初始化Sampler实例,这里的变量名examples起得不好,可能会让人误解 examples = self.sampler.initialize( agent=self. agent, # Agent gets initialized in sampler. agent会在Sampler中被初始化,所以要传进去 affinity=self.affinity, # CPU亲和性 seed=self.seed + 1, # 随机种子 bootstrap_value=getattr(self.algo, "bootstrap_value", False), traj_info_kwargs=self.get_traj_info_kwargs(), # 此方法里只设置折扣系数 rank=rank, world_size=world_size, ) """ batch_spec.size的实现参见 BatchSpec 类,它表示的所有的environment实例上的所有时间步的数量总和。这里又乘了一个没有说明含义的 world_size,用一个不正经的比喻,我猜这里有大概是"平行宇宙"的概念(美剧闪电侠),在"当前宇宙"内的发生的采样,它是算在 batch_spec.size 内的,而像这样的场景,我们可以把它复制很多个出来,用所有这些创造出来的集合来训练RL模型。 """ self.itr_batch_size = self.sampler.batch_spec.size * world_size # 所有迭代的时间步数 n_itr = self.get_n_itr( ) # 计算模型训练的迭代次数。在这里,迭代次数并不是直接指定的,而是经过一个复杂的方法计算出来的 self.agent.to_device(self.affinity.get("cuda_idx", None)) # 在指定的设备上运行程序 if world_size > 1: self.agent.data_parallel() # 初始化算法(Algorithm)对象 self.algo.initialize( agent=self.agent, n_itr=n_itr, batch_spec=self.sampler.batch_spec, mid_batch_reset=self.sampler.mid_batch_reset, examples=examples, world_size=world_size, rank=rank, ) # 初始化日志参数 self.initialize_logging() return n_itr
lead_dim, T, B, img_shape = infer_leading_dims(observation, 3) action = torch.randint(low=0, high=self.num_actions, size=(T * B, )) action = restore_leading_dims((action), lead_dim, T, B) return action # Setup the data collection pipeline sampler = GpuSampler(EnvCls=gym.make, env_kwargs=config["env"], CollectorCls=GpuResetCollector, eval_env_kwargs=config["env"], **config["sampler"]) agent = RandomAgent(ModelCls=RandomDiscreteModel, model_kwargs={"num_actions": 15}) seed = make_seed() set_seed(seed) sampler.initialize(agent=agent, affinity=affinity, seed=seed + 1, rank=0) steps = config["train_steps"] # Create the model model = BiGAN(**config["model"]) if config["load_path"]: model.load_state_dict(torch.load(config["load_path"])) # Setup the optimizers lr = config["optim"]["lr"] d_optimizer = torch.optim.Adam(model.d.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=2.5e-5) g_optimizer = torch.optim.Adam(list(model.e.parameters()) + list(model.g.parameters()),