def batch_and_learn(i, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal frames, stats timings = prof.Timings() while frames < flags.total_frames: timings.reset() batch, agent_state = get_batch(free_queue, full_queue, buffers, initial_agent_state_buffers, flags, timings) stats = learn(model, learner_model, batch, agent_state, optimizer, scheduler, flags, position_count=position_count, action_hist=action_hist) timings.time('learn') with lock: to_log = dict(frames=frames) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) frames += T * B if i == 0: log.info('Batch and learn: %s', timings.summary())
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue, model: torch.nn.Module, buffers: Buffers, episode_state_count_dict: dict, train_state_count_dict: dict, initial_agent_state_buffers, flags): try: log.info('Actor %i started.', i) timings = prof.Timings() gym_env = create_env(flags) seed = i ^ int.from_bytes(os.urandom(4), byteorder='little') gym_env.seed(seed) if flags.num_input_frames > 1: gym_env = FrameStack(gym_env, flags.num_input_frames) env = Environment(gym_env, fix_seed=flags.fix_seed, env_seed=flags.env_seed) env_output = env.initial() agent_state = model.initial_state(batch_size=1) agent_output, unused_state = model(env_output, agent_state) while True: index = free_queue.get() if index is None: break # Write old rollout end. for key in env_output: buffers[key][index][0, ...] = env_output[key] for key in agent_output: buffers[key][index][0, ...] = agent_output[key] for i, tensor in enumerate(agent_state): initial_agent_state_buffers[index][i][...] = tensor # Update the episodic state counts episode_state_key = tuple(env_output['frame'].view(-1).tolist()) if episode_state_key in episode_state_count_dict: episode_state_count_dict[episode_state_key] += 1 else: episode_state_count_dict.update({episode_state_key: 1}) buffers['episode_state_count'][index][0, ...] = \ torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key))) # Reset the episode state counts when the episode is over if env_output['done'][0][0]: for episode_state_key in episode_state_count_dict: episode_state_count_dict = dict() # Update the training state counts if you're doing count-based exploration if flags.model == 'count': train_state_key = tuple(env_output['frame'].view(-1).tolist()) if train_state_key in train_state_count_dict: train_state_count_dict[train_state_key] += 1 else: train_state_count_dict.update({train_state_key: 1}) buffers['train_state_count'][index][0, ...] = \ torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key))) # Do new rollout for t in range(flags.unroll_length): timings.reset() with torch.no_grad(): agent_output, agent_state = model(env_output, agent_state) timings.time('model') env_output = env.step(agent_output['action']) timings.time('step') for key in env_output: buffers[key][index][t + 1, ...] = env_output[key] for key in agent_output: buffers[key][index][t + 1, ...] = agent_output[key] # Update the episodic state counts episode_state_key = tuple( env_output['frame'].view(-1).tolist()) if episode_state_key in episode_state_count_dict: episode_state_count_dict[episode_state_key] += 1 else: episode_state_count_dict.update({episode_state_key: 1}) buffers['episode_state_count'][index][t + 1, ...] = \ torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key))) # Reset the episode state counts when the episode is over if env_output['done'][0][0]: episode_state_count_dict = dict() # Update the training state counts if you're doing count-based exploration if flags.model == 'count': train_state_key = tuple( env_output['frame'].view(-1).tolist()) if train_state_key in train_state_count_dict: train_state_count_dict[train_state_key] += 1 else: train_state_count_dict.update({train_state_key: 1}) buffers['train_state_count'][index][t + 1, ...] = \ torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key))) timings.time('write') full_queue.put(index) if i == 0: log.info('Actor %i: %s', i, timings.summary()) except KeyboardInterrupt: pass except Exception as e: logging.error('Exception in worker process %i', i) traceback.print_exc() print() raise e
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue, model: torch.nn.Module, buffers: Buffers, initial_agent_state_buffers, flags): try: log.info('Actor %i started.', i) timings = prof.Timings() gym_env = create_env(flags) seed = i ^ int.from_bytes(os.urandom(4), byteorder='little') gym_env.seed(seed) if flags.num_input_frames > 1: gym_env = FrameStack(gym_env, flags.num_input_frames) env = Environment(gym_env, fix_seed=flags.fix_seed, env_seed=flags.env_seed) env_output = env.initial() agent_state = model.initial_state(batch_size=1) agent_output, unused_state = model(env_output, agent_state) while True: index = free_queue.get() if index is None: break for key in env_output: buffers[key][index][0, ...] = env_output[key] for key in agent_output: buffers[key][index][0, ...] = agent_output[key] for i, tensor in enumerate(agent_state): initial_agent_state_buffers[index][i][...] = tensor for t in range(flags.unroll_length): timings.reset() with torch.no_grad(): agent_output, agent_state = model(env_output, agent_state) timings.time('model') env_output = env.step(agent_output['action']) timings.time('step') for key in env_output: buffers[key][index][t + 1, ...] = env_output[key] for key in agent_output: buffers[key][index][t + 1, ...] = agent_output[key] timings.time('write') full_queue.put(index) if i == 0: log.info('Actor %i: %s', i, timings.summary()) except KeyboardInterrupt: pass except Exception as e: logging.error('Exception in worker process %i', i) traceback.print_exc() print() raise e