예제 #1
0
def learning_loop(exit_flag: mp.Value,
                  gvfs: Sequence[Sequence[Learner]],
                  behaviour_gvf: SARSA,
                  main2gvf: mp.SimpleQueue,
                  gvf2main: mp.SimpleQueue,
                  gvf2plot: mp.SimpleQueue):
    action, action_prob, obs, x = None, None, None, None

    # get first state
    while exit_flag.value == 0 and obs is None:
        while exit_flag.value == 0 and main2gvf.empty():
            time.sleep(0.001)
        if exit_flag.value == 0:
            obs, x = main2gvf.get()
            action, action_prob = behaviour_gvf.policy(obs=obs, x=x)
            gvf2main.put(action)

    # main loop
    while exit_flag.value == 0:
        while exit_flag.value == 0 and main2gvf.empty():
            time.sleep(0.001)
        if exit_flag.value:
            break

        # get data from servos
        obsp, xp = main2gvf.get()
        actionp, action_probp = behaviour_gvf.policy(obs=obsp, x=xp)

        # update weights
        for g in chain.from_iterable(gvfs):
            g.update(x, obs,
                     action, action_prob,
                     xp, obsp,
                     actionp, action_probp)

        # send action
        gvf2main.put(actionp)

        # send data to plots
        gdata = [[g.data(x, obs, action, xp, obsp)
                  for g in gs]
                 for gs in gvfs]
        data = dict(ChainMap(*chain.from_iterable(gdata)))
        data['obs'] = obs
        gvf2plot.put(data)

        # go to next state
        obs = obsp
        x = xp
        action = actionp
        action_prob = action_probp

    print('Done learning!')
예제 #2
0
def learning_loop(exit_flag: mp.Value,
                  gvfs: Sequence[Sequence[GTDLearner]],
                  main2gvf: mp.SimpleQueue,
                  gvf2plot: mp.SimpleQueue):
    action, action_prob, obs, x = None, None, None, None

    # get first state
    while exit_flag.value == 0 and obs is None:
        while exit_flag.value == 0 and main2gvf.empty():
            time.sleep(0.001)
        if exit_flag.value == 0:
            action, action_prob, obs, x = main2gvf.get()

    i = 1

    # main loop
    while exit_flag.value == 0:
        while exit_flag.value == 0 and main2gvf.empty():
            time.sleep(0.001)
        if exit_flag.value:
            break

        i += 1
        ude = False
        rupee = False
        if 5000 < i < 5100:
            ude = True
        if i == 7000:
            rupee = True

        # get data from servos
        actionp, action_probp, obsp, xp = main2gvf.get()

        # update weights
        for gs, xi, xpi in zip(gvfs, x, xp):
            for g in gs:
                g.update(action, action_prob, obs, obsp, xi, xpi, ude, rupee)

        # send data to plots
        gdata = [[g.data(xi, obs, action, xpi, obsp)
                  for g in gs]
                 for gs, xi, xpi in zip(gvfs, x, xp)]
        data = dict(ChainMap(*chain.from_iterable(gdata)))
        data['obs'] = obs
        gvf2plot.put(data)

        # go to next state
        obs = obsp
        x = xp
        action = actionp
        action_prob = action_probp

    print('Done learning!')
예제 #3
0
class RecordIndexService(DataShardService):
    def __init__(
        self,
        master_client,
        batch_size,
        num_epochs=None,
        dataset_size=None,
        task_type=elasticai_api_pb2.TRAINING,
        shuffle=False,
    ):
        super(RecordIndexService, self).__init__(
            master_client=master_client,
            batch_size=batch_size,
            num_epochs=num_epochs,
            dataset_size=dataset_size,
            shuffle=shuffle,
            task_type=task_type,
        )
        self._shard_queue = SimpleQueue()
        threading.Thread(
            target=self._get_shard_indices,
            name="fetch_shard_indices",
            daemon=True,
        ).start()

    def _get_shard_indices(self):
        while True:
            if self._shard_queue.empty():
                task = self.get_task(self._task_type)
                if not task.shard or task.type != self._task_type:
                    break
                ids = (task.shard.indices if task.shard.indices else list(
                    range(task.shard.start, task.shard.end)))
                for i in ids:
                    self._shard_queue.put(i)
            else:
                time.sleep(1)

    def fetch_record_index(self):
        """Fetch an index of the record. The function get an index
        from a queue because there may be multiple sub-process to call
        the function.
        """
        for _ in range(30):
            if not self._shard_queue.empty():
                return self._shard_queue.get()
            else:
                time.sleep(1)
        raise StopIteration
예제 #4
0
class Network:

	def __init__(self):
		self.cmd_queue = SimpleQueue()
		self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

		print(colored('[STCK]', 'grey'), colored('Network: Initing.', 'white'))
	def connect(self, server_ip, server_port=8200):
		self.sock.connect((server_ip, server_port))

		print(colored('[STCK]', 'grey'), colored('Network: Connecting.', 'white'))
	def send(self, cmd):
		self.sock.send(('%s\n' % cmd).encode())
		#print('Network: Send %s' % cmd)

	def receive(self):
		recvData = self.sock.recv(1024).decode("utf8").split('\n')
		for i in range(len(recvData)):
			self.cmd_queue.put(recvData[i])

	def nextCmd(self):
		return self.cmd_queue.get()

	def hasCmd(self):
		return not self.cmd_queue.empty()

	def disconnect(self):
		print("Network: Closed")
		self.sock.close()
예제 #5
0
def test_tracer_usage_multiprocess():
    q = MPQueue()

    # Similar to test_multiprocess(), ensures that no collisions are
    # generated between parent and child processes while using
    # multiprocessing.

    # Note that we have to be wary of the size of the underlying
    # pipe in the queue: https://bugs.python.org/msg143081

    def target(q):
        ids_list = list(
            chain.from_iterable((s.span_id, s.trace_id) for s in [tracer.start_span("s") for _ in range(10)])
        )
        q.put(ids_list)

    ps = [mp.Process(target=target, args=(q,)) for _ in range(30)]
    for p in ps:
        p.start()

    for p in ps:
        p.join()

    ids_list = list(chain.from_iterable((s.span_id, s.trace_id) for s in [tracer.start_span("s") for _ in range(100)]))
    ids = set(ids_list)
    assert len(ids) == len(ids_list), "Collisions found in ids"

    while not q.empty():
        child_ids_list = q.get()
        child_ids = set(child_ids_list)

        assert len(child_ids) == len(child_ids_list), "Collisions found in subprocess ids"

        assert ids & child_ids == set()
        ids = ids | child_ids  # accumulate the ids
예제 #6
0
    def processDatabaseUpdate(
            databaseFileName: str,
            databaseUpdateQueue: multiprocessing.SimpleQueue) -> None:
        # Open database
        database = Database(databaseFileName)

        # Process updates
        while not databaseUpdateQueue.empty():
            # Get a task
            task = databaseUpdateQueue.get()
            taskName = task[0]

            # Update modification time
            if taskName == 'UpdateMtime':
                filePath = task[1]

                # Get file info
                fileInfo = database.getFile(filePath)
                # Update mtime
                fileInfo.stats['mtime'] = os.stat(filePath).st_mtime
                # Update database record
                database.setFile(fileInfo)

            # Remove a blob
            if taskName == 'RemoveBlob':
                blobId = task[1]

                # Remove blob
                database.removeBlob(blobId)

        # Close database
        database.commit()
        database.close()
예제 #7
0
    def __init__(self, queue: SimpleQueue, address: str, port: int,
                 entWin: Window):
        """Initialize a level generator

        conn: Used to communicate with the Chase
        address: Connection address to start a new mcpi connection
        port: Connection port to start a new mcpi connection
        entrance: The entrance window to begin level with
        """
        self.queue = queue
        self.mc = mmc.Minecraft.create(address, port)
        self.entWin = entWin
        self._construct()
        self.players = []
        while True:
            while not queue.empty():
                rec: Tuple[Cmd, List] = queue.get()  # New msg
                if rec[0] == Cmd.TERM:
                    self._cleanup()
                    return
                elif rec[0] == Cmd.ENT:
                    self.players.extend(rec[1])
                elif rec[0] == Cmd.EXI:
                    for i in rec[1]:
                        try:
                            self.players.remove(i)
                        except ValueError:
                            sys.stderr.write(
                                f"Player(id) {i} not found in {self}!")
            self._loop()
