예제 #1
0
파일: dask_util.py 프로젝트: GaloisInc/FAW
    def __init__(self, api_info):
        self.api = faw_pipelines_util.Api(api_info)
        with self.api.task_file_read('model.chkpt', taskname='learn') as f:
            buf = io.BytesIO(f.read())
        self._data = torch.load(buf)
        self._lock = threading.Lock()

        # This happens on a dask worker -- aka, needs a path adjustment
        path = os.path.dirname(os.path.abspath(__file__))
        sys.path.insert(0, path)

        from . import model as model_module
        self._model = model_module.Model.argparse_create(
            self._data['model_args'])
        self._model.build()
        self._model.load_state_dict(self._data['model']['state_dict'])
        self._model.state_restore(
            pickle.loads(self._data['model']['model_extra']))
        for optim, state_dict in zip([self._model._optim],
                                     self._data['model']['optims']):
            optim.load_state_dict(state_dict)

        self._model.eval()
        self._lastuse = time.monotonic()
        self._queue_todo = asyncio.Queue()
예제 #2
0
def main(api_info, cmd_args):
    ap = argparse.ArgumentParser(description=__doc__)
    ap.add_argument('input_file')
    args = ap.parse_args(cmd_args)

    api = faw_pipelines_util.Api(api_info)
    model = CachedParser.get(api, 'model.chkpt', taskname='learn')

    with open(args.input_file, 'rb') as f:
        fdata = f.read(model.get_header_bytes().result())

    def bytes_to_str(b):
        return json.dumps(b.decode('latin1'))[1:-1]

    rules = collections.defaultdict(int)
    for i in range(len(fdata)):
        # range(n gram size)
        for j in range(4):
            if i + 1 + j > len(fdata):
                continue

            rules[bytes_to_str(fdata[i:i+j+1])] += 1
    for k, v in rules.items():
        print(f'bytes: {k}: {v}')
예제 #3
0
async def _pipeline_spawn_task(mongodb_conn, dask_client, task_cfg,
        tasks_downstream, task_api_info, old_future_info, new_future_info_cb):
    """Ensure that a given task is either up-to-date or running.
    """
    # Figure out the correct version of this task. There are two versions:
    # the config file version, and the data version. Each pipeline task depends
    # on its own config file version and the data versions of its dependents.
    # Furthermore, a task must only run if its dependents are up to date.
    api = faw_pipelines_util.Api(task_api_info, mongodb_conn)

    should_run = True
    last_task_status = await api._internal_task_status_get_state()

    if last_task_status.disabled:
        should_run = False
    else:
        for dep in task_cfg['dependsOn']:
            task_status = await api._internal_task_status_get_state(taskname=dep)
            if not task_status.done:
                should_run = False

    run_version = task_cfg['version']
    task_api_info['task_version'] = run_version

    # Abort if out of date
    if not should_run:
        if old_future_info is not None:
            await old_future_info['future'].cancel()
        return
    elif run_version != last_task_status.version:
        # Running an update on an old version OR not running at all
        if old_future_info is not None:
            await old_future_info['future'].cancel()

            # Wait 1s -- reasonable time to expect the process to have been
            # killed.
            await asyncio.sleep(1)

        # Clear database + set version + unset done flag.
        await api.destructive__task_change_version(run_version,
                tasks_downstream)
    else:
        # DB is on currect version; is it up to date (finished)?
        if last_task_status.done:
            # Done
            return

        # Not done -- ensure it's running
        if old_future_info is not None and not old_future_info['future'].done():
            # Keep the reference
            new_future_info_cb(old_future_info)
            return

    # If we reach here, we want to launch a new task runner
    if task_api_info['task'] == 'internal--faw-final-reprocess-db':
        # Special case -- reprocess DB, set done
        tools_to_reset = []
        if True:
            for k, v in _config_base['parsers'].items():
                if v.get('pipeline') == task_api_info['pipeline']:
                    # This parser needs to be recomputed when we finish
                    tools_to_reset.append(k)
        await _reparse_db_fn(tools_to_reset=tools_to_reset)
        await api._internal_task_status_set_completed()
        return

    # Normal case -- spawn a dask task which spawns the FAW task.
    future = dask_client.submit(
            lambda: _pipeline_task_run(task_cfg, task_api_info),
            pure=False)

    future_info = {
        'future': future,
    }
    new_future_info_cb(future_info)