コード例 #1
0
    def getTaskProcess(self, args):
        while True:
            if self.has_work_left() > 0:
                try:
                    req = self.q_request.get(block=True, timeout=3)
                except:
                    continue
            else:
                threadname = multiprocess.current_process().name
                print threadname + '总任务关闭'
                break
            with self.lock:  #要保证该操作的原子性,进入critical area
                self.running = self.running + 1
#			self.lock.acquire()
            threadname = multiprocess.current_process().name

            print '进程' + threadname + '发起请求: '

            ans = self.do_job(self.task, req, threadname, args)
            #			ans = self.connectpool.getConnect(req)

            # 			self.lock.release()
            self.q_finish.put((req, ans))
            #			self.lock.acquire()
            with self.lock:
                self.running -= 1
            threadname = multiprocess.current_process().name

            print '进程' + threadname + '完成请求'
コード例 #2
0
	def do_job(self,args):

		if self.isThread==1:
			print threading.current_thread(),args
		else:
			print datetime.datetime.now()
			print multiprocess.current_process(),args
コード例 #3
0
def redirect_log(wkdir):
    """redirect stdout and stderr of forked worker to tmp wkdir"""
    import os, sys
    import multiprocess as mp
    # define stdout and stderr names
    
    stdout = os.path.join(wkdir, 'stdout')
    stderr = os.path.join(wkdir, 'stderr')
    _info(" stdout->[%s]" % stdout)
    _info(" stderr->[%s]" % stderr)

    # synch-ing log
    map(lambda x: x.flush(), (sys.stdout, sys.stderr))

    flags = os.O_CREAT | os.O_WRONLY
    fdout = os.open (stdout, flags)
    assert fdout>=0, \
        "problem child [%r] opening stdout" % mp.current_process()
    fileno = sys.stdout.fileno()
    os.close (fileno)
    os.dup2  (fdout, fileno)

    fderr = os.open (stderr, flags)
    assert fderr>=0, \
        "problem child [%r] opening stderr" % mp.current_process()
    fileno = sys.stderr.fileno()
    os.close (fileno)
    os.dup2  (fderr, fileno)
コード例 #4
0
	def getTaskProcess(self,args):
		while True:
			if self.has_work_left()>0:
				try:
					req = self.q_request.get(block=True,timeout=3)
				except:
					continue
			else:
				threadname=multiprocess.current_process().name
				print threadname+'总任务关闭'
				break
			with self.lock:				#要保证该操作的原子性,进入critical area
				self.running=self.running+1
#			self.lock.acquire()
			threadname=multiprocess.current_process().name

			print '进程'+threadname+'发起请求: '

			ans=self.do_job(self.task,req,threadname,args)
#			ans = self.connectpool.getConnect(req)

# 			self.lock.release()
			self.q_finish.put((req,ans))
#			self.lock.acquire()
			with self.lock:
				self.running-= 1
			threadname=multiprocess.current_process().name

	 		print '进程'+threadname+'完成请求'
コード例 #5
0
def function(a, game, fail, exc):
    try:
        if fail:
            raise ThisError("Child process %s failed!" % mp.current_process().name)
        else:       
            game.end_turn()
            print("Child process %s current player: " % mp.current_process().name)
    except:
        exc.put(ThisError("Child process %s failed!" % mp.current_process().name))
コード例 #6
0
ファイル: multi_test.py プロジェクト: flowersw/Hearthstone-AI
def function(a, game, fail, exc):
    try:
        if fail:
            raise ThisError("Child process %s failed!" %
                            mp.current_process().name)
        else:
            game.end_turn()
            print("Child process %s current player: " %
                  mp.current_process().name)
    except:
        exc.put(
            ThisError("Child process %s failed!" % mp.current_process().name))
コード例 #7
0
	def run(self):
#死循环,从而让创建的线程在一定条件下关闭退出
		while True:
			try:
				threadname=multiprocess.current_process().name
				print '正在执行任务'+threadname
				do,args = self.work_queue.get(block=False)#任务异步出队,Queue内部实现了同步机制
				print '正在读取数据'+threadname
				do(args)
				print '结束当前操作'+threadname
			except:
				threadname=multiprocess.current_process().name
				print '没有任务检测到,关闭多余资源'+threadname
				break	
