def __init__(self, ndata, nprocs): self._ndata = mp.RawValue(ctypes.c_int, ndata) self._start = mp.RawValue(ctypes.c_int, 0) self._lock = mp.Lock() min_chunk = ndata // nprocs min_chunk = ndata if min_chunk <= 2 else min_chunk self._chunk = min_chunk
def __init__(self, value=0): """ RawValue because we don't need it to create a Lock: """ self.capacity = multiprocessing.RawValue('i', value) self.refillRate = multiprocessing.RawValue('i', value) self.lock = multiprocessing.Lock()
def __init__(self, dim, bounds, popsize=31, stop_fitness=None, keep=200, F=0.5, Cr=0.9, rg=Generator(MT19937()), logger=None): self.dim, self.lower, self.upper = _check_bounds(bounds, dim) self.popsize = popsize self.stop_fitness = stop_fitness self.keep = keep self.rg = rg self.F0 = F self.Cr0 = Cr self.stop = 0 self.iterations = 0 self.evals = 0 self.p = 0 self.improves = deque() self._init() if not logger is None: self.logger = logger self.best_y = mp.RawValue(ct.c_double, 1E99) self.n_evals = mp.RawValue(ct.c_long, 0) self.time_0 = time()
def __init__(self, initval=0, model='thread'): """ Initialize the counter with initval. Chooses threading model based on passed thread parameter. 'thread' : Utlizes threading.Lock() 'multiprocessing' : Utilizes multiprocessing.Lock() """ if model == 'thread': self.lock = threading.Lock() class Value: def __init__(self, value=None): self.value = value self.val = Value(initval) self.incval = Value(0) self.decval = Value(0) else: self.lock = multiprocessing.Lock() self.val = multiprocessing.RawValue('l', initval) self.incval = multiprocessing.RawValue('l', 0) self.decval = multiprocessing.RawValue('l', 0)
def initialize(self, env_spaces, share_memory=False): if share_memory: self.eval_epsilon = mp.RawValue(ctypes.c_float, 0) # Does not support vector-valued epsilon. self.sample_epsilon = mp.RawValue(ctypes.c_float, 1) else: self.eval_epsilon = 1 self.sample_epsilon = 1
def __init__(self, pfun, bounds): self.name = "cv_score" self.dim = len(bounds.lb) self.pfun = pfun self.bounds = bounds self.evals = mp.RawValue(ct.c_int, 0) self.best_y = mp.RawValue(ct.c_double, math.inf) self.t0 = time.perf_counter()
def __init__(self, dim, f8fun): self.name = "f8" self.dim = dim self.f8fun = f8fun self.bounds = bounds(dim) self.evals = mp.RawValue(ct.c_int, 0) self.best_y = mp.RawValue(ct.c_double, math.inf) self.t0 = time.perf_counter()
def __init__(self, context_size): # use raw values because both indices have # to manually be locked together # 'i' indicates a signed integer. These shared objects will be process and thread-safe. # initialize one new value for doc_id self._doc_id = multiprocessing.RawValue('i', 0) # initialize one new value for doc_pos (the first position is at pos=context_size) self._in_doc_pos = multiprocessing.RawValue('i', context_size) self._lock = multiprocessing.Lock()
def start_capture(self): # should not call this more than once assert not self.capture_process output_raw_filename = None self.logger.info("Starting capture on device '%s' with mode: '%s'" % (self.capture_device, self.mode)) if self.capture_device == 'decklink': if self.mode not in supported_formats.keys(): raise Exception("Unsupported video format %s" % self.mode) self.output_raw_file = tempfile.NamedTemporaryFile( dir=self.custom_tempdir) output_raw_filename = self.output_raw_file.name elif self.capture_device == 'pointgrey': # for debugging purposes, it might be useful to print out the # flycap version. the best way I could find to do this was to # try to look up the .soname flycap_lib = max([filename for filename in os.listdir('/usr/lib') \ if 'libflycapture.so' in filename], key=len) if flycap_lib: version_index = flycap_lib.find('.so') + 4 self.logger.info('Using PointGrey SDK version: %s' % flycap_lib[version_index:]) else: self.logger.warn("Unable to determine PointGrey SDK version") self.outputdir = tempfile.mkdtemp(dir=self.custom_tempdir) self.frame_counter = multiprocessing.RawValue('i', 0) self.finished_semaphore = multiprocessing.RawValue('b', False) self.capture_process = CaptureProcess( self.capture_device, self.mode, self.frame_counter, self.finished_semaphore, output_raw_filename=output_raw_filename, outputdir=self.outputdir, fps=self.fps, camera_settings_file=self.camera_settings_file) self.logger.info("Starting capture...") self.capture_process.start() # wait for capture to actually start... self.logger.info("Waiting for first frame...") max_wait_for_frame = 5 elapsed = 0 interval = 0.1 while self.capture_framenum() < 1: time.sleep(interval) elapsed += interval if elapsed > max_wait_for_frame: self.logger.error("Timed out waiting for first frame! Capture " "program hung?") self.terminate_capture() raise Exception("Timed out waiting for first frame")
def build_sync_scat(n_parallel, n_in_out): sync = struct( assign_idxs=np.ctypeslib.as_array(mp.RawArray('l', n_parallel + 1)), use_idxs_arr=mp.RawValue(c_bool, False), tag=mp.RawValue('i', 0), size=mp.RawValue('l', 0), idxs_arr=None, # (allocated later; only need if shuffling) data_IDs=mp.RawArray('i', n_in_out), ) return sync
def __init__( self, fun, # fitness function bounds, # bounds of the objective function arguments max_eval_fac=None, # maximal number of evaluations factor check_interval=100, # sort evaluation store after check_interval iterations capacity=500, # capacity of the evaluation store logger=None, # if None logging is switched off num_retries=None, statistic_num=0, datafile=None): self.fun = fun self.lower, self.upper = _convertBounds(bounds) self.delta = self.upper - self.lower self.logger = logger self.capacity = capacity if max_eval_fac is None: if num_retries is None: max_eval_fac = 50 else: max_eval_fac = int(min(50, 1 + num_retries // check_interval)) if num_retries == None: num_retries = max_eval_fac * check_interval self.num_retries = num_retries # increment eval_fac so that max_eval_fac is reached at last retry self.eval_fac_incr = max_eval_fac / (num_retries / check_interval) self.max_eval_fac = max_eval_fac self.check_interval = check_interval self.dim = len(self.lower) self.random = Random() self.t0 = time.perf_counter() #shared between processes self.add_mutex = mp.Lock() self.check_mutex = mp.Lock() self.xs = mp.RawArray(ct.c_double, capacity * self.dim) self.ys = mp.RawArray(ct.c_double, capacity) self.eval_fac = mp.RawValue(ct.c_double, 1) self.count_evals = mp.RawValue(ct.c_long, 0) self.count_runs = mp.RawValue(ct.c_int, 0) self.num_stored = mp.RawValue(ct.c_int, 0) self.num_sorted = mp.RawValue(ct.c_int, 0) self.best_y = mp.RawValue(ct.c_double, math.inf) self.worst_y = mp.RawValue(ct.c_double, math.inf) self.best_x = mp.RawArray(ct.c_double, self.dim) self.statistic_num = statistic_num self.datafile = datafile if statistic_num > 0: # enable statistics self.statistic_num = statistic_num self.time = mp.RawArray(ct.c_double, self.statistic_num) self.val = mp.RawArray(ct.c_double, self.statistic_num) self.si = mp.RawValue(ct.c_int, 0) self.sevals = mp.RawValue(ct.c_long, 0) self.bval = mp.RawValue(ct.c_double, math.inf)
def _build_parallel_ctrl(self, n_worker): self.ctrl = AttrDict( quit=mp.RawValue(ctypes.c_bool, False), barrier_in=mp.Barrier(n_worker + 1), barrier_out=mp.Barrier(n_worker + 1), do_eval=mp.RawValue(ctypes.c_bool, False), itr=mp.RawValue(ctypes.c_long, 0), ) self.traj_infos_queue = mp.Queue() self.eval_traj_infos_queue = mp.Queue() self.sync = AttrDict(stop_eval=mp.RawValue(ctypes.c_bool, False))
def RunCameraInterface(config, no_escape=True): global num_cameras num_cameras = num_cameras + 1 # Initialize objects used to communicate between processes storage_frame_queue = LabeledQueue(frame_type='jpeg') # This queue is filled by the camera acquisition process and emptied by the video writing process visualization_frame_queue = LabeledQueue(frame_type='img') # This queue is filled by the camera acquisition process and emptied by the (visualization) primary process do_record = config.get('RecordVideo', False) if do_record: frame_queues = [visualization_frame_queue, storage_frame_queue] else: print('NOT RECORDING!!!') frame_queues = [visualization_frame_queue] # We'll draing when we terminate! global all_queues all_queues.extend(frame_queues) signal.signal(signal.SIGINT, termination_handler) camera_process_finished = multiprocessing.RawValue('b', True) # This signals to the main process that the camera acquisition process has terminated global producer_finished_flags producer_finished_flags['Camera{}'.format(num_cameras)] = camera_process_finished global terminate_flag camera_process = multiprocessing.Process(target=start_camera, args=(config, frame_queues, terminate_flag, camera_process_finished)) camera_process.daemon = True global all_processes all_processes['Camera{}'.format(num_cameras)] = camera_process camera_process.start() # Launch the camera frame acquisition process global consumer_finished_flags if do_record: vwriter_process_finished = multiprocessing.RawValue('b', True) # This signals to the main process that the video writing process has terminated consumer_finished_flags['Writer{}'.format(num_cameras)] = camera_process_finished vwriter_process = multiprocessing.Process(target=start_writer, args=(config, storage_frame_queue, terminate_flag, vwriter_process_finished)) vwriter_process.daemon = True all_processes['Writer{}'.format(num_cameras)] = vwriter_process vwriter_process.start() # Launch the video writing process # Create the main pyglet window time.sleep(0.1) pyglet_process_finished = multiprocessing.RawValue('b', True) # This signals to the main process that the camera acquisition process has terminated consumer_finished_flags['Pyglet{}'.format(num_cameras)] = pyglet_process_finished pyglet_process = multiprocessing.Process(target=start_window, args=(config, visualization_frame_queue, terminate_flag, pyglet_process_finished, no_escape)) pyglet_process.daemon = True all_processes['Pyglet{}'.format(num_cameras)] = camera_process pyglet_process.start() # Launch the camera frame acquisition process return terminate_flag
def __init__(self): # # RETURNS STDOUT: self._state = "TEXT" + str(NUMBER) # # RETURNS BAD VALUE: self._timestamp.value = 1234567890.99 # self._state = multiprocessing.RawValue(ctypes.c_char_p) # self._tid = multiprocessing.RawValue(ctypes.c_char_p) # self._timestamp = multiprocessing.RawValue(ctypes.c_float) self._state = multiprocessing.RawValue(ctypes.c_int, WorkerState.NOT_READY) self._tid = multiprocessing.RawArray('c', 64) self._timestamp = multiprocessing.RawValue(ctypes.c_uint, 0)
def __init__( self, bounds, # bounds of the objective function arguments max_evaluations=50000, # maximum evaluation count check_interval=10, # sort evaluation memory after check_interval iterations capacity=500, # capacity of the evaluation store logger=None # if None logging is switched off ): self.lower, self.upper = _convertBounds(bounds) self.logger = logger self.max_evals = max_evaluations self.capacity = capacity self.check_interval = check_interval self.dim = len(self.lower) self.delta = [] for k in range(self.dim): self.delta.append(self.upper[k] - self.lower[k]) #shared between processes self.add_mutex = mp.Lock() self.xs = mp.RawArray(ct.c_double, self.capacity * self.dim) self.ys = mp.RawArray(ct.c_double, self.capacity) self.count_evals = mp.RawValue(ct.c_long, 0) self.count_runs = mp.RawValue(ct.c_int, 0) self.num_stored = mp.RawValue(ct.c_int, 0) self.num_sorted = mp.RawValue(ct.c_int, 0) self.count_stat_runs = mp.RawValue(ct.c_int, 0) self.t0 = time.perf_counter() self.mean = mp.RawValue(ct.c_double, 0) self.qmean = mp.RawValue(ct.c_double, 0) self.best_y = mp.RawValue(ct.c_double, math.inf)
def __init__(self, backend, main, np, args=()): self.Errors = backend.QueueFactory(1) self._tls = backend.StorageFactory() self.main = main self.args = args self.slaveguard = threading.Thread(target=self._slaveGuard) self.errorguard = threading.Thread(target=self._errorGuard) # self._allDead has to be from backend because the slaves will check # this variable via is_alive() self._allDead = backend.EventFactory() # each dead child releases one sempahore # when all children are dead, will set _allDead event. self.semaphore = threading.Semaphore(0) self.JoinedProcesses = multiprocessing.RawValue('l') # workers self.P = [ backend.SlaveFactory(target=self._slaveMain, args=(rank,)) \ for rank in range(np) ] # nanny threads self.N = [ threading.Thread(target=self._slaveNanny, args=(rank, self.P[rank])) \ for rank in range(np) ] return
def __init__(self, remote_addr, remote_port, metadata, compdata_queue): self.remote_addr = remote_addr self.remote_port = remote_port self.metadata = metadata self.compdata_queue = compdata_queue # measurement self.monitor_network_bw = multiprocessing.RawValue(ctypes.c_double, 0) self.monitor_network_bw.value = 0.0 self.vm_resume_time_at_dest = multiprocessing.RawValue(ctypes.c_double, 0) self.time_finish_transmission = multiprocessing.RawValue(ctypes.c_double, 0) self.is_first_recv = False self.time_first_recv = 0 super(StreamSynthesisClient, self).__init__(target=self.transfer)
def init_par_objs(self, n_parallel): n = n_parallel shareds = SimpleContainer( new_state_action_count_vec=np.frombuffer( mp.RawArray('l', n), dtype=int, ), total_state_action_count=np.frombuffer( mp.RawValue('l'), dtype=int, )[0], max_state_action_count_vec=np.frombuffer( mp.RawArray('l', n), dtype=int, ), min_state_action_count_vec=np.frombuffer( mp.RawArray('l', n), dtype=int, ), sum_state_action_count_vec=np.frombuffer( mp.RawArray('l', n), dtype=int, ), n_steps_vec=np.frombuffer( mp.RawArray('l', n), dtype=int, ), ) barriers = SimpleContainer( summarize_count=mp.Barrier(n), update_count=mp.Barrier(n), ) self._par_objs = (shareds, barriers)
def run_test(queue_cls, num_producers, num_consumers, msgs_per_prod, consume_many): start_time = time() q = queue_cls(100000) producers = [] consumers = [] all_msgs_sent = multiprocessing.RawValue(ctypes.c_bool, False) for j in range(num_producers): p = multiprocessing.Process(target=produce_msgs, args=(q, j, msgs_per_prod)) producers.append(p) for j in range(num_consumers): p = multiprocessing.Process(target=consume_msgs, args=(q, j, all_msgs_sent, consume_many)) consumers.append(p) for p in producers: p.start() for c in consumers: c.start() for p in producers: p.join() all_msgs_sent.value = True for c in consumers: c.join() q.close() log.info('Exiting queue type %s', queue_cls.__module__ + '.' + queue_cls.__name__) end_time = time() time_taken = end_time - start_time log.info('Time taken by queue type %s is %.5f', queue_cls.__module__ + '.' + queue_cls.__name__, time_taken) return time_taken
def _build_parallel_ctrl(self, n_worker): self.ctrl = AttrDict( quit=mp.RawValue(ctypes.c_bool, False), barrier_in=mp.Barrier(n_worker + 1), barrier_out=mp.Barrier(n_worker + 1), do_eval=mp.RawValue(ctypes.c_bool, False), itr=mp.RawValue(ctypes.c_long, 0), # TODO SAVE state of curriculum? ) self.traj_infos_queue = mp.Queue() self.eval_traj_infos_queue = mp.Queue() self.sync = AttrDict(stop_eval=mp.RawValue(ctypes.c_bool, False), glob_average_return=mp.Value('d', 0.0), curriculum_stage=mp.Value('i', 0), difficulty=mp.Value('d', 0.0), seeds=mp.Array('i', n_worker))
def __init__(self, transfers): self.evals = mp.RawValue(ct.c_long, 0) # writable across python processes self.best_y = mp.RawValue(ct.c_double, math.inf) # writable across python processes self.t0 = time.perf_counter() self.transfers = transfers self.asteroid = transfers["asteroid"].to_numpy() self.station = transfers["station"].to_numpy() self.trajectory = transfers["trajectory"].to_numpy() self.transfer_start = transfers["transfer_start"].to_numpy() self.transfer_time = transfers["transfer_time"].to_numpy() self.mass = transfers["mass"].to_numpy() self.dv = transfers["dv"].to_numpy() self.trajectory_dv = trajectory_dv(self.asteroid, self.trajectory, self.dv)
def __init__(self, backend, main, np, args=()): self.Errors = backend.QueueFactory(1) self._tls = backend.StorageFactory() self.main = main self.args = args self.guard = threading.Thread(target=self._guardMain) self.errorguard = threading.Thread(target=self._errorGuard) # this has to be from backend because the slaves will check # this variable. self.guardDead = backend.EventFactory() # each dead child releases one sempahore # when all dead guard will proceed to set guarddead self.semaphore = threading.Semaphore(0) self.JoinedProcesses = multiprocessing.RawValue('l') self.P = [ backend.SlaveFactory(target=self._slaveMain, args=(rank,)) \ for rank in range(np) ] self.G = [ threading.Thread(target=self._slaveGuard, args=(rank, self.P[rank])) \ for rank in range(np) ] return
def __init__(self, ModelCls=None, model_kwargs=None, initial_model_state_dict=None): """ Arguments are saved but no model initialization occurs. Args: ModelCls: The model class to be used. model_kwargs (optional): Any keyword arguments to pass when instantiating the model. initial_model_state_dict (optional): Initial model parameter values. """ save__init__args(locals()) self.model = None # type: torch.nn.Module self.shared_model = None self.distribution = None self.device = torch.device("cpu") self._mode = None if self.model_kwargs is None: self.model_kwargs = dict() # The rest only for async operations: self._rw_lock = RWLock() self._send_count = mp.RawValue("l", 0) self._recv_count = 0
def sample_runner_initialize(self, affinity): n_server = len(affinity) n_worker = sum(len(aff["workers_cpus"]) for aff in affinity) n_envs_list = [self.batch_spec.B // n_worker] * n_worker if not self.batch_spec.B % n_worker == 0: logger.log( "WARNING: unequal number of envs per process, from " f"batch_B {self.batch_spec.B} and n_parallel {n_worker} " "(possible suboptimal speed).") for b in range(self.batch_spec.B % n_worker): n_envs_list[b] += 1 if self.eval_n_envs > 0: eval_n_envs_per = max(1, self.eval_n_envs // len(n_envs_list)) eval_n_envs = eval_n_envs_per * n_worker logger.log(f"Total parallel evaluation envs: {eval_n_envs}.") self.eval_max_T = 1 + int(self.eval_max_steps // eval_n_envs) self.eval_n_envs_per = eval_n_envs_per else: self.eval_n_envs_per = 0 self.eval_max_T = 0 ctrl = AttrDict( quit=mp.RawValue(ctypes.c_bool, False), barrier_in=mp.Barrier(n_server + n_worker + 1), barrier_out=mp.Barrier(n_server + n_worker + 1), do_eval=mp.RawValue(ctypes.c_bool, False), itr=mp.RawValue(ctypes.c_long, 0), ) traj_infos_queue = mp.Queue() common_kwargs = dict( ctrl=ctrl, traj_infos_queue=traj_infos_queue, ) servers_kwargs = assemble_servers_kwargs(affinity, n_envs_list, self.seed, self.double_buffer) servers = [ mp.Process(target=self.action_server_process, kwargs=s_kwargs.update(**common_kwargs)) for s_kwargs in servers_kwargs ] for s in servers: s.start() self.servers = servers self.ctrl = ctrl self.traj_infos_queue = traj_infos_queue
def find_rule_processes(tree, holdout, y, x1, x2, hierarchy): if config.TUNING_PARAMS.use_stumps: # Get rule to add to root (best_leaf, best_reg_sign, best_loss_hier_node, best_loss_mat) = \ find_rule_process_stumps(tree, holdout, y, x1, x2, hierarchy) return (best_leaf, best_reg_sign, best_loss_hier_node, best_loss_mat) rule_processes = [] nrow = x1.num_row ncol = x2.num_col # Initialize a lock to control access to the best rule objects. lock = multiprocessing.Lock() # Initialize loss to a large number, so that the first loss is chosen best_loss = multiprocessing.RawValue(ctypes.c_double, 1e100) best_leaf = multiprocessing.RawValue('i', -1) shared_best_loss_mat = multiprocessing.RawArray(ctypes.c_double, nrow * ncol) best_loss_reg = multiprocessing.RawValue('i', 0) best_loss_hier_node = multiprocessing.RawValue('i', 0) # Store the value of the next leaf index that needs to be processed, so that # the workers know what leaf to work on leaf_index_cntr = multiprocessing.Value('i', 0) # Pack arguments for the worker processes args = [ tree, holdout, y, x1, x2, hierarchy, leaf_index_cntr, (lock, best_loss, best_leaf, best_loss_hier_node, shared_best_loss_mat, best_loss_reg) ] # Fork worker processes, and wait for them to return fork_and_wait(config.NCPU, find_rule_process_worker, args) # Convert all of the shared types into standard python values best_leaf = int(best_leaf.value) best_loss_reg = int(best_loss_reg.value) best_loss_hier_node = int(best_loss_hier_node.value) # Convert the raw array into a numpy array best_loss_mat = np.reshape(np.array(shared_best_loss_mat), (nrow, ncol)) # Return rule_processes return (best_leaf, best_loss_reg, best_loss_hier_node, best_loss_mat)
def __init__( self, bounds, # bounds of the objective function arguments max_eval_fac=50, # maximal number of evaluations check_interval=100, # sort evaluation store after check_interval iterations capacity=500, # capacity of the evaluation store logger=None # if None logging is switched off ): self.lower, self.upper = _convertBounds(bounds) self.delta = self.upper - self.lower self.logger = logger self.capacity = capacity self.max_eval_fac = max_eval_fac self.check_interval = check_interval self.dim = len(self.lower) self.random = Random() self.t0 = time.perf_counter() #shared between processes self.add_mutex = mp.Lock() self.check_mutex = mp.Lock() self.xs = mp.RawArray(ct.c_double, capacity * self.dim) self.lowers = mp.RawArray(ct.c_double, capacity * self.dim) self.uppers = mp.RawArray(ct.c_double, capacity * self.dim) self.ys = mp.RawArray(ct.c_double, capacity) self.eval_fac = mp.RawValue(ct.c_int, 1) self.count_evals = mp.RawValue(ct.c_long, 0) self.count_runs = mp.RawValue(ct.c_int, 0) self.num_stored = mp.RawValue(ct.c_int, 0) self.num_sorted = mp.RawValue(ct.c_int, 0) self.best_y = mp.RawValue(ct.c_double, math.inf) self.worst_y = mp.RawValue(ct.c_double, math.inf) self.best_x = mp.RawArray(ct.c_double, self.dim)
def get_n_gpu(): detected_n_gpu = mp.RawValue('i', 0) p = mp.Process(target=n_gpu_subprocess, args=(detected_n_gpu, )) p.start() p.join() n_gpu = int(detected_n_gpu.value) if n_gpu == -1: raise ImportError("Must be able to import pygpu to use GPUs.") return n_gpu
def __init__(self, command_queue, task_queue, mode_queue, output_queue, comp_type, comp_level): self.command_queue = command_queue self.task_queue = task_queue self.mode_queue = mode_queue self.output_queue = output_queue self.comp_type = comp_type self.comp_level = comp_level # shared variables between processes self.child_process_time_total = multiprocessing.RawValue( ctypes.c_double, 0) self.child_process_block_total = multiprocessing.RawValue( ctypes.c_double, 0) self.child_input_size_total = multiprocessing.RawValue( ctypes.c_ulong, 0) self.child_output_size_total = multiprocessing.RawValue( ctypes.c_ulong, 0) super(CompChildProc, self).__init__(target=self._comp)
def launch_workers(self, double_buffer, traj_infos_queue, affinity, seed, n_envs_list, eval_n_envs_per): n_worker = len(affinity["workers_cpus"]) sync = AttrDict( step_blockers=[mp.Semaphore(0) for _ in range(n_worker)], act_waiters=[mp.Semaphore(0) for _ in range(n_worker)], stop_eval=mp.RawValue(ctypes.c_bool, False), ) step_buffer_pyt, step_buffer_np = build_step_buffer( self.examples, sum(n_envs_list)) if self.eval_n_envs_per > 0: eval_n_envs = self.eval_n_envs_per * n_worker eval_step_buffer_pyt, eval_step_buffer_np = build_step_buffer( self.examples, eval_n_envs) self.eval_step_buffer_pyt = eval_step_buffer_pyt self.eval_step_buffer_np = eval_step_buffer_np else: eval_step_buffer_np = None common_kwargs = dict( EnvCls=self.EnvCls, env_kwargs=self.env_kwargs, agent=None, batch_T=self.batch_spec.T, CollectorCls=self.CollectorCls, TrajInfoCls=self.TrajInfoCls, traj_infos_queue=traj_infos_queue, ctrl=self.ctrl, max_decorrelation_steps=self.max_decorrelation_steps, eval_n_envs=self.eval_n_envs_per, eval_CollectorCls=self.eval_CollectorCls or EvalCollector, eval_env_kwargs=self.eval_env_kwargs, eval_max_T=self.eval_max_T, ) workers_kwargs = assemble_workers_kwargs(affinity, seed, double_buffer, n_envs_list, step_buffer_np, sync, self.eval_n_envs_per, eval_step_buffer_np) workers = [ mp.Process(target=sampling_process, kwargs=dict(common_kwargs=common_kwargs, worker_kwargs=w_kwargs)) for w_kwargs in workers_kwargs ] for w in workers: w.start() self.workers = workers self.step_buffer_pyt = step_buffer_pyt self.step_buffer_np = step_buffer_np self.sync = sync self.mid_batch_reset = self.CollectorCls.mid_batch_reset
def __init__(self, ndata, nprocs, chunk=None, schedule='guided'): if not schedule in ['guided', 'dynamic', 'static']: raise ValueError('unknown scheduling strategy') self._ndata = mp.RawValue(ctypes.c_int, ndata) self._start = mp.RawValue(ctypes.c_int, 0) self._lock = mp.Lock() self._schedule = schedule self._nprocs = nprocs if schedule == 'guided' or schedule == 'dynamic': min_chunk = ndata // (10 * nprocs) if chunk: min_chunk = chunk min_chunk = max(min_chunk, 1) self._chunk = min_chunk elif schedule == 'static': min_chunk = ndata // nprocs if chunk: min_chunk = max(chunk, min_chunk) min_chunk = max(min_chunk, 1) self._chunk = min_chunk