class Counter(object): ''' A counter used for multiprocessing, simple wrapper around multiprocessing.Value ''' def __init__(self): from torch.multiprocessing import Value self.val = Value('i', 0) def increment(self, n=1): with self.val.get_lock(): self.val.value += n def reset(self): with self.val.get_lock(): self.val.value = 0 @property def value(self): return self.val.value
def _worker( reader: DatasetReader, input_queue: Queue, output_queue: Queue, num_active_workers: Value, num_inflight_items: Value, worker_id: int, ) -> None: """ A worker that pulls filenames off the input queue, uses the dataset reader to read them, and places the generated instances on the output queue. When there are no filenames left on the input queue, it decrements num_active_workers to signal completion. """ logger.info(f"Reader worker: {worker_id} PID: {os.getpid()}") # Keep going until you get a file_path that's None. while True: file_path = input_queue.get() if file_path is None: # It's important that we close and join the queue here before # decrementing num_active_workers. Otherwise our parent may join us # before the queue's feeder thread has passed all buffered items to # the underlying pipe resulting in a deadlock. # # See: # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines output_queue.close() output_queue.join_thread() # Decrementing is not atomic. # See https://docs.python.org/2/library/multiprocessing.html#multiprocessing.Value. with num_active_workers.get_lock(): num_active_workers.value -= 1 logger.info(f"Reader worker {worker_id} finished") break logger.info(f"reading instances from {file_path}") for instance in reader.read(file_path): with num_inflight_items.get_lock(): num_inflight_items.value += 1 output_queue.put(instance)
class Signal(object): ''' a signal used for mutliprocessing, simple wrapper around multiprocessing.Value ''' def __init__(self): from torch.multiprocessing import Value self.val = Value('i', False) def set_signal(self, boolean): with self.val.get_lock(): self.val.value = boolean @property def value(self): return bool(self.val.value)
class Agent_sync(Agent): """ An agent class will maintain multiple policy net and environments, each worker will have one environment and one policy useful for most of single agent RL/IL settings """ def __init__(self, config: ParamDict, environment: Environment, policy: Policy, filter_op: Filter): threads, gpu = config.require("threads", "gpu") super(Agent_sync, self).__init__(config, environment, policy, filter_op) # sync signal, -1: terminate, 0: normal running, >0 restart and waiting for parameter update self._sync_signal = Value('i', 0) # sampler sub-process list self._sampler_proc = [] # used for synchronize commands self._cmd_pipe = None self._param_pipe = None self._cmd_lock = Lock() cmd_pipe_child, cmd_pipe_parent = Pipe(duplex=True) param_pipe_child, param_pipe_parent = Pipe(duplex=False) self._cmd_pipe = cmd_pipe_parent self._param_pipe = param_pipe_parent for i_thread in range(threads): child_name = f"sampler_{i_thread}" worker_cfg = ParamDict({ "seed": self.seed + 1024 + i_thread, "gpu": gpu }) child = Process(target=Agent_sync._sampler_worker, name=child_name, args=(worker_cfg, cmd_pipe_child, param_pipe_child, self._cmd_lock, self._sync_signal, deepcopy(policy), deepcopy(environment), deepcopy(filter_op))) self._sampler_proc.append(child) child.start() def __del__(self): """ We should terminate all child-process here """ self._sync_signal.value = -1 sleep(1) for _proc in self._sampler_proc: _proc.join(2) if _proc.is_alive(): _proc.terminate() self._cmd_pipe.close() self._param_pipe.close() def broadcast(self, config: ParamDict): policy_state, filter_state, max_step, self._batch_size, fixed_env, fixed_policy, fixed_filter = \ config.require("policy state dict", "filter state dict", "trajectory max step", "batch size", "fixed environment", "fixed policy", "fixed filter") self._replay_buffer = [] policy_state["fixed policy"] = fixed_policy filter_state["fixed filter"] = fixed_filter cmd = ParamDict({ "trajectory max step": max_step, "fixed environment": fixed_env, "filter state dict": filter_state }) assert self._sync_signal.value < 1, "Last sync event not finished due to some error, some sub-proc maybe died, abort" # tell sub-process to reset with self._sync_signal.get_lock(): self._sync_signal.value = len(self._sampler_proc) # sync net parameters with self._cmd_lock: for _ in range(len(self._sampler_proc)): self._param_pipe.send(policy_state) # wait for all agents' ready feedback while self._sync_signal.value > 0: sleep(0.01) # sync commands for _ in range(self._batch_size): self._cmd_pipe.send(cmd) def collect(self): if self._cmd_pipe.poll(0.1): self._replay_buffer.append(self._cmd_pipe.recv()) if len(self._replay_buffer) < self._batch_size: return None else: batch = self._filter.operate_trajectoryList(self._replay_buffer) return batch @staticmethod def _sampler_worker(setups: ParamDict, pipe_cmd, pipe_param, read_lock, sync_signal, policy, environment, filter_op): gpu, seed = setups.require("gpu", "seed") device = decide_device(gpu) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) environment.init(display=False) filter_op.init() filter_op.to_device(device) policy.init() policy.to_device(device) # -1: syncing, 0: waiting for new command, 1: sampling local_state = 0 current_step = None step_buffer = [] cmd = None def _get_piped_data(pipe): with read_lock: if pipe.poll(0.001): return pipe.recv() else: return None while sync_signal.value >= 0: # check sync counter for sync event, and waiting for new parameters if sync_signal.value > 0 and local_state >= 0: # receive sync signal, reset all workspace settings, decrease sync counter, # and set state machine to -1 for not init again while _get_piped_data(pipe_cmd) is not None: pass step_buffer.clear() _policy_state = _get_piped_data(pipe_param) if _policy_state is not None: # set new parameters policy.reset(_policy_state) with sync_signal.get_lock(): sync_signal.value -= 1 local_state = -1 # if sync ends, tell state machine to recover from syncing state, and reset environment elif sync_signal.value == 0 and local_state == -1: local_state = 0 # waiting for states (states are list of dicts) elif sync_signal.value == 0 and local_state == 0: cmd = _get_piped_data(pipe_cmd) if cmd is not None: step_buffer.clear() cmd.require("filter state dict", "fixed environment", "trajectory max step") current_step = environment.reset( random=not cmd["fixed environment"]) filter_op.reset(cmd["filter state dict"]) local_state = 1 # sampling elif sync_signal.value == 0 and local_state == 1: with torch.no_grad(): policy_step = filter_op.operate_currentStep(current_step) last_step = policy.step([policy_step])[0] last_step, current_step, done = environment.step(last_step) record_step = filter_op.operate_recordStep(last_step) step_buffer.append(record_step) if len(step_buffer) >= cmd["trajectory max step"] or done: traj = filter_op.operate_stepList(step_buffer, done=done) with read_lock: pipe_cmd.send(traj) local_state = 0 # finalization filter_op.finalize() policy.finalize() environment.finalize() pipe_cmd.close() pipe_param.close() print("Sampler sub-process exited")
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.bind((address, port)) server.listen() param_queue = Queue() param_queue.put(net.state_dict()) shutdown_val = Value('b', 0) receiver_proc = Process(target=HandleWorkers, args=(server, replay_memory, mem_lock, param_queue, shutdown_val)) receiver_proc.start() while True: try: Train(net, replay_memory, mem_lock, args.output_file) if param_queue is not None: param_queue.put(net.state_dict) torch.save(net.state_dict(), args.output_file) except KeyboardInterrupt: if server is not None: assert (shutdown_val is not None and receiver_proc is not None) print("Shutting down...") with shutdown_val.get_lock(): shutdown_val.value = 1 receiver_proc.join() server.close() break
class ProgressiveResize(ResizeNative): """Resize data to sizes specified by scheduler""" def __init__(self, scheduler: scheduler_type, mode: str = 'nearest', align_corners: bool = None, preserve_range: bool = False, keys: Sequence = ('data', ), grad: bool = False, **kwargs): """ Args: scheduler: scheduler which determined the current size. The scheduler is called with the current iteration of the transform mode: one of ``nearest``, ``linear``, ``bilinear``, ``bicubic``, ``trilinear``, ``area`` (for more inforamtion see :func:`torch.nn.functional.interpolate`) align_corners: input and output tensors are aligned by the center points of their corners pixels, preserving the values at the corner pixels. preserve_range: output tensor has same range as input tensor keys: keys which should be augmented grad: enable gradient computation inside transformation **kwargs: keyword arguments passed to augment_fn Warnings: When this transformations is used in combination with multiprocessing, the step counter is not perfectly synchronized between multiple processes. As a result the step count my jump between values in a range of the number of processes used. """ super().__init__(size=0, mode=mode, align_corners=align_corners, preserve_range=preserve_range, keys=keys, grad=grad, **kwargs) self.scheduler = scheduler self._step = Value('i', 0) def reset_step(self) -> ResizeNative: """ Reset step to 0 Returns: ResizeNative: returns self to allow chaining """ with self._step.get_lock(): self._step.value = 0 return self def increment(self) -> ResizeNative: """ Increment step by 1 Returns: ResizeNative: returns self to allow chaining """ with self._step.get_lock(): self._step.value += 1 return self @property def step(self) -> int: """ Current step Returns: int: number of steps """ return self._step.value def forward(self, **data) -> dict: """ Resize data Args: **data: input batch Returns: dict: augmented batch """ self.kwargs["size"] = self.scheduler(self.step) self.increment() return super().forward(**data)
def train(): np.random.seed(random_seed) torch.manual_seed(random_seed) writer = SummaryWriter() ac = AC(latent_num, cnn_chanel_num, stat_dim) writer.add_graph(ac, (torch.zeros([1, 1, img_shape[0], img_shape[1] ]), torch.zeros([1, stat_dim]))) optim = GlobalAdam([{ 'params': ac.encode_img.parameters(), 'lr': 2.5e-5 }, { 'params': ac.encode_stat.parameters(), 'lr': 2.5e-5 }, { 'params': ac.pi.parameters(), 'lr': 2.5e-5 }, { 'params': ac.actor.parameters(), 'lr': 2.5e-5 }, { 'params': ac.f.parameters() }, { 'params': ac.V.parameters() }], lr=5e-3, weight_decay=weight_decay) if os.path.exists('S3_state_dict.pt'): ac.load_state_dict(torch.load('S3_state_dict.pt')) optim.load_state_dict(torch.load('S3_Optim_state_dict.pt')) else: ac.load_state_dict(torch.load('../stage2/S2_state_dict.pt'), strict=False) result_queue = Queue() validate_queue = Queue() gradient_queue = Queue() loss_queue = Queue() ep_cnt = Value('i', 0) optimizer_lock = Lock() processes = [] ac.share_memory() optimizer_worker = Process(target=update_shared_model, args=(gradient_queue, optimizer_lock, optim, ac)) optimizer_worker.start() for no in range(mp.cpu_count() - 3): worker = Worker(no, ac, ep_cnt, optimizer_lock, result_queue, gradient_queue, loss_queue) worker.start() processes.append(worker) validater = Validate(ac, ep_cnt, optimizer_lock, validate_queue) validater.start() best_reward = 0 while True: with ep_cnt.get_lock(): if not result_queue.empty(): ep_cnt.value += 1 reward, money, win_rate = result_queue.get() objective_actor, loss_critic, loss_f = loss_queue.get() writer.add_scalar('Interaction/Reward', reward, ep_cnt.value) writer.add_scalar('Interaction/Money', money, ep_cnt.value) writer.add_scalar('Interaction/win_rate', win_rate, ep_cnt.value) writer.add_scalar('Update/objective_actor', objective_actor, ep_cnt.value) writer.add_scalar('Update/loss_critic', loss_critic, ep_cnt.value) writer.add_scalar('Update/loss_f', loss_f, ep_cnt.value) with optimizer_lock: if reward > best_reward: best_reward = reward torch.save(ac.state_dict(), 'S3_BEST_state_dict.pt') if ep_cnt.value % save_every == 0: torch.save(ac.state_dict(), 'S3_state_dict.pt') torch.save(optim.state_dict(), 'S3_Optim_state_dict.pt') if not validate_queue.empty(): val_reward, val_money, val_win_rate = validate_queue.get() writer.add_scalar('Validation/reward', val_reward, ep_cnt.value) writer.add_scalar('Validation/money', val_money, ep_cnt.value) writer.add_scalar('Validation/win_rate', val_win_rate, ep_cnt.value) for worker in processes: worker.join() optimizer_worker.kill()
def create_worker(gnet_actor: Actor, gnet_critic: Critic, opt: SharedAdam, global_episode: mp.Value, global_results_queue: mp.Queue, name: int) -> None: """ This is our main function. It is in a function so that it can be spread over multiple processes. :param gnet_actor: Our global Actor network. :param gnet_critic: Our global Critic network. :param opt: Our shared Adam optimizer. :param global_episode: A shared value that tells us what episode we are on over all workers. :param global_results_queue: A shared queue that workers can put rewards onto. :param int: A number for this worker. :return: None """ lnet_actor, lnet_critic = Actor(), Critic() lnet_critic.load_state_dict(gnet_critic.state_dict()) lnet_actor.load_state_dict(gnet_actor.state_dict()) lenv = gym.make('CartPole-v0') if name == 0: print("Creating Video Recorder") video_recorder = VideoRecorder(lenv, './output/06_Cartpole_A3C_Q_Critic.mp4', enabled=True) else: video_recorder = None print(f"Worker {name} starting run...") total_step = 1 while global_episode.value < N_ITERS: buffer_state, buffer_log_probs, buffer_rewards, buffer_policy_dist, buffer_action_one_hot = [], [], [], [], [] episode_reward = 0 state = lenv.reset() for _ in count(): # Render the environment if you are the zeroth worker every 100 steps if (total_step + 1) % 10 == 0 and video_recorder is not None: video_recorder.capture_frame() state = torch.FloatTensor(state).to(device) policy_dist = lnet_actor(state) policy = Categorical(policy_dist) action = policy.sample() action_one_hot = to_onehot(action, ACTION_DIM) next_state, reward, done, _ = lenv.step(action.cpu().numpy()) log_prob = policy.log_prob(action).unsqueeze(0) episode_reward += reward buffer_policy_dist.append(policy_dist[None, :]) buffer_action_one_hot.append(action_one_hot[None, :]) buffer_log_probs.append(log_prob[None, :]) buffer_state.append(state[None, :]) buffer_rewards.append( torch.FloatTensor([reward])[None, :].to(device)) state = next_state if total_step % UPDATE_GLOBAL_ITER == 0 or done: # sync next_state = torch.FloatTensor(next_state).to(device) next_policy_dist = lnet_actor(next_state) next_policy = Categorical(next_policy_dist) next_action = next_policy.sample() next_action_one_hot = to_onehot(next_action, ACTION_DIM) final_value = lnet_critic(next_state, next_policy_dist) if not done else 0 # Concatenate buffers buffer_state = torch.cat(buffer_state, dim=0) buffer_policy_dist = torch.cat(buffer_policy_dist, dim=0) buffer_action_one_hot = torch.cat(buffer_action_one_hot, dim=0) buffer_log_probs = torch.cat(buffer_log_probs, dim=0) buffer_rewards = torch.cat(buffer_rewards, dim=0) # Calculate the cumulative rewards using the final predicted value as the terminal value cum_reward = final_value discounted_future_rewards = torch.FloatTensor( len(buffer_rewards)).to(device) for i in range(len(buffer_rewards)): cum_reward = buffer_rewards[-i] + GAMMA * cum_reward discounted_future_rewards[-i] = cum_reward # Calculate the local losses for the states in the buffer values = lnet_critic(buffer_state, buffer_policy_dist) # Now we calculate the advantage function advantage = discounted_future_rewards - values # And the loss for both the actor and the critic actor_loss = -(buffer_log_probs * advantage.detach()) critic_loss = advantage.pow(2) # calculate local gradients and push local parameters to global # We are going to couple these losses so that on each episode they are related together opt.zero_grad() (actor_loss + critic_loss).mean().backward() for lp, gp in zip(lnet_actor.parameters(), gnet_actor.parameters()): gp._grad = lp.grad for lp, gp in zip(lnet_critic.parameters(), gnet_critic.parameters()): gp._grad = lp.grad opt.step() # pull global parameters lnet_critic.load_state_dict(gnet_critic.state_dict()) lnet_actor.load_state_dict(gnet_actor.state_dict()) buffer_state, buffer_log_probs, buffer_rewards, buffer_policy_dist, buffer_action_one_hot = [], [], [], [], [] if done: # Increment the global episode with global_episode.get_lock(): global_episode.value += 1 # Update the results queue # print(episode_reward) global_results_queue.put(episode_reward) # End this batch break total_step += 1 # This indicates its time to join all workers global_results_queue.put(None) print("DONE!") if video_recorder is not None: video_recorder.close() lenv.close()
class Pipeline(): def __init__(self, config, share_batches=True, manager=None, new_process=True): if new_process == True and manager is None: manager = Manager() self.knows = Semaphore(0) # > 0 if we know if any are coming # == 0 if DatasetReader is processing a command self.working = Semaphore(1 if new_process else 100) self.finished_reading = Lock( ) # locked if we're still reading from file # number of molecules that have been sent to the pipe: self.in_pipe = Value('i', 0) # Tracking what's already been sent through the pipe: self._example_number = Value('i', 0) # The final kill switch: self._close = Value('i', 0) self.command_queue = manager.Queue(10) self.molecule_pipeline = None self.batch_queue = Queue(config.data.batch_queue_cap ) #manager.Queue(config.data.batch_queue_cap) self.share_batches = share_batches self.dataset_reader = DatasetReader("dataset_reader", self, config, new_process=new_process) if new_process: self.dataset_reader.start() def __getstate__(self): self_dict = self.__dict__.copy() self_dict['dataset_reader'] = None return self_dict # methods for pipeline user/consumer: def start_reading(self, examples_to_read, make_molecules=True, batch_size=None, wait=False): #print("Start reading...") assert check_semaphore( self.finished_reading ), "Tried to start reading file, but already reading!" with self.in_pipe.get_lock(): assert self.in_pipe.value == 0, "Tried to start reading, but examples already in pipe!" set_semaphore(self.finished_reading, False) set_semaphore(self.knows, False) self.working.acquire() self.command_queue.put( StartReading(examples_to_read, make_molecules, batch_size)) if wait: self.wait_till_done() def wait_till_done(self): # wait_semaphore(self.knows) # wait_semaphore(self.finished_reading) self.working.acquire() self.working.release() if self.any_coming(): with self.in_pipe.get_lock(): ip = self.in_pipe.value raise Exception(f"Waiting with {ip} examples in pipe!") def scan_to(self, index): assert check_semaphore( self.knows), "Tried to scan to index, but don't know if finished!" assert check_semaphore( self.finished_reading ), "Tried to scan to index, but not finished reading!" assert not self.any_coming( ), "Tried to scan to index, but pipeline not empty!" self.working.acquire() self.command_queue.put(ScanTo(index)) with self._example_number.get_lock(): self._example_number.value = index # What to do if things are still in the pipe??? def set_indices(self, test_set_indices): self.working.acquire() self.command_queue.put(SetIndices(torch.tensor(test_set_indices))) self.working.acquire() self.command_queue.put(ScanTo(0)) def set_shuffle(self, shuffle): self.command_queue.put(SetShuffle(shuffle)) def any_coming(self): # returns True if at least one example is coming wait_semaphore(self.knows) with self.in_pipe.get_lock(): return self.in_pipe.value > 0 def get_batch(self, timeout=None): #assert self.any_coming(verbose=verbose), "Tried to get data from an empty pipeline!" x = self.batch_queue.get(True, timeout) #print(f"{type(x)} : {x}") #for b in x: # print(f" --{type(b)} : {b}") with self.in_pipe.get_lock(): self.in_pipe.value -= x.n_examples if self.in_pipe.value == 0 and not check_semaphore( self.finished_reading): set_semaphore(self.knows, False) with self._example_number.get_lock(): self._example_number.value += x.n_examples return x @property def example_number(self): with self._example_number.get_lock(): return self._example_number.value def close(self): self.command_queue.put(CloseReader()) with self._close.get_lock(): self._close.value = True self.dataset_reader.join(4) self.dataset_reader.kill() # methods for DatasetReader: def get_command(self): return self.command_queue.get() def put_molecule_to_ext(self, m, block=True): r = self.molecule_pipeline.put_molecule(m, block) if not r: return False with self.in_pipe.get_lock(): if self.in_pipe.value == 0: set_semaphore(self.knows, True) self.in_pipe.value += 1 return True def put_molecule_data(self, data, atomic_numbers, weights, ID, block=True): r = self.molecule_pipeline.put_molecule_data(data, atomic_numbers, weights, ID, block) if not r: return False with self.in_pipe.get_lock(): if self.in_pipe.value == 0: set_semaphore(self.knows, True) if data.ndim == 3: self.in_pipe.value += data.shape[0] else: self.in_pipe.value += 1 return True def get_batch_from_ext(self, block=True): return self.molecule_pipeline.get_next_batch(block) def ext_batch_ready(self): return self.molecule_pipeline.batch_ready() # !!! Call only after you've put the molecules !!! def set_finished_reading(self): set_semaphore(self.finished_reading, True) set_semaphore(self.knows, True) self.molecule_pipeline.notify_finished() def put_batch(self, x): if False: #self.share_batches: print("[P] Sharing memory... ") try: x.share_memory_() except Exception as e: print("[P] Failed when moving tensor to shared memory") print(e) print("[P] Done sharing memory") self.batch_queue.put(x) def time_to_close(self): with self._close.get_lock(): return self._close.value
receiver_proc = Process(target=HandleWorkers, args=(server, out_queue, param_queue, shutdown_server)) receiver_proc.start() state_dict = resnet.state_dict() print(type(state_dict)) param_queue.put(state_dict) tensor = out_queue.get() net_output = resnet(tensor) print(net_output) print("Shutting down server...") with shutdown_server.get_lock(): shutdown_server.value = 1 else: param_queue = Queue() server = communication.WorkerSocket(address, port) print("Connected to server") receiver_proc = Process(target=ReceiveParams, args=(server, param_queue), daemon=True) receiver_proc.start() # for _ in range(3): # SendPlayout(server) # time.sleep(1)
class CustomDataset(Dataset): def __init__(self, dataset, debug_mode=False, build_fn=None, num_workers=3, pca=None, params=None, cache_path=None, logger=None, debug=False): """ Converts the KDD Dataset files into a PyTorch-Dataset object. The Dataset object is to be passeed to a PyTorch-DataLoader. The DataLoader then uses __getitem__ to retrieve batches in the form {'data': input_vector, 'label': attack_type}. Textual fields are one hot encoded. Args: dataset (DataFrame): A Pandas dataset that will be used debug_mode: If True only a small sample for debugging will be returned build_fn: A function that takes a single row as input and returns a dict with keys 'data' and 'label', both of which hold a list of numpy objects """ self.dataset = dataset self.params = params self.preprocessed = [] self.build_fn = build_fn enter_build_fn_block = True if cache_path is not None: try: pickle_in = open(cache_path, "rb") except FileNotFoundError: pass else: enter_build_fn_block = False self.preprocessed = pickle.load(pickle_in) pickle_in.close() if self.build_fn and enter_build_fn_block: self.num_workers = num_workers self.messagePipes = { 'proc_' + str(id): Pipe() for id in range(self.num_workers) } self.finished = Value('i', 0) self._build(dataset) if cache_path is not None: pickle_out = open(cache_path, "wb") pickle.dump(self.preprocessed, pickle_out) pickle_out.close() # Creating the logger before the multiprocessing step # causes errors as spawned processes try to pickle it self.logger = logger if not self.build_fn: if self.logger: self.logger.warning( 'No build function specified. __getitem__() will return the raw dataset.' ) else: if self.logger: self.logger.debug('Build complete with {} failure(s)'.format( abs(len(self.preprocessed) - len(dataset)))) if pca is not None: self._do_pca(pca) def _do_pca(self, n_components): msg = 'Performing PCA with {} components.'.format(n_components) if self.logger: self.logger.info(msg) else: print(msg) data_only = [entry['data'].tolist() for entry in self.preprocessed] pca = PCA(n_components=n_components) transformed = pca.fit_transform(data_only) for i, elem in enumerate(transformed): self.preprocessed[i]['data'] = torch.tensor(elem) msg = 'Retained variance: {:.4f}'.format( pca.explained_variance_ratio_.cumsum()[-1]) if self.logger: self.logger.info(msg) else: print(msg) def __call__(self, batch_size, shuffle=False, num_workers=0, pin_memory=False): return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) def __len__(self): if self.build_fn: return len(self.preprocessed) else: return len(self.dataset) def _helper(self, dataset, sender_end, position): ''' Args: dataset (DataFrame) - part of the whole dataset to be preprocessed by a worker process ''' every_n = 10000 batch = [] for count, (_, row) in enumerate(dataset.iterrows()): batch.append(self.build_fn(row, self.params)) if count % every_n == 0: sender_end.send(batch) batch = [] count += 1 sender_end.send(batch) sender_end.send(-1) with self.finished.get_lock(): self.finished.value += 1 def _receiver(self, receiver_end, id): finished = 0 batch_2 = [] while True: batch = receiver_end.recv() if batch == -1: break for item in batch: label_tensor = torch.tensor(item['label']) data_tensor = torch.tensor(item['data']) batch_2.append({'data': data_tensor, 'label': label_tensor}) if len(batch_2) >= 200: with lock: self.preprocessed.extend(batch_2) batch_2 = [] with lock: self.preprocessed.extend(batch_2) def _single_thread_build(self, dataset): for _, row in dataset.iterrows(): item = self.build_fn(row, self.params) label_tensor = torch.tensor(item['label']) data_tensor = torch.tensor(item['data']) self.preprocessed.append({ 'data': data_tensor, 'label': label_tensor }) def _build(self, dataset): if self.num_workers == 1: self._single_thread_build(dataset) else: # Divide the dataset into equal parts and send to the worker processes processes = [] threads = [] ds_size = len(dataset) chunk_size = ds_size // self.num_workers for i in range(self.num_workers): if i == self.num_workers - 1: ds_chunk = dataset.iloc[i * chunk_size:] else: ds_chunk = dataset.iloc[i * chunk_size:i * chunk_size + chunk_size] # thread = Thread(target=self._receiver, args=(self.messagePipes['proc_' + str(i)][0],i)) process = Process(target=self._helper, name='helper_' + str(i), args=(ds_chunk, self.messagePipes['proc_' + str(i)][1], i)) processes.append(process) # thread.start() process.start() for i in range(self.num_workers): thread = Thread(target=self._receiver, args=(self.messagePipes['proc_' + str(i)][0], i)) threads.append(thread) thread.start() for process in processes: process.join() for thread in threads: thread.join() def output_size(self): return len(set([elem['label'][0].item() for elem in self.preprocessed])) def input_size(self): return len(self.preprocessed[0]['data']) def __getitem__(self, idx): if self.build_fn: return self.preprocessed[idx] elif isinstance(self.dataset, pd.DataFrame): return self.dataset.iloc[idx, :] else: return self.dataset[idx]
class HogwildWorld(World): """Creates a separate world for each thread (process). Maintains a few shared objects to keep track of state: - A Semaphore which represents queued examples to be processed. Every call of parley increments this counter; every time a Process claims an example, it decrements this counter. - A Condition variable which notifies when there are no more queued examples. - A boolean Value which represents whether the inner worlds should shutdown. - An integer Value which contains the number of unprocessed examples queued (acquiring the semaphore only claims them--this counter is decremented once the processing is complete). """ def __init__(self, world_class, opt, agents): super().__init__(opt) self.inner_world = world_class(opt, agents) self.queued_items = Semaphore(0) # counts num exs to be processed self.epochDone = Condition() # notifies when exs are finished self.terminate = Value('b', False) # tells threads when to shut down self.cnt = Value('i', 0) # number of exs that remain to be processed self.threads = [] for i in range(opt['numthreads']): self.threads.append( HogwildProcess(i, world_class, opt, agents, self.queued_items, self.epochDone, self.terminate, self.cnt)) for t in self.threads: t.start() def display(self): self.shutdown() raise NotImplementedError('Hogwild does not support displaying in-run' ' task data. Use `--numthreads 1`.') def episode_done(self): return False def parley(self): """Queue one item to be processed.""" with self.cnt.get_lock(): self.cnt.value += 1 self.queued_items.release() self.total_parleys += 1 def getID(self): return self.inner_world.getID() def report(self, compute_time=False): return self.inner_world.report(compute_time) def save_agents(self): self.inner_world.save_agents() def synchronize(self): """Sync barrier: will wait until all queued examples are processed.""" with self.epochDone: self.epochDone.wait_for(lambda: self.cnt.value == 0) def shutdown(self): """Set shutdown flag and wake threads up to close themselves""" # set shutdown flag with self.terminate.get_lock(): self.terminate.value = True # wake up each thread by queueing fake examples for _ in self.threads: self.queued_items.release() # wait for threads to close for t in self.threads: t.join()
def train(args, worker_id: int, global_model: Union[ActorNetwork, ActorCriticNetwork], T: Value, global_reward: Value, optimizer: torch.optim.Optimizer = None, global_model_critic: CriticNetwork = None, optimizer_critic: torch.optim.Optimizer = None, lr_scheduler: torch.optim.lr_scheduler = None, lr_scheduler_critic: torch.optim.lr_scheduler = None): """ Start worker in training mode, i.e. training the shared model with backprop loosely based on https://github.com/ikostrikov/pytorch-a3c/blob/master/train.py :param args: console arguments :param worker_id: id of worker to differentiatethem and init different seeds :param global_model: global model, which is optimized/ for split models: actor :param T: global counter of steps :param global_reward: global running reward value :param optimizer: optimizer for shared model/ for split models: actor model :param global_model_critic: optional global critic model for split networks :param optimizer_critic: optional critic optimizer for split networks :param lr_scheduler: optional learning rate scheduler instance for shared model / for fixed model: actor learning rate scheduler :param lr_scheduler_critic: optional learning rate scheduler instance for critic model :return: None """ torch.manual_seed(args.seed + worker_id) if args.worker == 1: logging.info(f"Running A2C with {args.n_envs} environments.") if "RR" not in args.env_name: env = SubprocVecEnv([ make_env(args.env_name, args.seed, i, args.log_dir) for i in range(args.n_envs) ]) else: env = DummyVecEnv( [make_env(args.env_name, args.seed, worker_id, args.log_dir)]) else: logging.info(f"Running A3C: training worker {worker_id} started.") env = DummyVecEnv( [make_env(args.env_name, args.seed, worker_id, args.log_dir)]) # avoid any issues if this is not 1 args.n_envs = 1 normalizer = get_normalizer(args.normalizer, env) # init local NN instance for worker thread model = copy.deepcopy(global_model) model.train() model_critic = None if global_model_critic: model_critic = copy.deepcopy(global_model_critic) model_critic.train() # if no shared optimizer is provided use individual one if not optimizer: optimizer, optimizer_critic = get_optimizer( args.optimizer, global_model, args.lr, model_critic=global_model_critic, lr_critic=args.lr_critic) if args.lr_scheduler == "exponential": lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) if optimizer_critic: lr_scheduler_critic = torch.optim.lr_scheduler.ExponentialLR( optimizer_critic, gamma=0.99) state = torch.Tensor(env.reset()) t = np.zeros(args.n_envs) global_iter = 0 episode_reward = np.zeros(args.n_envs) if worker_id == 0: writer = SummaryWriter(log_dir='experiments/runs/') while True: # Get state of the global model model.load_state_dict(global_model.state_dict()) if not args.shared_model: model_critic.load_state_dict(global_model_critic.state_dict()) # containers for computing loss values = [] log_probs = [] rewards = [] entropies = [] # container to check whether a terminal state was reached from one of the envs terminals = [] # reward_sum = 0 for step in range(args.rollout_steps): t += 1 if args.shared_model: value, mu, std = model(normalizer(state)) else: mu, std = model(normalizer(state)) value = model_critic(normalizer(state)) dist = torch.distributions.Normal(mu, std) # ------------------------------------------ # # select action action = dist.sample() # ------------------------------------------ # Compute statistics for loss entropy = dist.entropy().sum(-1).unsqueeze(-1) log_prob = dist.log_prob(action).sum(-1).unsqueeze(-1) # make selected move action = np.clip(action.detach().numpy(), -args.max_action, args.max_action) state, reward, dones, _ = env.step( action[0] if not args.worker == 1 or "RR" in args.env_name else action) reward = shape_reward(args, reward) episode_reward += reward # probably don't set terminal state if max_episode length dones = np.logical_or(dones, t >= args.max_episode_length) values.append(value) log_probs.append(log_prob) rewards.append(torch.Tensor(reward).unsqueeze(-1)) entropies.append(entropy) terminals.append(torch.Tensor(1 - dones).unsqueeze(-1)) for i, done in enumerate(dones): if done: # keep track of the avg overall global reward with global_reward.get_lock(): if global_reward.value == -np.inf: global_reward.value = episode_reward[i] else: global_reward.value = .99 * global_reward.value + .01 * episode_reward[ i] if worker_id == 0 and T.value % args.log_frequency == 0: writer.add_scalar("reward/global", global_reward.value, T.value) episode_reward[i] = 0 t[i] = 0 if args.worker != 1 or "RR" in args.env_name: env.reset() with T.get_lock(): # this is one for a3c and n for A2C (actually the lock is not needed for A2C) T.value += args.n_envs if lr_scheduler and worker_id == 0 and T.value % args.lr_scheduler_step and global_iter != 0: lr_scheduler.step(T.value / args.lr_scheduler_step) if lr_scheduler_critic: lr_scheduler_critic.step(T.value / args.lr_scheduler_step) state = torch.Tensor(state) if args.shared_model: v, _, _ = model(normalizer(state)) G = v.detach() else: G = model_critic(normalizer(state)).detach() values.append(G) # compute loss and backprop advantages = torch.zeros((args.n_envs, 1)) ret = torch.zeros((args.rollout_steps, args.n_envs, 1)) adv = torch.zeros((args.rollout_steps, args.n_envs, 1)) # iterate over all time steps from most recent to the starting one for i in reversed(range(args.rollout_steps)): # G can be seen essentially as the return over the course of the rollout G = rewards[i] + args.discount * terminals[i] * G if not args.no_gae: # Generalized Advantage Estimation td_error = rewards[i] + args.discount * terminals[i] * values[ i + 1] - values[i] # terminals here to "reset" advantages to 0, because reset ist called internally in the env # and new trajectory started advantages = advantages * args.discount * args.tau * terminals[ i] + td_error else: advantages = G - values[i].detach() adv[i] = advantages.detach() ret[i] = G.detach() policy_loss = -(torch.stack(log_probs) * adv).mean() # minus 1 in order to remove the last element, which is only necessary for next timestep value value_loss = .5 * (ret - torch.stack(values[:-1])).pow(2).mean() entropy_loss = torch.stack(entropies).mean() # zero grads to reset the gradients optimizer.zero_grad() if args.shared_model: # combined loss for shared architecture total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss total_loss.backward() else: optimizer_critic.zero_grad() value_loss.backward() (policy_loss - args.entropy_loss_weight * entropy_loss).backward() # this is just used for plotting in tensorboard total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) sync_grads(model, global_model) optimizer.step() if not args.shared_model: torch.nn.utils.clip_grad_norm_(model_critic.parameters(), args.max_grad_norm) sync_grads(model_critic, global_model_critic) optimizer_critic.step() global_iter += 1 if worker_id == 0 and T.value % args.log_frequency == 0: log_to_tensorboard(writer, model, optimizer, rewards, values, total_loss, policy_loss, value_loss, entropy_loss, T.value, model_critic=model_critic, optimizer_critic=optimizer_critic)