예제 #8
0
def _record_loop(q: SimpleQueue, filename, monitor, frame_rate):
    with mss() as sct:
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        # adjust monitor to crop out the parts not visible
        if monitor['left'] < 0:
            monitor['width'] += monitor['left']
            monitor['left'] = 0
        if monitor['top'] < 0:
            monitor['height'] += monitor['top']
            monitor['top'] = 0
        monitor['height'] = min(monitor['height'],
                                sct.monitors[0]['height'] - monitor['top'])
        monitor['width'] = min(monitor['width'],
                               sct.monitors[0]['width'] - monitor['left'])
        out = cv2.VideoWriter(filename, fourcc, frame_rate,
                              (monitor['width'], monitor['height']))
        period = 1. / frame_rate
        while q.empty():
            start_time = time.time()

            img = np.array(sct.grab(monitor))
            out.write(img[:, :, :3])

            # wait for frame rate time
            elapsed = time.time() - start_time
            if elapsed < period:
                time.sleep(period - elapsed)
        out.release()
예제 #9
0
def importData(simulator):
    testSim = simulator
    q = SimpleQueue()
    jobs = []
    PERIOD_SIZE = 50
    BATCH_SIZE = 100
    BATCH_COUNT = 1
    BATCH_OFFSET = 100
    dates = testSim.getAllDates()
    index = list(range(BATCH_COUNT))
    feed = []
    threads = 16

    running = False
    count = 0
    while 1:
        if count < threads:
            for i in random.sample(index, threads-count if len(index) >= threads-count else len(index)):
                p = Process(target=testSim.processTimePeriod, args=(q, PERIOD_SIZE, dates, BATCH_SIZE * (i + BATCH_OFFSET) + PERIOD_SIZE, BATCH_SIZE))
                jobs.append(p)
                p.start() 
                index.remove(i)
        count = 0
        for p in jobs:
            if not p.is_alive():
                p.terminate()
                jobs.remove(p)
            else:
                count += 1
        while not q.empty():
            print('Getting')
            feed.append(q.get())
        if count == 0 and len(index) == 0: 
            break
    return feed
예제 #10
0
def test_multiprocess():
    q = MPQueue()

    def target(q):
        assert sum((_ is _rand.seed for _ in forksafe._registry)) == 1
        q.put([_rand.rand64bits() for _ in range(100)])

    ps = [mp.Process(target=target, args=(q,)) for _ in range(30)]
    for p in ps:
        p.start()

    for p in ps:
        p.join()
        assert p.exitcode == 0

    ids_list = [_rand.rand64bits() for _ in range(1000)]
    ids = set(ids_list)
    assert len(ids_list) == len(ids), "Collisions found in ids"

    while not q.empty():
        child_ids_list = q.get()
        child_ids = set(child_ids_list)

        assert len(child_ids_list) == len(child_ids), "Collisions found in subprocess ids"

        assert ids & child_ids == set()
        ids = ids | child_ids  # accumulate the ids
예제 #11
0
def init_worker(status_queue: multiprocessing.SimpleQueue,
                param_queue: multiprocessing.SimpleQueue,
                result_queue: multiprocessing.SimpleQueue) -> None:
    global result
    global coverage_run
    global py_hash_secret
    global py_random_seed

    # Make sure the generator is re-seeded, as we have inherited
    # the seed from the parent process.
    py_random_seed = random.SystemRandom().randbytes(8)
    random.seed(py_random_seed)

    result = ChannelingTestResult(result_queue)
    if not param_queue.empty():
        server_addr, backend_dsn = param_queue.get()

        if server_addr is not None:
            os.environ['EDGEDB_TEST_CLUSTER_ADDR'] = json.dumps(server_addr)
        if backend_dsn:
            os.environ['EDGEDB_TEST_BACKEND_DSN'] = backend_dsn

    os.environ['EDGEDB_TEST_PARALLEL'] = '1'
    coverage_run = devmode.CoverageConfig.start_coverage_if_requested()
    py_hash_secret = cpython_state.get_py_hash_secret()
    status_queue.put(True)
예제 #12
0
파일: asset.py 프로젝트: mxie91/avocado
    def run(self, runnable):
        # pylint: disable=W0201
        self.runnable = runnable
        yield self.prepare_status("started")

        name = self.runnable.kwargs.get("name")
        # if name was passed correctly, run the Avocado Asset utility
        if name is not None:
            asset_hash = self.runnable.kwargs.get("asset_hash")
            algorithm = self.runnable.kwargs.get("algorithm")
            locations = self.runnable.kwargs.get("locations")
            expire = self.runnable.kwargs.get("expire")
            if expire is not None:
                expire = data_structures.time_to_seconds(str(expire))

            cache_dirs = self.runnable.config.get("datadir.paths.cache_dirs")
            if cache_dirs is None:
                cache_dirs = settings.as_dict().get("datadir.paths.cache_dirs")

            # let's spawn it to another process to be able to update the
            # status messages and avoid the Asset to lock this process
            queue = SimpleQueue()
            process = Process(
                target=self._fetch_asset,
                args=(
                    name,
                    asset_hash,
                    algorithm,
                    locations,
                    cache_dirs,
                    expire,
                    queue,
                ),
            )
            process.start()

            while queue.empty():
                time.sleep(RUNNER_RUN_STATUS_INTERVAL)
                yield self.prepare_status("running")

            output = queue.get()
            result = output["result"]
            stdout = output["stdout"]
            stderr = output["stderr"]
        else:
            # Otherwise, log the missing package name
            result = "error"
            stdout = ""
            stderr = 'At least name should be passed as kwargs using name="uri".'

        yield self.prepare_status("running", {
            "type": "stdout",
            "log": stdout.encode()
        })
        yield self.prepare_status("running", {
            "type": "stderr",
            "log": stderr.encode()
        })
        yield self.prepare_status("finished", {"result": result})
예제 #13
0
def learning_loop(exit_flag: mp.Value, gvfs: Sequence[Sequence[Learner]],
                  behaviour_gvf: SARSA, main2gvf: mp.SimpleQueue,
                  gvf2main: mp.SimpleQueue, gvf2plot: mp.SimpleQueue):
    action, action_prob, obs, x = None, None, None, None

    # get first state
    while exit_flag.value == 0 and obs is None:
        while exit_flag.value == 0 and main2gvf.empty():
            time.sleep(0.001)
        if exit_flag.value == 0:
            obs, x = main2gvf.get()
            action, action_prob = behaviour_gvf.policy(obs=obs, x=x)
            gvf2main.put(action)

    # main loop
    while exit_flag.value == 0:
        while exit_flag.value == 0 and main2gvf.empty():
            time.sleep(0.001)
        if exit_flag.value:
            break

        # get data from servos
        obsp, xp = main2gvf.get()
        actionp, action_probp = behaviour_gvf.policy(obs=obsp, x=xp)

        # update weights
        for g in chain.from_iterable(gvfs):
            g.update(x, obs, action, action_prob, xp, obsp, actionp,
                     action_probp)

        # send action
        gvf2main.put(actionp)

        # send data to plots
        gdata = [[g.data(x, obs, action, xp, obsp) for g in gs] for gs in gvfs]
        data = dict(ChainMap(*chain.from_iterable(gdata)))
        data['obs'] = obs
        gvf2plot.put(data)

        # go to next state
        obs = obsp
        x = xp
        action = actionp
        action_prob = action_probp

    print('Done learning!')