コード例 #8
0
def RebuildProxy(func, token, serializer, kwds):
    '''
    Function used for unpickling proxy objects.

    If possible the shared object is returned, or otherwise a proxy for it.
    '''
    server = getattr(current_process(), '_manager_server', None)

    if server and server.address == token.address:
        return server.id_to_obj[token.id][0]
    else:
        incref = (kwds.pop('incref', True)
                  and not getattr(current_process(), '_inheriting', False))
        return func(token, serializer, incref=incref, **kwds)
コード例 #9
0
ファイル: achilles_main.py プロジェクト: adpena/achilles
def setupGlobals():
    multiprocess.current_process().authkey = b"176778741"

    manager = multiprocess.Manager()

    globals_dict = {"OUTPUT_QUEUE": manager.Queue()}
    return globals_dict
コード例 #10
0
 def initializer(self, connection_strings):
     from multiprocess import Pool, Queue, Manager
     import multiprocess as mp
     from dronekit import LocationGlobal, LocationGlobalRelative, LocationLocal
     from MyVehicle import MyVehicle
     from Simulation import Simulation
     from Const import Const
     import global_sim
     import fitness
     print('initializing process {}'.format(mp.current_process().name))
     c = connection_strings.get()
     print('Connecting to sim {} with process {}'.format(
         c, mp.current_process()))
     global_sim.sim = Simulation(c,
                                 targets_amount=5,
                                 speedup=Const.SPEED_UP)
コード例 #11
0
def main():

    initialize()
    game = setup_game()
    
    with mp.Manager() as manager:
    

        exc = manager.Queue() 

        arg_list = []
        for i in range(0, 3):
            arg_list.append( ((i, i+3), game, False, exc))
        arg_list.append( ((5, 5), game, True, exc))
        
        proc_list = []

        for arg in arg_list:
            proc_list.append(mp.Process(target=function, args=arg))
            proc_list[-1].start()
            
        print("Number of active children post start: %d" % len(mp.active_children()))
        for p in proc_list:
            p.join()
        if(not exc.empty()):
            e = exc.get()
            print(e.message)

    print("Number active children post join: %d " % len(mp.active_children()))
    print(mp.active_children())
    print(mp.current_process())
コード例 #12
0
def AutoProxy(token,
              serializer,
              manager=None,
              authkey=None,
              exposed=None,
              incref=True):
    '''
    Return an auto-proxy for `token`
    '''
    _Client = listener_client[serializer][1]

    if exposed is None:
        conn = _Client(token.address, authkey=authkey)
        try:
            exposed = dispatch(conn, None, 'get_methods', (token, ))
        finally:
            conn.close()

    if authkey is None and manager is not None:
        authkey = manager._authkey
    if authkey is None:
        authkey = current_process().authkey

    ProxyType = MakeProxyType('AutoProxy[%s]' % token.typeid, exposed)
    proxy = ProxyType(token,
                      serializer,
                      manager=manager,
                      authkey=authkey,
                      incref=incref)
    proxy._isauto = True
    return proxy
コード例 #13
0
ファイル: multi_test.py プロジェクト: flowersw/Hearthstone-AI
def main():

    initialize()
    game = setup_game()

    with mp.Manager() as manager:

        exc = manager.Queue()

        arg_list = []
        for i in range(0, 3):
            arg_list.append(((i, i + 3), game, False, exc))
        arg_list.append(((5, 5), game, True, exc))

        proc_list = []

        for arg in arg_list:
            proc_list.append(mp.Process(target=function, args=arg))
            proc_list[-1].start()

        print("Number of active children post start: %d" %
              len(mp.active_children()))
        for p in proc_list:
            p.join()
        if (not exc.empty()):
            e = exc.get()
            print(e.message)

    print("Number active children post join: %d " % len(mp.active_children()))
    print(mp.active_children())
    print(mp.current_process())
