def __init__(self, api_info): self.api = faw_pipelines_util.Api(api_info) with self.api.task_file_read('model.chkpt', taskname='learn') as f: buf = io.BytesIO(f.read()) self._data = torch.load(buf) self._lock = threading.Lock() # This happens on a dask worker -- aka, needs a path adjustment path = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, path) from . import model as model_module self._model = model_module.Model.argparse_create( self._data['model_args']) self._model.build() self._model.load_state_dict(self._data['model']['state_dict']) self._model.state_restore( pickle.loads(self._data['model']['model_extra'])) for optim, state_dict in zip([self._model._optim], self._data['model']['optims']): optim.load_state_dict(state_dict) self._model.eval() self._lastuse = time.monotonic() self._queue_todo = asyncio.Queue()
def main(api_info, cmd_args): ap = argparse.ArgumentParser(description=__doc__) ap.add_argument('input_file') args = ap.parse_args(cmd_args) api = faw_pipelines_util.Api(api_info) model = CachedParser.get(api, 'model.chkpt', taskname='learn') with open(args.input_file, 'rb') as f: fdata = f.read(model.get_header_bytes().result()) def bytes_to_str(b): return json.dumps(b.decode('latin1'))[1:-1] rules = collections.defaultdict(int) for i in range(len(fdata)): # range(n gram size) for j in range(4): if i + 1 + j > len(fdata): continue rules[bytes_to_str(fdata[i:i+j+1])] += 1 for k, v in rules.items(): print(f'bytes: {k}: {v}')
async def _pipeline_spawn_task(mongodb_conn, dask_client, task_cfg, tasks_downstream, task_api_info, old_future_info, new_future_info_cb): """Ensure that a given task is either up-to-date or running. """ # Figure out the correct version of this task. There are two versions: # the config file version, and the data version. Each pipeline task depends # on its own config file version and the data versions of its dependents. # Furthermore, a task must only run if its dependents are up to date. api = faw_pipelines_util.Api(task_api_info, mongodb_conn) should_run = True last_task_status = await api._internal_task_status_get_state() if last_task_status.disabled: should_run = False else: for dep in task_cfg['dependsOn']: task_status = await api._internal_task_status_get_state(taskname=dep) if not task_status.done: should_run = False run_version = task_cfg['version'] task_api_info['task_version'] = run_version # Abort if out of date if not should_run: if old_future_info is not None: await old_future_info['future'].cancel() return elif run_version != last_task_status.version: # Running an update on an old version OR not running at all if old_future_info is not None: await old_future_info['future'].cancel() # Wait 1s -- reasonable time to expect the process to have been # killed. await asyncio.sleep(1) # Clear database + set version + unset done flag. await api.destructive__task_change_version(run_version, tasks_downstream) else: # DB is on currect version; is it up to date (finished)? if last_task_status.done: # Done return # Not done -- ensure it's running if old_future_info is not None and not old_future_info['future'].done(): # Keep the reference new_future_info_cb(old_future_info) return # If we reach here, we want to launch a new task runner if task_api_info['task'] == 'internal--faw-final-reprocess-db': # Special case -- reprocess DB, set done tools_to_reset = [] if True: for k, v in _config_base['parsers'].items(): if v.get('pipeline') == task_api_info['pipeline']: # This parser needs to be recomputed when we finish tools_to_reset.append(k) await _reparse_db_fn(tools_to_reset=tools_to_reset) await api._internal_task_status_set_completed() return # Normal case -- spawn a dask task which spawns the FAW task. future = dask_client.submit( lambda: _pipeline_task_run(task_cfg, task_api_info), pure=False) future_info = { 'future': future, } new_future_info_cb(future_info)