예제 #14
0
class AudioProcessingThread(threading.Thread):
    """
    Chirp Connect audio processing thread
    """
    DEBUG_AUDIO_FILENAME = 'chirp_audio.wav'

    def __init__(self, parent=None, *args, **kwargs):
        """
        Initialise audio processing.
        In debug mode, the audio data is saved to file.
        """
        self.sdk = parent.sdk
        self.sample_size = parent.sample_size
        self.block_size = parent.block_size
        self.sample_format = parent.sample_format
        self.process_input_fn = parent.process_input_fn
        self.sample_rate = float(parent.sdk.sample_rate)

        self.block_period = self.block_size / self.sample_rate or 0.1
        self.wav_filename = parent.wav_filename or self.DEBUG_AUDIO_FILENAME
        self.input_queue = SimpleQueue()
        super(AudioProcessingThread, self).__init__(*args, **kwargs)

        if self.sdk.debug:
            import soundfile as sf
            self.wav_file = sf.SoundFile(self.wav_filename,
                                         mode='w',
                                         channels=1,
                                         samplerate=self.sdk.sample_rate)

        self.daemon = True
        self.start()

    def run(self):
        """
        Continuously process any input data from circular buffer.
        Note: We need to sleep as much as possible in this thread
        to restrict CPU usage.
        """
        while self.is_alive():

            tstart = time.time()
            while not self.input_queue.empty():
                data = self.input_queue.get()
                self.process_input_fn(data)
                if self.sdk.debug and not self.wav_file.closed:
                    self.wav_file.buffer_write(data, dtype=self.sample_size)
                self.block_period = self.block_size / self.sample_rate

            tsleep = (self.block_period - ((time.time() - tstart)))
            if tsleep > 0:
                time.sleep(tsleep)

    def stop(self):
        """ In debug mode, close wav file """
        if self.sdk.debug:
            self.wav_file.close()