コード例 #14
0
    def __init__(
        self,
        host,
        port,
        username,
        secret_key,
        achilles_function=None,
        achilles_args=None,
        achilles_callback=None,
        achilles_reducer=None,
        response_mode="OBJECT",
        globals_dict=None,
        chunksize=1,
        command=None,
        command_verified=False,
    ):

        self.HOST = host  # The server's hostname or IP address
        self.PORT = port  # The port used by the server
        self.USERNAME = username
        self.SECRET_KEY = secret_key
        self.response_mode = response_mode
        self.sqlite_db_created = False
        self.sqlite_db = ""
        self.abs_counter = 0
        self.achilles_function = achilles_function
        self.achilles_args = achilles_args
        self.achilles_callback = achilles_callback
        self.achilles_reducer = achilles_reducer
        self.globals_dict = globals_dict
        self.chunksize = chunksize
        self.command = command
        self.command_verified = command_verified

        current_process().authkey = b"176778741"
コード例 #15
0
ファイル: managers.py プロジェクト: uqfoundation/multiprocess
def RebuildProxy(func, token, serializer, kwds):
    '''
    Function used for unpickling proxy objects.

    If possible the shared object is returned, or otherwise a proxy for it.
    '''
    server = getattr(current_process(), '_manager_server', None)

    if server and server.address == token.address:
        return server.id_to_obj[token.id][0]
    else:
        incref = (
            kwds.pop('incref', True) and
            not getattr(current_process(), '_inheriting', False)
            )
        return func(token, serializer, incref=incref, **kwds)
コード例 #16
0
 def _connect(self):
     util.debug('making connection to manager')
     name = current_process().name
     if threading.current_thread().name != 'MainThread':
         name += '|' + threading.current_thread().name
     conn = self._Client(self._token.address, authkey=self._authkey)
     dispatch(conn, None, 'accept_connection', (name,))
     self._tls.connection = conn
コード例 #17
0
ファイル: managers.py プロジェクト: uqfoundation/multiprocess
 def _connect(self):
     util.debug('making connection to manager')
     name = current_process().name
     if threading.current_thread().name != 'MainThread':
         name += '|' + threading.current_thread().name
     conn = self._Client(self._token.address, authkey=self._authkey)
     dispatch(conn, None, 'accept_connection', (name,))
     self._tls.connection = conn
コード例 #18
0
 def __init__(self, address=None, authkey=None, serializer='pickle'):
     if authkey is None:
         authkey = current_process().authkey
     self._address = address     # XXX not final address if eg ('', 0)
     self._authkey = AuthenticationString(authkey)
     self._state = State()
     self._state.value = State.INITIAL
     self._serializer = serializer
     self._Listener, self._Client = listener_client[serializer]
コード例 #19
0
ファイル: managers.py プロジェクト: uqfoundation/multiprocess
 def __init__(self, address=None, authkey=None, serializer='pickle'):
     if authkey is None:
         authkey = current_process().authkey
     self._address = address     # XXX not final address if eg ('', 0)
     self._authkey = AuthenticationString(authkey)
     self._state = State()
     self._state.value = State.INITIAL
     self._serializer = serializer
     self._Listener, self._Client = listener_client[serializer]
コード例 #20
0
def rebuild_handle(pickled_data):
    address, handle, inherited = pickled_data
    if inherited:
        return handle
    sub_debug('rebuilding handle %d', handle)
    conn = Client(address, authkey=current_process().authkey)
    conn.send((handle, os.getpid()))
    new_handle = recv_handle(conn)
    conn.close()
    return new_handle
コード例 #21
0
def rebuild_handle(pickled_data):
    address, handle, inherited = pickled_data
    if inherited:
        return handle
    sub_debug('rebuilding handle %d', handle)
    conn = Client(address, authkey=current_process().authkey)
    conn.send((handle, os.getpid()))
    new_handle = recv_handle(conn)
    conn.close()
    return new_handle
