def fullCycleTest(config="config/TestConfig.ini", chimney="EW00", **kargs): reader = None globalTimer = stopwatch.StopWatch() try: timer = stopwatch.StopWatch() if not os.path.exists(config): relConfig = os.path.join(os.pwd(), config) if os.path.exists(relConfig): config = relConfig # if logging.info("fullCycleTest(): ChimneyReader setup") kargs['configurationFile'] = config kargs['chimney'] = chimney timer.restart() reader = testDriver.ChimneyReader(**kargs) timer.stop() logging.info("fullCycleTest(): ChimneyReader setup ended in {}".format( timer.toString())) if not reader.scope: raise RuntimeError("Failed to contact the oscilloscope.") logging.info("fullCycleTest(): readout loop") timer.restart() reader.start() iRead = 0 while reader.readNext(): iRead += 1 print("Read #{}".format(iRead)) # while timer.stop() logging.info("fullCycleTest(): readout ended in {}".format( timer.toString())) logging.info("fullCycleTest(): verification") timer.restart() success = reader.verify() timer.stop() logging.info("fullCycleTest(): verification ended in {}".format( timer.toString())) if not success: raise RuntimeError("Verification failed!") logging.info("fullCycleTest(): archival script generation") timer.restart() reader.generateArchivalScript() timer.stop() logging.info("fullCycleTest(): script generation ended in {}".format( timer.toString())) except Exception as e: print >> sys.stderr, e globalTimer.stop() logging.info("fullCycleTest(): test took {}".format( globalTimer.toString())) if reader: reader.printTimers() return reader
def make_new_timer(self): new_timer = tk.Frame(self.indTimersFrame) new_timer.pack(side=tk.TOP, anchor=tk.N) timer_e = tk.Entry(new_timer, width=5) timer_e.insert(tk.END, 'Well #') timer_e.pack(side=tk.LEFT) sw = stopwatch.StopWatch(parent=new_timer) sw.pack(side=tk.LEFT) start_button = tk.Button( new_timer, text="Start", command=lambda: self.sw_mating_start(timer_e.get(), sw)) start_button.pack(side=tk.LEFT) stop_button = tk.Button( new_timer, text="Stop", command=lambda: self.sw_mating_stop(timer_e.get(), sw)) stop_button.pack(side=tk.LEFT) reset_button = tk.Button(new_timer, text="Reset", command=sw.Reset) reset_button.pack(side=tk.LEFT) destroy_button = tk.Button(new_timer, text="Destroy", command=lambda: self.destroy_timer( timer, sw, start_button, stop_button, reset_button, destroy_button)) destroy_button.pack(side=tk.LEFT)
def __init__(self): super().__init__(parent=None, title='BeBe Clock Tool', size=(700, 300)) super().SetBackgroundColour(wx.WHITE) nb = wx.Notebook(self) nb.AddPage(stopwatch.StopWatch(nb), "Stopwatch") nb.AddPage(timer.Timer(nb), "Timer") nb.AddPage(alarm.Alarm(nb), "Alarm") self.Show()
def start_policy_worker(inputs): sw = stopwatch.StopWatch() args, experiment_name, i, lock, stats_queue, device, \ next_obs, next_done, obs, actions, logprobs, rewards, dones, values, traj_availables, \ rollout_task_queues, policy_request_queue, learner_request_queue = inputs device = torch.device('cuda') agent = Agent(4).to(device) min_num_requests = 6 wait_for_min_requests = 0.025 # time.sleep(5) step = 0 while True: step += 1 with sw.timer('policy_worker'): waiting_started = time.time() policy_requests = [] with sw.timer('policy_requests'): while len(policy_requests) < min_num_requests and time.time( ) - waiting_started < wait_for_min_requests: try: policy_requests.extend( policy_request_queue.get_many(timeout=0.005)) except Empty: pass if len(policy_requests) == 0: continue with sw.timer('prepare_data'): ls = np.concatenate(policy_requests) rollout_worker_idxs = ls.T[0, ::args.num_envs // args.num_env_split] split_idxs = ls.T[1, ::args.num_envs // args.num_env_split] step_idxs = ls.T[-1, ::args.num_envs // args.num_env_split] idxs = tuple(ls.T) with sw.timer('index'): t1 = next_obs[idxs[:-1]] with sw.timer('create array'): t2 = torch.from_numpy(next_obs[idxs[:-1]]) with sw.timer('convert float'): t3 = t2.float() with sw.timer('move_to_gpu'): next_o = t3.to(device) # next_o = torch.from_numpy(next_obs[idxs[:-1]]).float().to(device) with sw.timer('inference'): with torch.no_grad(): a, l, e = agent.get_action(next_o) with sw.timer('move_to_cpu'): actions[idxs] = a.cpu() with sw.timer('execute_action'): for j in range(len(rollout_worker_idxs)): rollout_worker_idx = rollout_worker_idxs[j] split_idx = split_idxs[j] step_idx = step_idxs[j] rollout_task_queues[rollout_worker_idx].put( [split_idx, step_idx]) if step % 1000 == 0: print(ls.shape) print(stopwatch.format_report(sw.get_last_aggregated_report()))
def benchmark_find_local_nodes_impact(): sw = stopwatch.StopWatch() with sw.timer('connection_speed'): for i in range(100): with sw.timer('default'): Node().default() # for i in range(100): # with sw.timer('preselected'): # Node().default2() print(stopwatch.format_report(sw.get_last_aggregated_report()))
def start_rollout_worker(self, rollout_worker_idx, env_fns): sw = stopwatch.StopWatch() next_obs, next_done, obs, actions, logprobs, rewards, dones, values = self.storage env_idxs = range( rollout_worker_idx * self.num_envs_per_rollout_worker, rollout_worker_idx * self.num_envs_per_rollout_worker + self.num_envs_per_rollout_worker) envs = [None for _ in range(len(self.env_fns))] for env_idx in env_idxs: envs[env_idx] = self.env_fns[env_idx]() next_step = 0 self.policy_request_queue.put( [next_step, env_idx, rollout_worker_idx]) next_obs[env_idx] = torch.tensor(envs[env_idx].reset()) next_done[env_idx] = 0 local_step = 0 while True: with sw.timer('act'): with sw.timer('wait_rollout_task_queue'): tasks = self.rollout_task_queues[ rollout_worker_idx].get_many() with sw.timer('rollouts'): for task in tasks: step, env_idx = task obs[step, env_idx] = next_obs[env_idx].copy() dones[step, env_idx] = next_done[env_idx].copy() next_obs[env_idx], r, d, info = envs[env_idx].step( actions[step, env_idx]) if d: next_obs[env_idx] = envs[env_idx].reset() rewards[step, env_idx] = r next_done[env_idx] = d next_step = step + 1 local_step += 1 with sw.timer('logging'): self.policy_request_queue.put( [next_step, env_idx, rollout_worker_idx]) if 'episode' in info.keys(): # print(["charts/episode_reward", info['episode']['r']]) # self.stats_queue.put(['l', info['episode']['l']]) self.stats_queue.put([ "charts/episode_reward", info['episode']['r'] ]) if local_step % 1000 == 0: print(stopwatch.format_report(sw.get_last_aggregated_report())) print()
def start_policy_worker(inputs): # raise args, experiment_name, i, lock, stats_queue, device, \ next_obs, next_done, obs, actions, logprobs, rewards, dones, values, traj_availables, \ rollout_task_queues, policy_request_queue, learner_request_queue = inputs data_loader = torch.utils.data.DataLoader(PolicyWorkerDataset(policy_request_queue, obs), batch_size=200) #, num_workers=2, pin_memory=True sw = stopwatch.StopWatch() device = torch.device('cuda') agent = Agent(4).to(device) min_num_requests = 3 wait_for_min_requests = 0.01 # time.sleep(5) step = 0 for batch_idx, (ls, next_o) in enumerate(data_loader): step += 1 with sw.timer('policy_worker'): with sw.timer('create array at gpu'): next_o = next_o.to(device, non_blocking=True) with sw.timer("prepare_data"): ls = ls.numpy() rollout_worker_idxs = ls.T[0,::args.num_envs//args.num_env_split] split_idxs = ls.T[1,::args.num_envs//args.num_env_split] step_idxs = ls.T[-1,::args.num_envs//args.num_env_split] idxs = tuple(ls.T) with sw.timer('inference'): with torch.no_grad(): a, l, e = agent.get_action(next_o) with sw.timer('move_to_cpu'): actions[idxs] = a.cpu() with sw.timer('execute_action'): for j in range(len(rollout_worker_idxs)): rollout_worker_idx = rollout_worker_idxs[j] split_idx = split_idxs[j] step_idx = step_idxs[j] rollout_task_queues[rollout_worker_idx].put([split_idx,step_idx]) # for idx, item in enumerate(idxs): # rollout_worker_idx = item[0] # split_idx = item[1] # step_idx = item[2] # with sw.timer('put_action'): # rollout_task_queues[rollout_worker_idx].put([split_idx,step_idx]) # actions[idxs] = a.cpu() # for j in range(len(rollout_worker_idxs)): # rollout_worker_idx = rollout_worker_idxs[j] # split_idx = split_idxs[j] # step_idx = step_idxs[j] if step % 100 == 0: # print(ls.shape) print(stopwatch.format_report(sw.get_last_aggregated_report()))
def track_time(t: tracker, curr_subject): '''control the starting and stopping of the timer''' print("Time Started") timer = stopwatch.StopWatch() timer.start() x = input(message3) while x: print(timer.elapsed()) x = input(message3) timer.stop() print(timer.elapsed()) time = get_valid_time() if time != 'q' and time != 0: enter_time(t, curr_subject, time)
def make_new_timer(self): # Creates a new individual timer in the timer frame new_timer = tk.Frame(self.indTimersFrame) new_timer.pack(side=tk.TOP, anchor=tk.N) timer = tk.Entry(new_timer, width=5) timer.insert(tk.END, 'Well #') timer.pack(side=tk.LEFT) sw = stopwatch.StopWatch(parent=new_timer) sw.pack(side=tk.LEFT) start_button = tk.Button(new_timer, text="Start", command=sw.Start) start_button.pack(side=tk.LEFT) stop_button = tk.Button(new_timer, text="Stop", command=sw.Stop) stop_button.pack(side=tk.LEFT) reset_button = tk.Button(new_timer, text="Reset", command=sw.Reset) reset_button.pack(side=tk.LEFT)
def benchmark_steem_passtrough(): sw = stopwatch.StopWatch() steem = Node().default() with sw.timer('connection_speed'): for i in range(1000): with sw.timer('default'): print(i) Account("furion") for i in range(1000): with sw.timer('passtrough'): Account("furion", steem=steem) # for i in range(100): # with sw.timer('preselected'): # Node().default2() print(stopwatch.format_report(sw.get_last_aggregated_report()))
def __init__(self, kpool, kqueue, kstats): self.lock = threading.Lock() self.pool = kpool self.queue = kqueue self.stats = kstats self.item = None self.songs = [] self.song = None self.paused = False self.player = None self.playing = False self.stopwatch = None self.running = threading.Event() self.skip_current_item = False self.skip_current_song = False self.stopwatch = stopwatch.StopWatch() super(Control, self).__init__()
def start_rollout_worker(self, rollout_worker_idx): sw = stopwatch.StopWatch() next_obs, next_done, obs, actions, logprobs, rewards, dones, values = self.storage rollout_task_queue = self.rollout_task_queues[rollout_worker_idx] env_idxs = range( rollout_worker_idx * self.num_envs_per_rollout_worker, rollout_worker_idx * self.num_envs_per_rollout_worker + self.num_envs_per_rollout_worker) for env_idx in env_idxs: next_step = 0 self.policy_request_queue.put( [next_step, env_idx, rollout_worker_idx]) next_obs[env_idx] = torch.tensor(self.envs[env_idx].reset()) next_done[env_idx] = 0 print(env_idx) last_report = last_report_frames = total_env_frames = 0 while True: with sw.timer('act'): with sw.timer('wait_rollout_task_queue'): tasks = [] for _ in range(4): tasks.extend( self.rollout_task_queues[rollout_worker_idx].get()) for task in tasks: step, env_idx = task with sw.timer('rollouts'): obs[step, env_idx] = next_obs[env_idx].copy() dones[step, env_idx] = next_done[env_idx].copy() next_obs[env_idx], r, d, info = self.envs[ env_idx].step(actions[step, env_idx]) if d: next_obs[env_idx] = self.envs[env_idx].reset() rewards[step, env_idx] = r next_done[env_idx] = d next_step = (step + 1) % self.num_steps self.policy_request_queue.put( [next_step, env_idx, rollout_worker_idx]) if 'episode' in info.keys(): print([ "charts/episode_reward", info['episode']['r'] ])
def start_policy_worker(self): next_obs, next_done, obs, actions, logprobs, rewards, dones, values = self.storage sw = stopwatch.StopWatch() # min_num_requests = 3 # wait_for_min_requests = 0.01 # time.sleep(5) step = 0 while True: step += 1 with sw.timer('policy_worker'): # waiting_started = time.time() with sw.timer('policy_requests'): # policy_requests = [] # while len(policy_requests) < min_num_requests and time.time() - waiting_started < wait_for_min_requests: # try: # policy_requests.extend(self.policy_request_queue.get_many(timeout=0.005)) # except Empty: # pass # if len(policy_requests) == 0: # continue policy_requests = [] for _ in range(4): policy_requests.extend(self.policy_request_queue.get()) with sw.timer('prepare_data'): ls = np.array(policy_requests) with sw.timer('index'): next_o = next_obs[ls[:, 1]] with sw.timer('inference'): with torch.no_grad(): a, l, _ = self.agent.get_action(next_o) v = self.agent.get_value(next_o) print(a) for idx, item in enumerate(ls): with sw.timer('move_to_cpu'): actions[tuple(item[0, 1])] = a[idx] logprobs[tuple(item[0, 1])] = l[idx] values[tuple(item[0, 1])] = v[idx] with sw.timer('execute_action'): self.rollout_task_queues[item[2]].put( [item[0], item[1]])
def benchmark_steem_passtrough(): sw = stopwatch.StopWatch() steem = Node().default() with sw.timer('connection_speed'): for i in range(1000): with sw.timer('default'): print(i) Account("furion") for i in range(1000): with sw.timer('passtrough'): Account("furion", steem=steem) # for i in range(100): # with sw.timer('preselected'): # Node().default2() print(stopwatch.format_report(sw.get_last_aggregated_report())) # passing vs initiating a new doesn't make much difference # ************************ # *** StopWatch Report *** # ************************ # connection_speed 3054.388ms (100%) # default 1000 1525.720ms (50%) # passtrough 1000 1501.971ms (49%) # Annotations:
envs.action_space.shape).to(device) logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) dones = torch.zeros((args.num_steps, args.num_envs)).to(device) values = torch.zeros((args.num_steps, args.num_envs)).to(device) invalid_action_masks = torch.zeros((args.num_steps, args.num_envs) + (envs.action_space.nvec.sum(), )).to(device) # TRY NOT TO MODIFY: start the game global_step = 0 start_time = time.time() # Note how `next_obs` and `next_done` are used; their usage is equivalent to # https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60 next_obs = envs.reset() next_done = torch.zeros(args.num_envs).to(device) num_updates = args.total_timesteps // args.batch_size sw = stopwatch.StopWatch() ## CRASH AND RESUME LOGIC: starting_update = 1 if args.prod_mode and wandb.run.resumed: print("previous run.summary", run.summary) starting_update = run.summary['charts/update'] + 1 global_step = starting_update * args.batch_size api = wandb.Api() run = api.run(run.get_url()[len("https://app.wandb.ai/"):]) model = run.file('agent.pt') model.download(f"models/{experiment_name}/") agent.load_state_dict(torch.load(f"models/{experiment_name}/agent.pt")) agent.eval() print(f"resumed at update {starting_update}") for update in range(starting_update, num_updates + 1):
def start_policy_worker(inputs): # raise i, args, experiment_name, lock, stats_queue, device, \ next_obs, next_done, obs, actions, logprobs, rewards, dones, values, traj_availables,\ rollout_task_queues, policy_request_queues, learner_request_queue, new_policy_queues = inputs sw = stopwatch.StopWatch() device = torch.device('cuda') agent = Agent(4).to(device) min_num_requests = 3 wait_for_min_requests = 0.01 # time.sleep(5) step = 0 while True: step += 1 with sw.timer('policy_worker'): waiting_started = time.time() with sw.timer('policy_requests'): policy_requests = [] while len(policy_requests) < min_num_requests and time.time( ) - waiting_started < wait_for_min_requests: try: policy_requests.extend( policy_request_queues[i].get_many(timeout=0.005)) except Empty: pass if len(policy_requests) == 0: continue with sw.timer('prepare_data'): ls = np.concatenate(policy_requests) idxs = tuple(ls.T) with sw.timer('index'): t1 = next_obs[idxs[:-1]] with sw.timer('create array at gpu'): next_o = torch.Tensor(t1).to(device, non_blocking=True) with sw.timer('prepare_data2'): rollout_worker_idxs = ls.T[0, ::args.num_envs // args.num_env_split] split_idxs = ls.T[1, ::args.num_envs // args.num_env_split] step_idxs = ls.T[-1, ::args.num_envs // args.num_env_split] with sw.timer('inference'): with torch.no_grad(): a, l, _ = agent.get_action(next_o) v = agent.get_value(next_o) with sw.timer('move_to_cpu'): actions[idxs] = a.cpu() logprobs[idxs] = l.cpu() values[idxs] = v.flatten().cpu() with sw.timer('execute_action'): for j in range(len(rollout_worker_idxs)): rollout_worker_idx = rollout_worker_idxs[j] split_idx = split_idxs[j] step_idx = step_idxs[j] rollout_task_queues[rollout_worker_idx].put( [split_idx, step_idx]) with sw.timer('update_policy'): try: new_policies = new_policy_queues[i].get_many(timeout=0.005) agent.load_state_dict(new_policies[-1]) except Empty: pass if step % 100 == 0: # print(ls.shape) print(stopwatch.format_report(sw.get_last_aggregated_report()))
def learn(inputs): i, args, experiment_name, lock, stats_queue, device, \ next_obs, next_done, obs, actions, logprobs, rewards, dones, values, traj_availables,\ rollout_task_queues, policy_request_queues, learner_request_queue, new_policy_queues = inputs s_next_obs, s_next_done, s_obs, s_actions, s_logprobs, s_rewards, s_dones, s_values, s_traj_availables = next_obs, next_done, obs, actions, logprobs, rewards, dones, values, traj_availables sw = stopwatch.StopWatch() device = torch.device('cuda') agent = Agent(4).to(device) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) min_num_requests = 3 wait_for_min_requests = 0.01 # time.sleep(5) step = 0 while True: with sw.timer('learner'): waiting_started = time.time() learner_requests = [] with sw.timer('learner_requests'): while len(learner_requests) < min_num_requests and time.time( ) - waiting_started < wait_for_min_requests: try: learner_requests.extend( learner_request_queue.get_many(timeout=0.005)) except Empty: pass if len(learner_requests) == 0: continue with sw.timer('prepare_data'): ls = np.concatenate(learner_requests) rollout_worker_idxs = ls.T[0, ::args.num_envs // args.num_env_split] split_idxs = ls.T[1, ::args.num_envs // args.num_env_split] step_idxs = ls.T[-1, ::args.num_envs // args.num_env_split] idxs = tuple(ls.T) next_done = torch.tensor(s_next_done[idxs[:-1]], device=device) next_obs = torch.tensor(s_next_obs[idxs[:-1]], device=device) rewards = torch.tensor(s_rewards[idxs[:-1]], device=device).transpose(0, 1) obs = torch.tensor(s_obs[idxs[:-1]], device=device).transpose(0, 1) logprobs = torch.tensor(s_logprobs[idxs[:-1]], device=device).transpose(0, 1) actions = torch.tensor(s_actions[idxs[:-1]], device=device).transpose(0, 1) values = torch.tensor(s_values[idxs[:-1]], device=device).transpose(0, 1) dones = torch.tensor(s_dones[idxs[:-1]], device=device).transpose(0, 1) # bootstrap reward if not done. reached the batch limit with torch.no_grad(): last_value = agent.get_value(next_obs).reshape(1, -1) if args.gae: advantages = torch.zeros_like(rewards).to(device) lastgaelam = 0 for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: nextnonterminal = 1.0 - next_done nextvalues = last_value else: nextnonterminal = 1.0 - dones[t + 1] nextvalues = values[t + 1] delta = rewards[ t] + args.gamma * nextvalues * nextnonterminal - values[ t] advantages[ t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam returns = advantages + values else: returns = torch.zeros_like(rewards).to(device) for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: nextnonterminal = 1.0 - next_done next_return = last_value else: nextnonterminal = 1.0 - dones[t + 1] next_return = returns[t + 1] returns[t] = rewards[ t] + args.gamma * nextnonterminal * next_return advantages = returns - values # flatten the batch b_obs = obs.reshape((-1, ) + (4, 84, 84)) b_logprobs = logprobs.reshape(-1) b_actions = actions.reshape(-1) b_advantages = advantages.reshape(-1) b_returns = returns.reshape(-1) b_values = values.reshape(-1) # Optimizaing the policy and value network # target_agent = Agent(4).to(device) inds = np.arange(len(b_obs)) minibatch_ind = inds # print("==============", b_obs.shape) # for i_epoch_pi in range(args.update_epochs): # np.random.shuffle(inds) # target_agent.load_state_dict(agent.state_dict()) # for start in range(0, args.batch_size, args.minibatch_size): # end = start + args.minibatch_size # minibatch_ind = inds[start:end] mb_advantages = b_advantages[minibatch_ind] if args.norm_adv: mb_advantages = (mb_advantages - mb_advantages.mean()) / ( mb_advantages.std() + 1e-8) _, newlogproba, entropy = agent.get_action(b_obs[minibatch_ind], b_actions[minibatch_ind]) ratio = (newlogproba - b_logprobs[minibatch_ind]).exp() # Stats approx_kl = (b_logprobs[minibatch_ind] - newlogproba).mean() # Policy loss pg_loss1 = -mb_advantages * ratio pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) pg_loss = torch.max(pg_loss1, pg_loss2).mean() entropy_loss = entropy.mean() # Value loss new_values = agent.get_value(b_obs[minibatch_ind]).view(-1) if args.clip_vloss: v_loss_unclipped = ((new_values - b_returns[minibatch_ind])**2) v_clipped = b_values[minibatch_ind] + torch.clamp( new_values - b_values[minibatch_ind], -args.clip_coef, args.clip_coef) v_loss_clipped = (v_clipped - b_returns[minibatch_ind])**2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) v_loss = 0.5 * v_loss_max.mean() else: v_loss = 0.5 * ((new_values - b_returns[minibatch_ind])**2).mean() loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() # raise del next_done, next_obs, rewards, obs, logprobs, actions, values, dones for learner_request in learner_requests: policy_request_queues[0].put(learner_request) # update the policy in the policy worker for new_policy_queue in new_policy_queues: new_policy_queue.put(agent.state_dict()) stats_queue.put(["losses/value_loss", v_loss.item()]) stats_queue.put(["losses/policy_loss", pg_loss.item()]) stats_queue.put(["losses/entropy", entropy_loss.item()]) stats_queue.put(["losses/approx_kl", approx_kl.item()])
def _StopWatchInit(self, ): self._stopwatch = sw.StopWatch()
def _StringToday(self, ): s = sw.StopWatch() return s._Today()
def __init__(self, policy_request_queue, obs): self.step = 0 self.sw = stopwatch.StopWatch() self.policy_request_queue = policy_request_queue self.obs = obs
def act(inputs): sw = stopwatch.StopWatch() args, experiment_name, i, lock, stats_queue, device, \ next_obs, next_done, obs, actions, logprobs, rewards, dones, values, traj_availables, \ rollout_task_queue, policy_request_queues, learner_request_queue = inputs envs = [] def make_env(gym_id, seed, idx): env = gym.make(gym_id) env = wrap_atari(env) env = gym.wrappers.RecordEpisodeStatistics(env) env = wrap_deepmind( env, clip_rewards=True, frame_stack=True, scale=False, ) env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) return env envs = [make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)] envs = np.array(envs, dtype=object) # for "Double-buffered" sampling policy_request_queue_idx = 0 for split_idx in range(args.num_env_split): policy_request_idxs = [] for env_idx, env in enumerate(envs[split_idx::args.num_env_split]): next_obs[i,split_idx,env_idx,0,0] = env.reset() next_done[i,split_idx,env_idx,0,0] = 0 policy_request_idxs += [[i,split_idx,env_idx,0,0,0]] policy_request_queue_idx = (policy_request_queue_idx + 1) % args.num_policy_workers policy_request_queues[policy_request_queue_idx].put(policy_request_idxs) last_report = last_report_frames = total_env_frames = 0 while True: with sw.timer('act'): with sw.timer('wait_rollout_task_queue'): tasks = [] while len(tasks) == 0: try: tasks = rollout_task_queue.get_many(timeout=0.01) except Empty: pass for task in tasks: # for "Double-buffered" sampling with sw.timer('rollouts'): split_idx, step = task policy_request_idxs = [] for env_idx, env in enumerate(envs[split_idx::args.num_env_split]): obs[i,split_idx,env_idx,0,0,step] = next_obs[i,split_idx,env_idx,0,0].copy() dones[i,split_idx,env_idx,0,0,step] = next_done[i,split_idx,env_idx,0,0] next_obs[i,split_idx,env_idx,0,0], r, d, info = env.step(actions[i,split_idx,env_idx,0,0,step]) if d: next_obs[i,split_idx,env_idx,0,0] = env.reset() rewards[i,split_idx,env_idx,0,0,step] = r next_done[env_idx] = d next_step = (step + 1) % args.num_steps policy_request_idxs += [[i,split_idx,env_idx,0,0,next_step]] num_frames = 1 total_env_frames += num_frames if 'episode' in info.keys(): stats_queue.put(info['episode']['l']) with sw.timer('policy_request_queue.put'): policy_request_queue_idx = (policy_request_queue_idx + 1) % args.num_policy_workers policy_request_queues[policy_request_queue_idx].put(policy_request_idxs) if total_env_frames % 1000 == 0 and i == 0: print(stopwatch.format_report(sw.get_last_aggregated_report()))