예제 #15
0
def find_optimal_temperature(config):

    model_paths = glob.glob(config["model_glob"])

    dwi_path_1 = config["inference"]["dwi_path"].format("")
    dwi_path_2 = config["inference"]["dwi_path"].format("retest")

    gpu_queue = SimpleQueue()
    for idx in get_gpus():
        gpu_queue.put(str(idx))

    procs = []
    pred_manager = Manager()
    predictions = pred_manager.dict()
    try:
        for mp in model_paths:

            #if any(t in mp for t in []):

            model_config = config["inference"].copy()
            model_config["model_path"] = mp

            for j in [0, 1]:
                run_config = model_config.copy()
                parse(run_config, "dwi_path", j)
                parse(run_config, "prior_path", j)
                parse(run_config, "term_path", j)
                parse(run_config, "seed_path", j)
                while gpu_queue.empty():
                    sleep(10)

                p = Process(target=run_inference,
                            args=(run_config, gpu_queue, predictions))
                p.start()
                procs.append(p)
                print("Launched {}: {}".format(mp.split("/")[-1], j))
                sleep(10)

    except KeyboardInterrupt:
        pass
    finally:
        for p in procs:
            p.join()
            while p.exitcode is None:
                sleep(0.1)

    pred_pairs = group_by_model(predictions)
    config["pred_pairs"] = pred_pairs

    save(config,
         name="opT_{}.yml".format(timestamp()),
         out_dir=os.path.dirname(config["model_glob"]))
    """
예제 #16
0
async def data_from_file(main2gvf: mp.SimpleQueue,
                         gvf2plot: mp.SimpleQueue,
                         coder: KanervaCoder):
    data = np.load('offline_data.npy')

    for item in data:
        item[-1] = coder(item[-2])
        main2gvf.put(item)

    time.sleep(0.1)
    while not gvf2plot.empty():
        time.sleep(0.1)
예제 #17
0
class AutohostFactory:

	def __init__(self):
		self.network = Network()
		self.idlehosts = SimpleQueue()
		self.count = 0
		self._load_autohosts()

	def new_autohost(self):

		print(colored('[INFO]', 'green'), colored('AFAC: Initing.', 'white'))
		if self.idlehosts.empty():
			username = self._new_autohost()
			print(colored('[INFO]', 'green'), colored('AFAC: Registering'+ username+'.', 'white'))
			return username
		else:
			username=self.idlehosts.get()

			print(colored('[INFO]', 'green'), colored('AFAC: Returning spare username:'******'.', 'white'))
			return username
		

	def free_autohost(self, username):
		self.idlehosts.put(username)

		print(colored('[INFO]', 'green'), colored('AFAC: Returning'+ username+'to the idle pool.', 'white'))
	def _new_autohost(self):
		self.network.connect('127.0.0.1')
		username = "******" % self.count
		password = b64encode(md5(b'password').digest()).decode('utf8')
		self.network.send("REGISTER %s %s" % (username, password)) # TODO: Check for errors
		time.sleep(1)
		self.network.send("CONFIRMAGREEMENT")
		self.network.receive()
		while self.network.hasCmd():
			print(self.network.nextCmd())
		self._save_autohost(username)
		self.network.disconnect()
		self.count += 1
		return username

	def _load_autohosts(self):
		with open('autohosts.txt', 'r') as file:
			usernames = file.read().split('\n')
			for username in usernames:
				if username != '':
					print("added: %s" % username)
					self.idlehosts.put(username)
					self.count += 1

	def _save_autohost(self, username):
		with open('autohosts.txt', 'a') as file:
			file.write('%s\n' % username)
예제 #18
0
def init_worker(param_queue: multiprocessing.SimpleQueue,
                result_queue: multiprocessing.SimpleQueue) -> None:
    global result

    # Make sure the generator is re-seeded, as we have inherited
    # the seed from the parent process.
    random.seed()

    result = ChannelingTestResult(result_queue)
    if not param_queue.empty():
        server_addr = param_queue.get()

        if server_addr is not None:
            os.environ['EDGEDB_TEST_CLUSTER_ADDR'] = json.dumps(server_addr)
예제 #19
0
def handle_sub_prepare(css, sids, sub_flags):
    print("***Config updated --> Prepare")
    processes = []
    q = SimpleQueue()
    for sid in sids:
        flags = cdb.GET_MODS_INCLUDE_LISTS | cdb.GET_MODS_SUPPRESS_DEFAULTS
        mods = cdb.get_modifications(css, sid, flags, ROOT_PATH)
        if mods == []:
            print("no modifications for subid {}".format(sid))
        else:
            print("edit-config and validate for subid {}".format(sid))
            p = Process(target=process_modifications, args=(mods, sid, q))
            processes.append(p)
            p.start()
    for process in processes:
        process.join()
    if not q.empty():
        cdb.sub_abort_trans(css, confd.ERRCODE_APPLICATION, 0, 0,
                            ', '.join(list(q.queue)))
        # TODO USE lxml to get the error string from the NETCONF error reply
        # TODO sub_abort_trans_info(...)
        return
    processes = []
    for sid in sids:
        print("commit for subid {}".format(sid))
        p = Process(target=commit, args=(sid, q))
        processes.append(p)
        p.start()
    for process in processes:
        process.join()
    if not q.empty():
        cdb.sub_abort_trans(css, confd.ERRCODE_APPLICATION, 0, 0,
                            ', '.join(list(q.queue)))
        # TODO USE lxml to get the error string from the NETCONF error reply
        # TODO sub_abort_trans_info(...)
        return
    cdb.sync_subscription_socket(css, cdb.DONE_PRIORITY)
예제 #20
0
    def run(self, runnable):
        # pylint: disable=W0201
        self.runnable = runnable
        yield self.prepare_status('started')

        name = self.runnable.kwargs.get('name')
        # if name was passed correctly, run the Avocado Asset utility
        if name is not None:
            asset_hash = self.runnable.kwargs.get('asset_hash')
            algorithm = self.runnable.kwargs.get('algorithm')
            locations = self.runnable.kwargs.get('locations')
            expire = self.runnable.kwargs.get('expire')
            if expire is not None:
                expire = data_structures.time_to_seconds(str(expire))

            cache_dirs = self.runnable.config.get('datadir.paths.cache_dirs')
            if cache_dirs is None:
                cache_dirs = settings.as_dict().get('datadir.paths.cache_dirs')

            # let's spawn it to another process to be able to update the
            # status messages and avoid the Asset to lock this process
            queue = SimpleQueue()
            process = Process(target=self._fetch_asset,
                              args=(name, asset_hash, algorithm, locations,
                                    cache_dirs, expire, queue))
            process.start()

            while queue.empty():
                time.sleep(RUNNER_RUN_STATUS_INTERVAL)
                yield self.prepare_status('running')

            output = queue.get()
            result = output['result']
            stdout = output['stdout']
            stderr = output['stderr']
        else:
            # Otherwise, log the missing package name
            result = 'error'
            stdout = ''
            stderr = ('At least name should be passed as kwargs using'
                      ' name="uri".')

        yield self.prepare_status('running',
                                  {'type': 'stdout',
                                   'log': stdout.encode()})
        yield self.prepare_status('running',
                                  {'type': 'stderr',
                                   'log': stderr.encode()})
        yield self.prepare_status('finished', {'result': result})
예제 #21
0
    def mapping(fun, args_list, processes):
        ans = [None] * len(args_list)
        q = SimpleQueue()
        for batch_start in range(0, len(args_list), processes):
            ps = []
            for i in range(batch_start,
                           min(batch_start + processes, len(args_list))):
                p = Process(target=MultiprocessingWithoutPipe.work,
                            args=(fun, i, q, args_list[i]))
                p.start()
                ps.append(p)

            while not q.empty():
                num, ret = q.get()
                ans[num] = ret

            for p in ps:
                p.join()

        while not q.empty():
            num, ret = q.get()
            ans[num] = ret

        return ans
예제 #22
0
class SafeQueue(object):
    __thread_pool = SingletonThreadPool()

    def __init__(self, *args, **kwargs):
        self._q = SimpleQueue(*args, **kwargs)

    def empty(self):
        return self._q.empty()

    def get(self):
        return self._q.get()

    def put(self, obj):
        # make sure the block put is done in the thread pool i.e. in the background
        SafeQueue.__thread_pool.get().apply_async(self._q.put, args=(obj, ))
예제 #23
0
def handle_sub_commit(css, sids, sub_flags):
    print("***Config updated --> Commit")
    processes = []
    q = SimpleQueue()
    for sid in sids:
        print("confirm-commit for subid {}".format(sid))
        p = Process(target=confirm_commit, args=(sid, q))
        processes.append(p)
        p.start()
    for process in processes:
        process.join()
    if not q.empty():
        sys.stderr.write("Confirm commit failed:\n{}\n".format('\n'.join(
            list(q.queue))))
    cdb.sync_subscription_socket(css, cdb.DONE_PRIORITY)
예제 #24
0
 def test_meta(self):
     # Test feature two.
     cry = Crypto('pbkdf2', self.conf, self.metadata)
     ioc = OrderedDict()
     ioc['ip-dst'] = "192.168.0.0"
     ioc['url'] = "test.com"
     rule = cry.create_rule(ioc, "Hello, this is the message!")
     rule['salt'] = b64decode(rule['salt'])
     rule['nonce'] = b64decode(rule['nonce'])
     rule['attributes'] = rule['attributes'].split('||')
     rule['ciphertext-check'] = b64decode(rule['ciphertext-check'])
     rule['ciphertext'] = b64decode(rule['ciphertext'])
     queue = SimpleQueue()
     cry.match(ioc, rule, queue)
     self.assertTrue(not queue.empty())
예제 #25
0
def learning_loop(exit_flag: mp.Value,
                  gvfs: Sequence[GTDLearner],
                  main2gvf: mp.SimpleQueue,
                  gvf2plot: mp.SimpleQueue):
    action, action_prob, obs, x = None, None, None, None

    # get first state
    while exit_flag.value == 0 and obs is None:
        while exit_flag.value == 0 and main2gvf.empty():
            time.sleep(0.001)
        if exit_flag.value == 0:
            action, action_prob, obs, x = main2gvf.get()

    # main loop
    while exit_flag.value == 0:
        while exit_flag.value == 0 and main2gvf.empty():
            time.sleep(0.001)
        if exit_flag.value:
            break

        # get data from servos
        actionp, action_probp, obsp, xp = main2gvf.get()

        # update weights
        for g in gvfs:
            g.update(action, action_prob, obs, obsp, x, xp)

        # send data to plots
        data = [[obs]] + [g.data(x, obs, action, xp, obsp) for g in gvfs]
        gvf2plot.put(data)

        # go to next state
        obs = obsp
        x = xp
        action = actionp
        action_prob = action_probp
예제 #26
0
def plotting_loop(exit_flag: mp.Value,
                  gvf2plot: mp.SimpleQueue,
                  plots: Sequence[Plot]):
    while exit_flag.value == 0:
        if locks:
            print('plot gp a 1 a')
            gplock.acquire()
            print('plot gp a 1 b')
        while exit_flag.value == 0 and gvf2plot.empty():
            if locks:
                print('plot gp r 1 a')
                gplock.release()
                print('plot gp r 1 b')
            time.sleep(0.001)
            if locks:
                print('plot gp a 2 a')
                gplock.acquire()
                print('plot gp a 2 b')

        if locks:
            print('plot gp r 2 a')
            gplock.release()
            print('plot gp r 2 b')
        if exit_flag.value:
            break

        if locks:
            print('plot gp a 3 a')
            gplock.acquire()
            print('plot gp a 3 b')
        d = gvf2plot.get()
        if locks:
            print('plot gp r 3 a')
            gplock.release()
            print('plot gp r 3 b')

        for plot, data in zip(plots, d):
            plot.update(data)

    for plot in plots:
        try:
            index = np.arange(len(plot.y[0]))
            np.savetxt(f"{plot.title}.csv",
                       np.column_stack(sum(((np.asarray(y),) for y in plot.y),
                                           (index,))),
                       delimiter=',')
        except ValueError:
            continue
예제 #27
0
def plotting_loop(exit_flag: mp.Value, gvf2plot: mp.SimpleQueue,
                  plots: Sequence[Plot]):
    while exit_flag.value == 0:
        if locks:
            print('plot gp a 1 a')
            gplock.acquire()
            print('plot gp a 1 b')
        while exit_flag.value == 0 and gvf2plot.empty():
            if locks:
                print('plot gp r 1 a')
                gplock.release()
                print('plot gp r 1 b')
            time.sleep(0.001)
            if locks:
                print('plot gp a 2 a')
                gplock.acquire()
                print('plot gp a 2 b')

        if locks:
            print('plot gp r 2 a')
            gplock.release()
            print('plot gp r 2 b')
        if exit_flag.value:
            break

        if locks:
            print('plot gp a 3 a')
            gplock.acquire()
            print('plot gp a 3 b')
        d = gvf2plot.get()
        if locks:
            print('plot gp r 3 a')
            gplock.release()
            print('plot gp r 3 b')

        for plot, data in zip(plots, d):
            plot.update(data)

    for plot in plots:
        try:
            index = np.arange(len(plot.y[0]))
            np.savetxt(f"{plot.title}.csv",
                       np.column_stack(
                           sum(((np.asarray(y), ) for y in plot.y),
                               (index, ))),
                       delimiter=',')
        except ValueError:
            continue
예제 #28
0
def multiprocess() -> Dict[str, Tuple[int, float]]:
    from multiprocessing import Process, SimpleQueue
    queue = SimpleQueue()
    procs: List[Process] = []
    for url in sites:
        p = Process(target=lambda url, q: q.put(fetch_site(url)),
                    args=(url, queue))
        p.start()
        procs.append(p)
    for p in procs:
        p.join()
    result = {}
    while not queue.empty():
        url, size, duration = queue.get()
        result[url] = size, duration
    return result
예제 #29
0
class NodeBFS:
    def __init__(self):
        self.node_queue = SimpleQueue()

    def __call__(self, visitor, node, final_suffix_offset=0):
        self.final_suffix_offset = final_suffix_offset
        self.node_queue.put(node)
        self.bfs(visitor, node)

    def bfs(self, visitor, node):
        while not self.node_queue.empty():
            node = self.node_queue.get()
            visitor.visit(node, self.final_suffix_offset)
            if node.children != None:
                for child in node.children:
                    self.node_queue.put(node.children[child])
예제 #30
0
파일: main.py 프로젝트: repen/92930809
def _main():
    q = SimpleQueue()
    data = get_data()
    processes = []

    for e, func in enumerate( functions , start=1):
        processes.append( Process(target=wrapper_run, args =  (q, func, data, e ) ) )

    [x.start() for x in processes]
    [x.join() for x in processes]
    results = []
    while True:
        items = q.get()
        results.append( items )
        if q.empty():
            break
    save( results )
예제 #31
0
    def run(self, runnable):
        yield messages.StartedMessage.get()

        if not runnable.uri:
            reason = "uri identifying the podman image is required"
            yield messages.FinishedMessage.get("error", reason)
        else:
            queue = SimpleQueue()
            process = Process(target=self._run_podman_pull, args=(runnable.uri, queue))
            process.start()
            while queue.empty():
                time.sleep(RUNNER_RUN_STATUS_INTERVAL)
                yield messages.RunningMessage.get()

            output = queue.get()
            result = output.pop("result")
            yield messages.FinishedMessage.get(result, **output)
예제 #32
0
    def run(self, runnable):
        yield messages.StartedMessage.get()

        if not runnable.uri:
            reason = "uri identifying the ansible module is required"
            yield messages.FinishedMessage.get("error", reason)
            return

        queue = SimpleQueue()
        process = Process(target=self._run_ansible_module,
                          args=(runnable, queue))
        process.start()
        yield from self.running_loop(lambda: not queue.empty())

        status = queue.get()
        yield messages.StdoutMessage.get(status["stdout"])
        yield messages.StderrMessage.get(status["stderr"])
        yield messages.FinishedMessage.get(status["result"])
예제 #33
0
파일: package.py 프로젝트: mxie91/avocado
    def run(self, runnable):
        # pylint: disable=W0201
        self.runnable = runnable
        yield messages.StartedMessage.get()
        # check if there is a valid 'action' argument
        cmd = self.runnable.kwargs.get("action", "install")
        # avoid invalid arguments
        if cmd not in ["install", "check", "remove"]:
            stderr = (f"Invalid action {cmd}. Use one of 'install', 'check' "
                      f"or 'remove'")
            yield messages.StderrMessage.get(stderr.encode())
            yield messages.FinishedMessage.get("error")
            return

        package = self.runnable.kwargs.get("name")
        # if package was passed correctly, run avocado-software-manager
        if package is not None:
            # let's spawn it to another process to be able to update the
            # status messages and avoid the software-manager to lock this
            # process
            queue = SimpleQueue()
            process = Process(target=self._run_software_manager,
                              args=(cmd, package, queue))
            process.start()

            while queue.empty():
                time.sleep(RUNNER_RUN_STATUS_INTERVAL)
                yield messages.RunningMessage.get()

            output = queue.get()
            result = output["result"]
            stdout = output["stdout"]
            stderr = output["stderr"]
        else:
            # Otherwise, log the missing package name
            result = "error"
            stdout = ""
            stderr = (
                'Package name should be passed as kwargs using name="package_name".'
            )

        yield messages.StdoutMessage.get(stdout.encode())
        yield messages.StderrMessage.get(stderr.encode())
        yield messages.FinishedMessage.get(result)
예제 #34
0
def plotting_loop(exit_flag: mp.Value,
                  gvf2plot: mp.SimpleQueue,
                  plots: Sequence[Plot]):

    while exit_flag.value == 0:
        while exit_flag.value == 0 and gvf2plot.empty():
            time.sleep(0.001)
        if exit_flag.value:
            break
        data = gvf2plot.get()

        for plot in plots:
            plot.update(data)

    for plot in plots:
        index = np.arange(len(plot.y[0]))
        np.savetxt(f"{plot.title}.csv",
                   sum(((np.asarray(y),) for y in plot.y), (index,)),
                   delimiter=',')
예제 #35
0
def spawn_import_clients(options, files_info):
    # Spawn one reader process for each db.table, as well as many client processes
    task_queue = SimpleQueue()
    error_queue = SimpleQueue()
    exit_event = multiprocessing.Event()
    interrupt_event = multiprocessing.Event()
    errors = []
    reader_procs = []
    client_procs = []

    parent_pid = os.getpid()
    signal.signal(signal.SIGINT, lambda a, b: abort_import(a, b, parent_pid, exit_event, task_queue, client_procs, interrupt_event))

    try:
        progress_info = []
        rows_written = multiprocessing.Value(ctypes.c_longlong, 0)

        for i in xrange(options["clients"]):
            client_procs.append(multiprocessing.Process(target=client_process,
                                                        args=(options["host"],
                                                              options["port"],
                                                              options["auth_key"],
                                                              task_queue,
                                                              error_queue,
                                                              rows_written,
                                                              options["force"],
                                                              options["durability"])))
            client_procs[-1].start()

        for file_info in files_info:
            progress_info.append((multiprocessing.Value(ctypes.c_longlong, -1), # Current lines/bytes processed
                                  multiprocessing.Value(ctypes.c_longlong, 0))) # Total lines/bytes to process
            reader_procs.append(multiprocessing.Process(target=table_reader,
                                                        args=(options,
                                                              file_info,
                                                              task_queue,
                                                              error_queue,
                                                              progress_info[-1],
                                                              exit_event)))
            reader_procs[-1].start()

        # Wait for all reader processes to finish - hooray, polling
        while len(reader_procs) > 0:
            time.sleep(0.1)
            # If an error has occurred, exit out early
            while not error_queue.empty():
                exit_event.set()
                errors.append(error_queue.get())
            reader_procs = [proc for proc in reader_procs if proc.is_alive()]
            update_progress(progress_info)

        # Wait for all clients to finish
        alive_clients = sum([client.is_alive() for client in client_procs])
        for i in xrange(alive_clients):
            task_queue.put(StopIteration())

        while len(client_procs) > 0:
            time.sleep(0.1)
            client_procs = [client for client in client_procs if client.is_alive()]

        # If we were successful, make sure 100% progress is reported
        if len(errors) == 0 and not interrupt_event.is_set():
            print_progress(1.0)

        def plural(num, text):
            return "%d %s%s" % (num, text, "" if num == 1 else "s")

        # Continue past the progress output line
        print("")
        print("%s imported in %s" % (plural(rows_written.value, "row"),
                                     plural(len(files_info), "table")))
    finally:
        signal.signal(signal.SIGINT, signal.SIG_DFL)

    if interrupt_event.is_set():
        raise RuntimeError("Interrupted")

    if len(errors) != 0:
        # multiprocessing queues don't handling tracebacks, so they've already been stringified in the queue
        for error in errors:
            print("%s" % error[1], file=sys.stderr)
            if options["debug"]:
                print("%s traceback: %s" % (error[0].__name__, error[2]), file=sys.stderr)
            if len(error) == 4:
                print("In file: %s" % error[3], file=sys.stderr)
        raise RuntimeError("Errors occurred during import")
예제 #36
0
def scan_regionset(regionset, options):
    """ This function scans all te region files in a regionset object
    and fills the ScannedRegionFile obj with the results
    """

    total_regions = len(regionset.regions)
    total_chunks = 0
    corrupted_total = 0
    wrong_total = 0
    entities_total = 0
    too_small_total = 0
    unreadable = 0

    # init progress bar
    if not options.verbose:
        pbar = progressbar.ProgressBar(
            widgets=['Scanning: ', FractionWidget(), ' ', progressbar.Percentage(), ' ', progressbar.Bar(left='[',right=']'), ' ', progressbar.ETA()],
            maxval=total_regions)

    # queue used by processes to pass finished stuff
    q = SimpleQueue()
    pool = multiprocessing.Pool(processes=options.processes,
            initializer=_mp_pool_init,initargs=(regionset,options,q))

    if not options.verbose:
        pbar.start()

    # start the pool
    # Note to self: every child process has his own memory space,
    # that means every obj recived by them will be a copy of the
    # main obj
    result = pool.map_async(multithread_scan_regionfile, regionset.list_regions(None), max(1,total_regions//options.processes))

    # printing status
    region_counter = 0

    while not result.ready() or not q.empty():
        time.sleep(0.01)
        if not q.empty():
            r = q.get()
            if r == None: # something went wrong scanning this region file
                          # probably a bug... don't know if it's a good
                          # idea to skip it
                continue
            if not isinstance(r,world.ScannedRegionFile):
                raise ChildProcessException(r)
            else:
                corrupted, wrong, entities_prob, shared_offset, num_chunks = r.get_counters()
                filename = r.filename
                # the obj returned is a copy, overwrite it in regionset
                regionset[r.get_coords()] = r
                corrupted_total += corrupted
                wrong_total += wrong
                total_chunks += num_chunks
                entities_total += entities_prob
                if r.status == world.REGION_TOO_SMALL:
                    too_small_total += 1
                elif r.status == world.REGION_UNREADABLE:
                    unreadable += 1
                region_counter += 1
                if options.verbose:
                  if r.status == world.REGION_OK:
                    stats = "(c: {0}, w: {1}, tme: {2}, so: {3}, t: {4})".format( corrupted, wrong, entities_prob, shared_offset, num_chunks)
                  elif r.status == world.REGION_TOO_SMALL:
                    stats = "(Error: not a region file)"
                  elif r.status == world.REGION_UNREADABLE:
                    stats = "(Error: unreadable region file)"
                  print("Scanned {0: <12} {1:.<43} {2}/{3}".format(filename, stats, region_counter, total_regions))
                else:
                    pbar.update(region_counter)

    if not options.verbose: pbar.finish()

    regionset.scanned = True
예제 #37
0
class AsyncScanner(object):
    """ Class to derive all the scanner classes from.

    To implement a scanner you have to override:
    update_str_last_scanned()
    Use try-finally to call terminate, if not processes will be
    hanging in the background
     """
    def __init__(self, data_structure, processes, scan_function, init_args,
                 _mp_init_function):
        """ Init the scanner.

        data_structure is a world.DataSet
        processes is the number of child processes to use
        scan_function is the function to use for scanning
        init_args are the arguments passed to the init function
        _mp_init_function is the function used to init the child processes
        """
        assert(isinstance(data_structure, world.DataSet))
        self.data_structure = data_structure
        self.list_files_to_scan = data_structure._get_list()
        self.processes = processes
        self.scan_function = scan_function

        # Queue used by processes to pass results
        self.queue = SimpleQueue()
        init_args.update({'queue': self.queue})
        # NOTE TO SELF: initargs doesn't handle kwargs, only args!
        # Pass a dict with all the args
        self.pool = multiprocessing.Pool(processes=processes,
                initializer=_mp_init_function,
                initargs=(init_args,))

        # Recommended time to sleep between polls for results
        self.SCAN_START_SLEEP_TIME = 0.001
        self.SCAN_MIN_SLEEP_TIME = 1e-6
        self.SCAN_MAX_SLEEP_TIME = 0.1
        self.scan_sleep_time = self.SCAN_START_SLEEP_TIME
        self.queries_without_results = 0
        self.last_time = time()
        self.MIN_QUERY_NUM = 1
        self.MAX_QUERY_NUM = 5

        # Holds a friendly string with the name of the last file scanned
        self._str_last_scanned = None

    def scan(self):
        """ Launch the child processes and scan all the files. """
        
        logging.debug("########################################################")
        logging.debug("########################################################")
        logging.debug("Starting scan in: " + str(self))
        logging.debug("########################################################")
        logging.debug("########################################################")
        # Tests indicate that smaller amount of jobs per worker make all type
        # of scans faster
        jobs_per_worker = 5
        #jobs_per_worker = max(1, total_files // self.processes
        self._results = self.pool.map_async(self.scan_function,
                                            self.list_files_to_scan,
                                            jobs_per_worker)
                                            
        # No more tasks to the pool, exit the processes once the tasks are done
        self.pool.close()

        # See method
        self._str_last_scanned = ""

    def get_last_result(self):
        """ Return results of last file scanned. """

        q = self.queue
        ds = self.data_structure
        if not q.empty():
            d = q.get()
            if isinstance(d, tuple):
                self.raise_child_exception(d)
            # Copy it to the father process
            ds._replace_in_data_structure(d)
            ds._update_counts(d)
            self.update_str_last_scanned(d)
            # Got result! Reset it!
            self.queries_without_results = 0
            return d
        else:
            # Count amount of queries without result
            self.queries_without_results += 1
            return None

    def terminate(self):
        """ Terminate the pool, this will exit no matter what.
        """
        self.pool.terminate()

    def raise_child_exception(self, exception_tuple):
        """ Raises a ChildProcessException using the info
        contained in the tuple returned by the child process. """
        e = exception_tuple
        raise ChildProcessException(e[0], e[1][0], e[1][1], e[1][2])

    def update_str_last_scanned(self):
        """ Updates the string that represents the last file scanned. """
        raise NotImplemented

    def sleep(self):
        """ Sleep waiting for results.

        This method will sleep less when results arrive faster and
        more when they arrive slower.
        """
        # If the query number is outside of our range...
        if not ((self.queries_without_results < self.MAX_QUERY_NUM) &
                (self.queries_without_results > self.MIN_QUERY_NUM)):
            # ... increase or decrease it to optimize queries
            if (self.queries_without_results < self.MIN_QUERY_NUM):
                self.scan_sleep_time *= 0.5
            elif (self.queries_without_results > self.MAX_QUERY_NUM):
                self.scan_sleep_time *= 2.0
            # and don't go farther than max/min
            if self.scan_sleep_time > self.SCAN_MAX_SLEEP_TIME:
                logging.debug("Setting sleep time to MAX")
                self.scan_sleep_time = self.SCAN_MAX_SLEEP_TIME
            elif self.scan_sleep_time < self.SCAN_MIN_SLEEP_TIME:
                logging.debug("Setting sleep time to MIN")
                self.scan_sleep_time = self.SCAN_MIN_SLEEP_TIME

        # Log how it's going
        logging.debug("")
        logging.debug("Nº of queries without result: " + str(self.queries_without_results))
        logging.debug("Current sleep time: " + str(self.scan_sleep_time))
        logging.debug("Time between calls to sleep(): " + str(time() - self.last_time))
        self.last_time = time()

        # Sleep, let the other processes do their job
        sleep(self.scan_sleep_time)

    @property
    def str_last_scanned(self):
        """ A friendly string with last scanned thing. """
        return self._str_last_scanned if self._str_last_scanned \
            else "Scanning..."

    @property
    def finished(self):
        """ Finished the operation. The queue could have elements """
        return self._results.ready() and self.queue.empty()

    @property
    def results(self):
        """ Yield all the results from the scan.

        This is the simpler method to control the scanning process,
        but also the most sloppy. If you want to closely control the
        scan process (for example cancel the process in the middle,
        whatever is happening) use get_last_result().

        for result in scanner.results:
            # do things
        """

        q = self.queue
        T = self.SCAN_WAIT_TIME
        while not q.empty() or not self.finished:
            sleep(T)
            if not q.empty():
                d = q.get()
                if isinstance(d, tuple):
                    self.raise_child_exception(d)
                # Overwrite it in the data dict
                self.replace_in_data_structure(d)
                yield d

    def __len__(self):
        return len(self.data_structure)
예제 #38
0
파일: _export.py 프로젝트: AtnNn/rethinkdb
def run_clients(options, workingDir, db_table_set):
    # Spawn one client for each db.table, up to options.clients at a time
    exit_event = multiprocessing.Event()
    processes = []
    error_queue = SimpleQueue()
    interrupt_event = multiprocessing.Event()
    sindex_counter = multiprocessing.Value(ctypes.c_longlong, 0)
    hook_counter = multiprocessing.Value(ctypes.c_longlong, 0)
    
    signal.signal(signal.SIGINT, lambda a, b: abort_export(a, b, exit_event, interrupt_event))
    errors = []

    try:
        progress_info = []
        arg_lists = []
        for db, table in db_table_set:
            
            tableSize = int(options.retryQuery("count", query.db(db).table(table).info()['doc_count_estimates'].sum()))
            
            progress_info.append((multiprocessing.Value(ctypes.c_longlong, 0),
                                  multiprocessing.Value(ctypes.c_longlong, tableSize)))
            arg_lists.append((db, table,
                              workingDir,
                              options,
                              error_queue,
                              progress_info[-1],
                              sindex_counter,
                              hook_counter,
                              exit_event,
                              ))


        # Wait for all tables to finish
        while processes or arg_lists:
            time.sleep(0.1)

            while not error_queue.empty():
                exit_event.set() # Stop immediately if an error occurs
                errors.append(error_queue.get())

            processes = [process for process in processes if process.is_alive()]

            if len(processes) < options.clients and len(arg_lists) > 0:
                newProcess = multiprocessing.Process(target=export_table, args=arg_lists.pop(0))
                newProcess.start()
                processes.append(newProcess)

            update_progress(progress_info, options)

        # If we were successful, make sure 100% progress is reported
        # (rows could have been deleted which would result in being done at less than 100%)
        if len(errors) == 0 and not interrupt_event.is_set() and not options.quiet:
            utils_common.print_progress(1.0, indent=4)

        # Continue past the progress output line and print total rows processed
        def plural(num, text, plural_text):
            return "%d %s" % (num, text if num == 1 else plural_text)

        if not options.quiet:
            print("\n    %s exported from %s, with %s, and %s" %
                  (plural(sum([max(0, info[0].value) for info in progress_info]), "row", "rows"),
                   plural(len(db_table_set), "table", "tables"),
                   plural(sindex_counter.value, "secondary index", "secondary indexes"),
                   plural(hook_counter.value, "hook function", "hook functions")
            ))
    finally:
        signal.signal(signal.SIGINT, signal.SIG_DFL)

    if interrupt_event.is_set():
        raise RuntimeError("Interrupted")

    if len(errors) != 0:
        # multiprocessing queues don't handle tracebacks, so they've already been stringified in the queue
        for error in errors:
            print("%s" % error[1], file=sys.stderr)
            if options.debug:
                print("%s traceback: %s" % (error[0].__name__, error[2]), file=sys.stderr)
        raise RuntimeError("Errors occurred during export")
예제 #39
0
def learning_loop(exit_flag: mp.Value,
                  gvfs: Sequence[Sequence[GTDLearner]],
                  main2gvf: mp.SimpleQueue,
                  gvf2plot: mp.SimpleQueue,
                  parsrs: List[Callable]):
    action, action_prob, obs, x = None, None, None, None

    # get first state
    while exit_flag.value == 0 and obs is None:
        while exit_flag.value == 0 and main2gvf.empty():
            time.sleep(0.01)
        if exit_flag.value == 0:
            if locks:
                print('gvf gm a 1 a')
                gmlock.acquire()
                print('gvf gm a 1 b')
            action, action_prob, obs, x = main2gvf.get()
            if locks:
                print('gvf gm r 1 a')
                gmlock.release()
                print('gvf gm r 1 b')

    # main loop
    # tt = 0
    # ts = []
    while exit_flag.value == 0:
        # ts.append(time.time() - tt) if tt > 0 else None
        # print(np.mean(ts))
        # tt = time.time()
        if locks:
            print('gvf gm a 2 a')
            gmlock.acquire()
            print('gvf gm a 2 b')
        while exit_flag.value == 0 and main2gvf.empty():
            if locks:
                print('gvf gm r 2 a')
                gmlock.release()
                print('gvf gm r 2 b')
            time.sleep(0.01)
            if locks:
                print('gvf gm a 3 a')
                gmlock.acquire()
                print('gvf gm a 3 b')
        if locks:
            print('gvf gm r 3 a')
            gmlock.release()
            print('gvf gm r 3 b')
        if exit_flag.value:
            break

        # get data from servos
        if locks:
            print('gvf gm a 4 a')
            gmlock.acquire()
            print('gvf gm a 4 b')
        actionp, action_probp, obsp, xp = main2gvf.get()
        if locks:
            print('gvf gm r 4 a')
            gmlock.release()
            print('gvf gm r 4 b')
        # update weights
        for gs, xi, xpi in zip(gvfs, x, xp):
            for g in gs:
                g.update(action, action_prob, obs, obsp, xi, xpi)

        # send data to plots
        gdata = [g.data(xi, obs, action, xpi, obsp)
                 for gs, xi, xpi in zip(gvfs, x, xp)
                 for g in gs]

        data = dict(ChainMap(*gdata))
        data['obs'] = obs
        data['x'] = x
        data = [parse(data) for parse in parsrs]
        if locks:
            print('gvf gp a 1 a')
            gplock.acquire()
            print('gvf gp a 1 b')
        # data = np.copy(data)
        gvf2plot.put(data)
        if locks:
            print('gvf gp r 1 a')
            gplock.release()
            print('gvf gp r 1 b')

        # go to next state
        obs = obsp
        x = xp
        action = actionp
        action_prob = action_probp

    print('Done learning!')
예제 #40
0
파일: client.py 프로젝트: ATRAN2/Futami
class InternalClient(Client):
    """This client is a fake client which is responsible for firing off
    all messages from the update notification side, and handling the
    routing of those messages to users watching.

    It does not have a socket, so it should not be included in the
    server's clients dictionary.
    """

    def __init__(self, server, nickname, user, host='localhost'):
        self.server = server
        self.nickname = nickname
        self.realname = nickname
        self.user = user
        self.host = host

        self._readbuffer = ""
        self._writebuffer = ""
        self.request_queue = SimpleQueue()
        self.response_queue = SimpleQueue()

        # dict of board => list of users
        self.board_watchers = defaultdict(list)

        # dict of board, thread => list of users
        self.thread_watchers = defaultdict(lambda: defaultdict(list))

        Process(
            target=Ami,
            name='immediate api worker',
            args=(self.request_queue, self.response_queue)
        ).start()

    def loop_hook(self):
        while not self.response_queue.empty():
            result = self.response_queue.get()

            # Handle exceptions in-band from child workers here.
            if isinstance(result, StoredException):
                print(result.traceback)
                raise RuntimeError(
                    "Exception caught from worker '{}', see above for exception details".format(
                        result.process,
                ))

            logger.debug("read from response queue {}".format(result))

            send_as = "/{}/{}".format(result.board, result.post_no)

            # Initial channel loads have identifiers, use them to find out
            # where to go
            if result.identifier:
                client, channel, target = result.identifier
                client = self.server.get_client(client)
                logger.debug("initial channel load, using identitifier info: sending to {} on {}".format(client, channel))

                if isinstance(target, BoardTarget):
                    self._send_message(
                        client, channel, result.summary,
                        sending_nick=send_as,
                    )
                    continue
                elif isinstance(target, ThreadTarget):
                    self._send_message(
                        client, channel, result.comment,
                        sending_nick=send_as,
                    )
                    continue

            if result.is_reply:  # Send to thread channel
                channel = "#/{}/{}".format(result.board, result.reply_to)
                logger.debug("sending reply to channel {}".format(channel))

                # TODO: Remove users who have disconnected from the server here
                for client in self.thread_watchers[result.board][result.reply_to]:
                    logger.debug("sending reply to {}".format(client))
                    self._send_message(
                        client, channel, result.comment,
                        sending_nick=send_as,
                    )
            else:
                channel = "#/{}/".format(result.board)
                logger.debug("sending thread update to channel {}".format(channel))

                # TODO: Remove users who have disconnected from the server here
                for client in self.board_watchers[result.board]:
                    self._send_message(
                        client, channel, result.summary,
                        sending_nick=send_as,
                    )

    def _parse_prefix(self, prefix):
        m = re.search(
            ":(?P<nickname>[^!]*)!(?P<username>[^@]*)@(?P<host>.*)",
            prefix
        )
        return m.groupdict()

    @property
    def socket(self):
        raise AttributeError('InternalClients have no sockets')

    def message(self, message):
        pass
        # prefix, message = message.split(" ", 1)

        # prefix = self._parse_prefix(prefix)

        # self.sending_client = self.server.get_client(prefix['nickname'])

        # self._readbuffer = message + '\r\n'
        # self._parse_read_buffer()

    def client_joined(self, client, channel):
        logger.debug("InternalClient handling {} joined {}".format(client, channel))

        channel_registration_map = {
            r'#/(.+)/$': self._client_register_board,
            r'#/(.+)/(\d+)$': self._client_register_thread,
        }

        matched_registration = False

        for regex, register_method in channel_registration_map.items():
            m = re.match(regex, channel.name)
            if m:
                register_method(client, channel, *m.groups())
                matched_registration = True
                break

        if not matched_registration:
            self._send_message(
                client, channel.name,
                "This channel ({}) doesn't look like a board. Nothing will happen in this channel.".format(channel.name)
            )
            return

    def _handle_command(self, command, arguments):
        # sending_client = self.sending_client
        # self.sending_client = None

        # Add handling here for actual input from users other than joins
        pass

    def _client_register_board(self, client, channel, board):
        logger.debug("registering to board: {}, {}, {}".format(client, channel, board))

        slash_board = '/{}/'.format(board)
        self._send_message(
            client, channel.name,
            "Welcome to {}, loading threads...".format(slash_board),
            sending_nick=slash_board,
        )

        target = BoardTarget(board)

        self.request_queue.put(
            SubscriptionUpdate.make(
                action=Action.LoadAndFollow,
                target=target,
                payload=(client.nickname, channel.name, target),
        ))

        self.board_watchers[board].append(client)

    def _client_register_thread(self, client, channel, board, thread):
        logging.debug("registering to thread: {}, {}, {}, {}".format(client, channel, board, thread))

        slash_board_thread = '/{}/{}'.format(board, thread)

        self._send_message(
            client, channel.name,
            "Welcome to >>>{}, loading posts...".format(slash_board_thread),
            sending_nick=slash_board_thread,
        )

        target = ThreadTarget(board, thread)

        self.request_queue.put(
            SubscriptionUpdate.make(
                action=Action.LoadAndFollow,
                target=target,
                payload=(client.nickname, channel.name, target),
        ))

        # Thread reply_tos are ints when they come back from the API
        self.thread_watchers[board][int(thread)].append(client)

    def _send_message(self, client, channel, message, sending_nick=None):
        if sending_nick:
            real_nick = self.nickname
            self.nickname = sending_nick

        client.message(
            ":{} PRIVMSG {} :{}".format(
                self.prefix,
                channel,
                message,
            )
        )

        if sending_nick:
            self.nickname = real_nick
예제 #41
0
def run_clients(options, db_table_set):
    # Spawn one client for each db.table
    exit_event = multiprocessing.Event()
    processes = []
    error_queue = SimpleQueue()
    interrupt_event = multiprocessing.Event()
    sindex_counter = multiprocessing.Value(ctypes.c_longlong, 0)

    signal.signal(signal.SIGINT, lambda a, b: abort_export(a, b, exit_event, interrupt_event))
    errors = [ ]

    try:
        sizes = get_all_table_sizes(options["host"], options["port"], options["auth_key"], db_table_set)

        progress_info = []

        arg_lists = []
        for db, table in db_table_set:
            progress_info.append((multiprocessing.Value(ctypes.c_longlong, 0),
                                  multiprocessing.Value(ctypes.c_longlong, sizes[(db, table)])))
            arg_lists.append((options["host"],
                              options["port"],
                              options["auth_key"],
                              db, table,
                              options["directory_partial"],
                              options["fields"],
                              options["delimiter"],
                              options["format"],
                              error_queue,
                              progress_info[-1],
                              sindex_counter,
                              exit_event))


        # Wait for all tables to finish
        while len(processes) > 0 or len(arg_lists) > 0:
            time.sleep(0.1)

            while not error_queue.empty():
                exit_event.set() # Stop rather immediately if an error occurs
                errors.append(error_queue.get())

            processes = [process for process in processes if process.is_alive()]

            if len(processes) < options["clients"] and len(arg_lists) > 0:
                processes.append(multiprocessing.Process(target=export_table,
                                                         args=arg_lists.pop(0)))
                processes[-1].start()

            update_progress(progress_info)

        # If we were successful, make sure 100% progress is reported
        # (rows could have been deleted which would result in being done at less than 100%)
        if len(errors) == 0 and not interrupt_event.is_set():
            print_progress(1.0)

        # Continue past the progress output line and print total rows processed
        def plural(num, text, plural_text):
            return "%d %s" % (num, text if num == 1 else plural_text)

        print("")
        print("%s exported from %s, with %s" %
              (plural(sum([max(0, info[0].value) for info in progress_info]), "row", "rows"),
               plural(len(db_table_set), "table", "tables"),
               plural(sindex_counter.value, "secondary index", "secondary indexes")))
    finally:
        signal.signal(signal.SIGINT, signal.SIG_DFL)

    if interrupt_event.is_set():
        raise RuntimeError("Interrupted")

    if len(errors) != 0:
        # multiprocessing queues don't handling tracebacks, so they've already been stringified in the queue
        for error in errors:
            print("%s" % error[1], file=sys.stderr)
            if options["debug"]:
                print("%s traceback: %s" % (error[0].__name__, error[2]), file=sys.stderr)
        raise RuntimeError("Errors occurred during export")
예제 #42
0
class QueuePool(object):
	Process = QueueProcess

	def __init__(self, callback, pool_size=1, check_intervall=2):
		self.task_queue = SimpleQueue()
		self.result_queue = SimpleQueue()
		self._callback = callback
		self._pool = {}  # {process_name: process}
		self._tasks = {}  # {task_id: process_name}
		for _ in range(pool_size):
			process = self.Process(self.task_queue, self.result_queue)
			self._pool[process.name] = process
			process.start()
		# Check for progress periodically TODO: stop timer when queue is empty!
		self.timer = QTimer()
		self.timer.timeout.connect(self._check_for_results)
		self.timer.start(check_intervall * 1000)

	def _check_for_results(self):
		while not self.result_queue.empty():
			process_name, task_id, result_object, is_exception, is_ready = self.result_queue.get()
			if is_ready or is_exception:
				if task_id in self._tasks:
					del self._tasks[task_id]
			else:
				self._tasks[task_id] = process_name
			self._callback(task_id, result_object, is_exception, is_ready)

	def change_check_interval(self, new_interval_in_seconds):
		try:
			interval = float(new_interval_in_seconds)
		except ValueError:
			return
		self.timer.stop()
		self.timer.start(interval * 1000)

	def change_pool_size(self, new_pool_size):
		try:
			diff = int(new_pool_size) - len(self._pool)
		except ValueError:
			return
		if diff < 0:
			for _ in range(abs(diff)):
				process_name, process = self._pool.popitem()
				process.soft_interrupt.set()
		else:
			for _ in range(diff):
				process = QueueProcess(self.task_queue, self.result_queue, function=compute)
				self._pool[process.name] = process
				process.start()

	def add_task(self, task_id, *params):
		self.task_queue.put([task_id] + list(params))

	def cancel_task(self, task_id):
		process_name = self._tasks.get(task_id)
		if process_name is None:
			# task is not active, but it might be part of task_queue where it shall be removed from
			task_objects = []
			while not self.task_queue.empty():
				task_objects.append(self.task_queue.get())
			for obj in task_objects:
				if task_id != obj[0]:
					self.task_queue.put(obj)
			return
		process = self._pool.get(process_name)
		if process is None:
			# process might be already stopped -> ignore for now
			return
		process.hard_interrupt.set()

	def shutdown(self):
		for process in self._pool.values():
			process.hard_interrupt.set()
			self.task_queue.put(None)  # unblock queue

	def terminate(self):
		for process in self._pool.values():
			if process.exitcode is None:
				process.terminate()
예제 #43
0
        for tax in taxa_range:
            for mdl in models:
                for gbr in gene_branch_len:
                    for gstdv in gene_branch_stdev:
                        for sbr in species_branch_len:
                            for sstdv in species_branch_stdev:
                                for alp in alphas:
                                    for cat in category_range:
                                        for drp in drop_chances:
                                            for ndr in num_drops_range:
                                                for dup in duplication_chances:
                                                    for ndp in num_duplications_range:
                                                        arguments.append((grp, tax, mdl, gbr, gstdv, sbr, sstdv,
                                                                          alp, cat, drp, ndr, dup, ndp))

    arguments *= in_args.replicates
    broker_queue = SimpleQueue()
    broker = Process(target=broker_func, args=[broker_queue, in_args.output])
    broker.daemon = True
    broker.start()

    run_multicore_function(arguments, generate)

    os.remove("site_rates_info.txt")
    os.remove("site_rates.txt")

    while not broker_queue.empty():
        pass

    broker.terminate()