コード例 #22
0
 def _start(self):
     from .connection import Listener
     assert self._listener is None
     debug('starting listener and thread for sending handles')
     self._listener = Listener(authkey=current_process().authkey)
     self._address = self._listener.address
     t = threading.Thread(target=self._serve)
     t.daemon = True
     t.start()
     self._thread = t
コード例 #23
0
 def _start(self):
     from .connection import Listener
     assert self._listener is None
     debug('starting listener and thread for sending handles')
     self._listener = Listener(authkey=current_process().authkey)
     self._address = self._listener.address
     t = threading.Thread(target=self._serve)
     t.daemon = True
     t.start()
     self._thread = t
コード例 #24
0
    def __init__(self, host, port):

        self.HOST = host  # The server's hostname or IP address
        self.PORT = port  # The port used by the server
        self.connected = False
        self.client_id = -1
        self.func = None
        self.callback = None
        self.reducer = None
        self.output_queue = None

        multiprocess.current_process().authkey = b"176778741"
コード例 #25
0
def _get_listener():
    global _listener

    if _listener is None:
        _lock.acquire()
        try:
            if _listener is None:
                debug('starting listener and thread for sending handles')
                _listener = Listener(authkey=current_process().authkey)
                t = threading.Thread(target=_serve)
                t.daemon = True
                t.start()
        finally:
            _lock.release()

    return _listener
コード例 #26
0
def _get_listener():
    global _listener

    if _listener is None:
        _lock.acquire()
        try:
            if _listener is None:
                debug('starting listener and thread for sending handles')
                _listener = Listener(authkey=current_process().authkey)
                t = threading.Thread(target=_serve)
                t.daemon = True
                t.start()
        finally:
            _lock.release()

    return _listener
コード例 #27
0
 def stop(self, timeout=None):
     from .connection import Client
     with self._lock:
         if self._address is not None:
             c = Client(self._address, authkey=current_process().authkey)
             c.send(None)
             c.close()
             self._thread.join(timeout)
             if self._thread.is_alive():
                 sub_warn('ResourceSharer thread did not stop when asked')
             self._listener.close()
             self._thread = None
             self._address = None
             self._listener = None
             for key, (send, close) in self._cache.items():
                 close()
             self._cache.clear()
コード例 #28
0
 def stop(self, timeout=None):
     from .connection import Client
     with self._lock:
         if self._address is not None:
             c = Client(self._address, authkey=current_process().authkey)
             c.send(None)
             c.close()
             self._thread.join(timeout)
             if self._thread.is_alive():
                 sub_warn('ResourceSharer thread did not stop when asked')
             self._listener.close()
             self._thread = None
             self._address = None
             self._listener = None
             for key, (send, close) in self._cache.items():
                 close()
             self._cache.clear()
コード例 #29
0
    def __init__(self,
                 token,
                 serializer,
                 manager=None,
                 authkey=None,
                 exposed=None,
                 incref=True):
        BaseProxy._mutex.acquire()
        try:
            tls_idset = BaseProxy._address_to_local.get(token.address, None)
            if tls_idset is None:
                tls_idset = util.ForkAwareLocal(), ProcessLocalSet()
                BaseProxy._address_to_local[token.address] = tls_idset
        finally:
            BaseProxy._mutex.release()

        # self._tls is used to record the connection used by this
        # thread to communicate with the manager at token.address
        self._tls = tls_idset[0]

        # self._idset is used to record the identities of all shared
        # objects for which the current process owns references and
        # which are in the manager at token.address
        self._idset = tls_idset[1]

        self._token = token
        self._id = self._token.id
        self._manager = manager
        self._serializer = serializer
        self._Client = listener_client[serializer][1]

        if authkey is not None:
            self._authkey = AuthenticationString(authkey)
        elif self._manager is not None:
            self._authkey = self._manager._authkey
        else:
            self._authkey = current_process().authkey

        if incref:
            self._incref()

        util.register_after_fork(self, BaseProxy._after_fork)
コード例 #30
0
 def serve_forever(self):
     '''
     Run the server forever
     '''
     current_process()._manager_server = self
     try:
         try:
             while 1:
                 try:
                     c = self.listener.accept()
                 except (OSError, IOError):
                     continue
                 t = threading.Thread(target=self.handle_request, args=(c,))
                 t.daemon = True
                 t.start()
         except (KeyboardInterrupt, SystemExit):
             pass
     finally:
         self.stop = 999
         self.listener.close()
コード例 #31
0
ファイル: managers.py プロジェクト: uqfoundation/multiprocess
 def serve_forever(self):
     '''
     Run the server forever
     '''
     current_process()._manager_server = self
     try:
         try:
             while 1:
                 try:
                     c = self.listener.accept()
                 except (OSError, IOError):
                     continue
                 t = threading.Thread(target=self.handle_request, args=(c,))
                 t.daemon = True
                 t.start()
         except (KeyboardInterrupt, SystemExit):
             pass
     finally:
         self.stop = 999
         self.listener.close()
コード例 #32
0
ファイル: managers.py プロジェクト: uqfoundation/multiprocess
 def serve_forever(self):
     '''
     Run the server forever
     '''
     self.stop_event = threading.Event()
     current_process()._manager_server = self
     try:
         accepter = threading.Thread(target=self.accepter)
         accepter.daemon = True
         accepter.start()
         try:
             while not self.stop_event.is_set():
                 self.stop_event.wait(1)
         except (KeyboardInterrupt, SystemExit):
             pass
     finally:
         if sys.stdout != sys.__stdout__:
             util.debug('resetting stdout, stderr')
             sys.stdout = sys.__stdout__
             sys.stderr = sys.__stderr__
         sys.exit(0)
コード例 #33
0
 def serve_forever(self):
     '''
     Run the server forever
     '''
     self.stop_event = threading.Event()
     current_process()._manager_server = self
     try:
         accepter = threading.Thread(target=self.accepter)
         accepter.daemon = True
         accepter.start()
         try:
             while not self.stop_event.is_set():
                 self.stop_event.wait(1)
         except (KeyboardInterrupt, SystemExit):
             pass
     finally:
         if sys.stdout != sys.__stdout__:
             util.debug('resetting stdout, stderr')
             sys.stdout = sys.__stdout__
             sys.stderr = sys.__stderr__
         sys.exit(0)
コード例 #34
0
ファイル: managers.py プロジェクト: uqfoundation/multiprocess
    def __init__(self, token, serializer, manager=None,
                 authkey=None, exposed=None, incref=True):
        BaseProxy._mutex.acquire()
        try:
            tls_idset = BaseProxy._address_to_local.get(token.address, None)
            if tls_idset is None:
                tls_idset = util.ForkAwareLocal(), ProcessLocalSet()
                BaseProxy._address_to_local[token.address] = tls_idset
        finally:
            BaseProxy._mutex.release()

        # self._tls is used to record the connection used by this
        # thread to communicate with the manager at token.address
        self._tls = tls_idset[0]

        # self._idset is used to record the identities of all shared
        # objects for which the current process owns references and
        # which are in the manager at token.address
        self._idset = tls_idset[1]

        self._token = token
        self._id = self._token.id
        self._manager = manager
        self._serializer = serializer
        self._Client = listener_client[serializer][1]

        if authkey is not None:
            self._authkey = AuthenticationString(authkey)
        elif self._manager is not None:
            self._authkey = self._manager._authkey
        else:
            self._authkey = current_process().authkey

        if incref:
            self._incref()

        util.register_after_fork(self, BaseProxy._after_fork)
コード例 #35
0
ファイル: managers.py プロジェクト: uqfoundation/multiprocess
def AutoProxy(token, serializer, manager=None, authkey=None,
              exposed=None, incref=True):
    '''
    Return an auto-proxy for `token`
    '''
    _Client = listener_client[serializer][1]

    if exposed is None:
        conn = _Client(token.address, authkey=authkey)
        try:
            exposed = dispatch(conn, None, 'get_methods', (token,))
        finally:
            conn.close()

    if authkey is None and manager is not None:
        authkey = manager._authkey
    if authkey is None:
        authkey = current_process().authkey

    ProxyType = MakeProxyType('AutoProxy[%s]' % token.typeid, exposed)
    proxy = ProxyType(token, serializer, manager=manager, authkey=authkey,
                      incref=incref)
    proxy._isauto = True
    return proxy
コード例 #36
0
 def get_connection(ident):
     from .connection import Client
     address, key = ident
     c = Client(address, authkey=current_process().authkey)
     c.send((key, os.getpid()))
     return c
コード例 #37
0
ファイル: wb_conn.py プロジェクト: grace-shearrer/HCP_puberty
def start_process():
    print('Starting', mltp.current_process().name)
コード例 #38
0
def get_config() -> dict:
    """Load the configuration dictionary.

    - Default configuration is obtained from `./fledge/config_default.yml`.
    - Custom configuration is obtained from `./config.yml` and overwrites the respective default configuration.
    - `./` denotes the repository base directory.
    """

    # Load default configuration values.
    with open(os.path.join(base_path, 'fledge', 'config_default.yml'),
              'r') as file:
        default_config = yaml.safe_load(file)

    # Create local `config.yml` for custom configuration in base directory, if not existing.
    # - The additional data paths setting is added for reference.
    if not os.path.isfile(os.path.join(base_path, 'config.yml')):
        with open(os.path.join(base_path, 'config.yml'), 'w') as file:
            file.write(
                "# Local configuration parameters.\n"
                "# - Configuration parameters and their defaults are defined in `fledge/config_default.yml`\n"
                "# - Copy from `fledge/config_default.yml` and modify parameters here to set the local configuration.\n"
                "paths:\n"
                "  additional_data: []\n")

    # Load custom configuration values, overwriting the default values.
    with open(os.path.join(base_path, 'config.yml'), 'r') as file:
        custom_config = yaml.safe_load(file)

    # Define utility function to recursively merge default and custom configuration.
    def merge_config(default_values: dict, custom_values: dict) -> dict:
        full_values = default_values.copy()
        full_values.update({
            key:
            (merge_config(default_values[key], custom_values[key]) if
             ((key in default_values) and isinstance(default_values[key], dict)
              and isinstance(custom_values[key], dict)) else
             custom_values[key])
            for key in custom_values.keys()
        })
        return full_values

    # Obtain complete configuration.
    if custom_config is not None:
        complete_config = merge_config(default_config, custom_config)
    else:
        complete_config = default_config

    # Define utility function to obtain full paths.
    # - Replace `./` with the base path and normalize paths.
    def get_full_path(path: str) -> str:
        return os.path.normpath(path.replace('./', base_path + os.path.sep))

    # Obtain full paths.
    complete_config['paths']['data'] = get_full_path(
        complete_config['paths']['data'])
    complete_config['paths']['additional_data'] = ([
        get_full_path(path)
        for path in complete_config['paths']['additional_data']
    ])
    complete_config['paths']['database'] = get_full_path(
        complete_config['paths']['database'])
    complete_config['paths']['results'] = get_full_path(
        complete_config['paths']['results'])

    # If not running as main process, set `run_parallel` to False.
    # - Workaround to avoid that subprocesses / workers infinitely spawn further subprocesses / workers.
    if multiprocess.current_process().name != 'MainProcess':
        complete_config['multiprocessing']['run_parallel'] = False

    return complete_config
コード例 #39
0
    def _process_one(self, tile_info):
        process_id = int(multiprocess.current_process().name[-1])
        # print(f"\n--- {process_id} ---\n")

        # --- Init
        tile_name = IMAGE_NAME_FORMAT.format(city=tile_info["city"],
                                             number=tile_info["number"])
        processed_tile_relative_dirpath = os.path.join(
            tile_info['city'], f"{tile_info['number']:02d}")
        processed_tile_dirpath = os.path.join(self.processed_dirpath,
                                              processed_tile_relative_dirpath)
        processed_flag_filepath = os.path.join(processed_tile_dirpath,
                                               "processed_flag")
        stats_filepath = os.path.join(processed_tile_dirpath, "stats.pt")
        os.makedirs(processed_tile_dirpath, exist_ok=True)
        stats = {}

        # --- Check if tile has been processed already
        if os.path.exists(processed_flag_filepath):
            if not self.mask_only:
                stats = torch.load(stats_filepath)
            return stats

        # --- Read data:
        raw_data = self.load_raw_data(tile_info)

        # --- Patch tiles
        if self.patch_size is not None:
            patch_stride = self.patch_stride if self.patch_stride is not None else self.patch_size
            patch_boundingboxes = image_utils.compute_patch_boundingboxes(
                raw_data["image"].shape[0:2],
                stride=patch_stride,
                patch_res=self.patch_size)
            class_freq_list = []
            for i, bbox in enumerate(
                    tqdm(patch_boundingboxes,
                         desc=f"Patching {tile_name}",
                         leave=False,
                         position=process_id)):
                sample = {
                    "image_filepath": raw_data["image_filepath"],
                    "name":
                    f"{tile_name}.rowmin_{bbox[0]}_colmin_{bbox[1]}_rowmax_{bbox[2]}_colmax_{bbox[3]}",
                    "bbox": bbox,
                    "city": tile_info["city"],
                    "number": tile_info["number"],
                }

                if self.gt_type == "npy" or self.gt_type == "geojson":
                    patch_gt_polygons = polygon_utils.patch_polygons(
                        raw_data["gt_polygons"],
                        minx=bbox[1],
                        miny=bbox[0],
                        maxx=bbox[3],
                        maxy=bbox[2])
                    sample["gt_polygons"] = patch_gt_polygons
                elif self.gt_type == "tif":
                    patch_gt_mask = raw_data["gt_polygons_image"][
                        bbox[0]:bbox[2], bbox[1]:bbox[3], :]
                    sample["gt_polygons_image"] = patch_gt_mask

                sample["image"] = raw_data["image"][bbox[0]:bbox[2],
                                                    bbox[1]:bbox[3], :]

                sample = self.pre_transform(
                    sample
                )  # Needs "image" to infer shape even if mask_only is True
                if self.mask_only:
                    del sample["image"]  # Don't need RGB image anymore

                relative_filepath = os.path.join(
                    processed_tile_relative_dirpath,
                    "data.{:06d}.pt".format(i))
                filepath = os.path.join(self.processed_dirpath,
                                        relative_filepath)
                torch.save(sample, filepath)

                # Compute stats
                if not self.mask_only:
                    if self.gt_type == "npy" or self.gt_type == "geojson":
                        class_freq_list.append(
                            np.mean(sample["gt_polygons_image"], axis=(0, 1)) /
                            255)
                    elif self.gt_type == "mask":
                        raise NotImplementedError("mask class freq")
                    else:
                        raise NotImplementedError(
                            f"gt_type={self.gt_type} not implemented for computing stats"
                        )

            # Aggregate stats
            if not self.mask_only:
                if len(class_freq_list):
                    class_freq_array = np.stack(class_freq_list, axis=0)
                    stats["class_freq"] = np.mean(class_freq_array, axis=0)
                    stats["num"] = len(class_freq_list)
                else:
                    print("Empty tile:", tile_info["city"],
                          tile_info["number"], "polygons:",
                          len(raw_data["gt_polygons"]))
        else:
            raise NotImplemented("patch_size is None")

        # Save stats
        if not self.mask_only:
            torch.save(stats, stats_filepath)

        # Mark tile as processed with flag
        pathlib.Path(processed_flag_filepath).touch()

        return stats
コード例 #40
0
def nprint(*args, **kwargs):
    _orig_print(mp.current_process()._identity[0],
                flush=True,
                end=': ',
                **kwargs)
    _orig_print(*args, **kwargs, flush=True)
コード例 #41
0
 def get_connection(ident):
     from .connection import Client
     address, key = ident
     c = Client(address, authkey=current_process().authkey)
     c.send((key, os.getpid()))
     return c