Пример #1
0
def read_or_process_data(path, limit: int = None):
    tmp_train_data_file = "/tmp/train_data.jsonl.gz"
    tmp_meta_file = "/tmp/meta_data.jsonl"
    if not os.path.exists(tmp_train_data_file):

        train_df, features = read_csvs_build_features(path)
        types = {
            name: typ
            for name, typ in train_df.dtypes.to_dict().items()
            if name in features
        }
        numerical_features = [
            name for name, typ in types.items() if typ == float or typ == int
        ]
        # print(numerical_features)
        categorical_features = [
            name for name, typ in types.items() if typ == str
        ]
        # print(categorical_features)
        input_dim = len(numerical_features) + len(categorical_features)

        def dataframe_to_dicts(df):
            data = [row[1].to_dict() for row in df.iterrows()]
            [d.__delitem__("Date") for d in data]
            return data

        train_data_dicts = dataframe_to_dicts(train_df)
        y_train_list = train_df["target"].tolist()
        [
            d.__setitem__("target", t)
            for d, t in zip(train_data_dicts, y_train_list)
        ]
        data_io.write_jsonl(tmp_train_data_file, train_data_dicts)
        data_io.write_jsonl(
            tmp_meta_file,
            [{
                "numerical_features": numerical_features,
                "categorical_features": categorical_features,
            }],
        )
    else:
        print("loading already processed data")
        train_data_dicts = list(
            data_io.read_jsonl(tmp_train_data_file, limit=limit))
        y_train_list = [d["target"] for d in train_data_dicts]
        meta = list(data_io.read_jsonl(tmp_meta_file))[0]
        numerical_features = meta["numerical_features"]
        categorical_features = meta["categorical_features"]
    # features = categorical_features + numerical_features
    return train_data_dicts, categorical_features, numerical_features
Пример #2
0
        def consumer(file):
            num_to_skip = es_client.get_source(index=STATE_INDEX_NAME,
                                               id=file,
                                               doc_type=STATE_TYPE)["line"]
            process_name = multiprocessing.current_process().name
            print("%s is skipping %d lines in file: %s " %
                  (process_name, num_to_skip, file))

            results_g = helpers.streaming_bulk(
                es_client,
                actions=(build_es_action(try_to_process(d),
                                         es_index_name,
                                         es_type,
                                         op_type="index")
                         for d in data_io.read_jsonl(
                             file, limit=limit, num_to_skip=num_to_skip)),
                chunk_size=chunk_size,
                yield_ok=True,
                raise_on_error=False,
                raise_on_exception=False,
            )
            counter = num_to_skip
            for k, (ok, d) in enumerate(results_g):
                counter += 1
                if not ok and "index" in d:
                    print("shit")
                if k % 1000 == 0:
                    update_state(file, {"line": counter})

            update_state(file, {"line": counter})
            if limit is None or counter < limit:
                update_state(file, {"done": True})

            print("%s is done; inserted %d new docs!" %
                  (process_name, counter - num_to_skip))
Пример #3
0
def download_edictos(
    data_dir=f"{os.environ['HOME']}/data/corteconstitucional/edictos", ):
    """
    needs to be run several times, some times it claims that it cannot find downloaded pdfs,
    :param data_dir:
    :return:
    """
    url = "https://www.corteconstitucional.gov.co/secretaria/edictos/"
    download_dir = f"{data_dir}/downloads"
    os.makedirs(download_dir, exist_ok=True)

    wd = build_chrome_driver(download_dir, headless=True)
    hrefs = get_hrefs(url, wd)

    old_file = f"{data_dir}/documents.jsonl"
    found_existing_documents = os.path.isfile(old_file)
    if found_existing_documents:
        new_file = old_file.split(".jsonl")[0] + "_updated.jsonl"
        old_docs = list(data_io.read_jsonl(old_file))
    else:
        old_docs = []
        new_file = old_file
    try:
        data_io.write_jsonl(
            new_file, generate_raw_docs(old_docs, hrefs, wd, download_dir))
    except Exception as e:
        traceback.print_exc()
        print("shit happened")
    finally:
        if found_existing_documents:
            shutil.move(new_file, old_file)
Пример #4
0
def read_scierc_seqs(
    jsonl_file, process_fun=lambda x: (x["sentences"], x["ner"])
) -> List[List[Tuple[str, str]]]:
    seqs = [
        sent for sentences, ner in (process_fun(d)
                                    for d in data_io.read_jsonl(jsonl_file))
        for sent in build_tagged_scierc_sequences(sentences=sentences, ner=ner)
    ]
    return seqs
Пример #5
0
def dummy_project_user_documents():
    user_id = 33
    if len(list_projects()) == 0:
        create_project(project_name="testproject")
    if len(list_users()) == 1:
        create_user("user", "password", user_id)

    project_id = next(iter(list_projects()))["id"]
    add_all_user_to_project(project_id)
    DATA_DIR = "."
    filename = "sample_docs.jsonl"
    jsonl_file = os.path.join(DATA_DIR, filename)
    docs = read_jsonl(jsonl_file)
    # user_id = [u['id'] for u in list_users() if u['id']!=1][0]
    create_documents(docs, project_id=project_id, user_id=user_id)
Пример #6
0
def build_schema_and_corpus():
    fields = {
        process_field_name(f): TEXT(stored=True, lang="de")
        for f in FIELDS
    }
    schema = Schema(
        aktenzeichen=ID(stored=True),
        **fields,
    )
    file = "BverfG.jsonl.gz"
    data = ({
        process_field_name(k): v
        for k, v in d.items() if k in FIELDS + ["aktenzeichen"]
    } for d in data_io.read_jsonl(file))
    return schema, data
Пример #7
0
def generate_edictos(
        data_dir=f"{os.environ['HOME']}/data/corteconstitucional/edictos",
        limit=None) -> Generator[Edicto, None, None]:

    g = (d for d in data_io.read_jsonl(f"{data_dir}/documents.jsonl")
         if d["href"] not in HREF_TO_EXCLUDE)
    for k, d in enumerate(g):
        if d["href"] in KNOWN_TO_HAVE_NO_EXPEDIENTE:
            continue
        if "pdf" in d:
            pdf_file = f"{data_dir}/downloads/{d['pdf']}".replace(" ", "\ ")
            text = parse_pdf(pdf_file)
            yield from extract_data(d["href"], text)

        elif "html" in d:
            text = html2text.html2text(d["html"])
            yield from extract_data(d["href"], text)
        if limit is not None and k > limit:
            break
Пример #8
0
def scierc_to_postgres(postgres_host, scierc_file: str):
    sqlalchemy_base, sqlalchemy_engine = get_sqlalchemy_base_engine(
        host=postgres_host)
    data = data_io.read_jsonl(scierc_file)
    data = ({**{'id': json.dumps(d.pop('doc_key'))}, **d} for d in data)
    data = (d for d in data if isinstance(d['id'], str))
    tables = get_tables_by_reflection(sqlalchemy_base.metadata,
                                      sqlalchemy_engine)
    table_name = 'scierc'
    if table_name in tables:
        table = tables[table_name]
        if from_scratch:
            table.drop(sqlalchemy_engine)
            table = None
    else:
        table = None
    if table is None:
        columns = [Column('id', String, primary_key=True)] + [
            Column(colname, String)
            for colname in ['sentences', 'ner', 'relations', 'clusters']
        ]
        table = Table(table_name,
                      sqlalchemy_base.metadata,
                      *columns,
                      extend_existing=True)
        print('creating table %s' % table.name)
        table.create()

    def update_fun(val, old_row):
        d = {
            k: json.dumps({'annotator_luan': v})
            for k, v in val.items() if k != 'sentences'
        }
        d['sentences'] = json.dumps(val['sentences'])
        return d

    with sqlalchemy_engine.connect() as conn:
        insert_or_update(conn,
                         table, ['sentences', 'ner', 'relations', 'clusters'],
                         data,
                         update_fun=update_fun)
Пример #9
0
def populate_es_parallel_bulk(
    es, files, es_index_name, es_type, limit=None, num_processes=4, chunk_size=500
):
    dicts_g = (d for file in files for d in read_jsonl(file, limit=limit))

    actions_g = (build_es_action(d, es_index_name, es_type) for d in dicts_g)
    results_g = helpers.parallel_bulk(
        es,
        actions_g,
        thread_count=num_processes,
        queue_size=num_processes,
        chunk_size=chunk_size,
        raise_on_exception=False,
        raise_on_error=False,
    )
    failed_g = (
        pop_exception(d)
        for ok, d in tqdm(results_g)
        if not ok and d.get("create", {}).get("status", 200) != 409
    )
    data_io.write_jsonl("failed.jsonl", failed_g)
Пример #10
0
        def consumer(file: str):  ## called multiple times on worker
            with sqlalchemy_engine.connect() as conn:
                assert isinstance(file, str)
                num_to_skip = num_lines_already_read(conn, file)
                print("%s skipping %d lines in file %s" %
                      (multiprocessing.current_process(), num_to_skip, file))
                data_g = (fill_missing_with_Nones(d, column_names)
                          for d in data_io.read_jsonl(
                              file, limit=limit, num_to_skip=num_to_skip))

                for count in populate_table(conn,
                                            table,
                                            data_g,
                                            batch_size=batch_size):
                    conn.execute(state_table.update().values(
                        line=count +
                        num_to_skip).where(state_table.c.file == file))

                if limit is None or num_lines_already_read(conn, file) < limit:
                    conn.execute(state_table.update().values(done=True).where(
                        state_table.c.file == file))
Пример #11
0
def prepare_manifest(corpora_dir="/content/corpora", limit=None):

    manifest = "manifest.jsonl"
    manifests = list(Path(corpora_dir).rglob("manifest.jsonl.gz"))
    limit = round(limit / len(manifests)) if limit is not None else None

    def get_file_name(f):
        if "/" in f:
            o = f.split('/')[-1]
        else:
            o = f
        return o

    g = (
        {
            "audio_filepath":
            f"{str(f).replace(f.name, '')}/mp3/{get_file_name(d['audio_file'])}",  #TODO(tilo): just hack for TEDLIUM!
            "duration": d["duration"],
            "text": d["text"],
        } for f in manifests for d in data_io.read_jsonl(str(f), limit=limit))
    data_io.write_jsonl(manifest, g)
    return manifest
Пример #12
0
        def consumer(file):
            print("%s is doing %s; limit: %d" %
                  (multiprocessing.current_process(), file, limit))

            dicts_g = (d for d in data_io.read_jsonl(file, limit=limit))

            actions_g = (build_es_action(d,
                                         es_index_name,
                                         es_type,
                                         op_type="index") for d in dicts_g)
            results_g = helpers.streaming_bulk(
                es_client,
                actions_g,
                chunk_size=chunk_size,
                yield_ok=True,
                raise_on_error=False,
                raise_on_exception=False,
            )

            failed_g = (pop_exception(d) for ok, d in results_g if not ok)
            data_io.write_jsonl(
                "%s_failed.jsonl" % multiprocessing.current_process(),
                failed_g)
Пример #13
0
        sns.set_style("whitegrid")
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
        chart.set_xticklabels(chart.get_xticklabels(), rotation=45)
        filter_fun_str = (inspect.getsourcelines(filter_fun)[0][0].split(":")
                          [1].strip("\n").strip(","))
        # plt.gcf().subplots_adjust(bottom=0.15,left=0.15, right=0.15)
        plt.tight_layout()
        ax.figure.savefig("boxplot_%s.png" % filter_fun_str)

    # plt.show()
    plt.close()


def get_num_runs(data):
    exp_names = list(set([d["exp_name"] for d in data]))
    num_runs = len([
        d for d in data
        if d["exp_name"] == exp_names[0] and d["split-name"] == "train"
    ])
    return num_runs


if __name__ == "__main__":
    # path = os.environ["HOME"] + "/gunther/data/plato_results/40000_4000"
    # path = os.environ["HOME"] + "/gunther/data/plato_results/5000_500_again"
    # file = "scores_2000traindialogues.jsonl"
    # scoring_runs = list(data_io.read_jsonl("results/40000_4000/results.jsonl"))
    scoring_runs = list(data_io.read_jsonl("results/5000_500/results.jsonl"))

    plot_results(scoring_runs)
Пример #14
0
    data = [{
        "train_size": step["train_size"],
        "f1-micro-spanlevel": step["scores"]["test"]["seqeval-f1"],
        "select_fun": e["select_fun"],
    } for e in experiments for step in e["scores"]]
    df = pd.DataFrame(data=data)
    ax = sns.boxplot(
        ax=ax,
        x="train_size",
        y="f1-micro-spanlevel",
        hue="select_fun",
        data=df,
    )

    df1 = df[df.train_size == df.train_size[0]]
    num_runs = len(
        df1[df1.select_fun == df1.select_fun[0]])  # well I rarely use pandas!

    ax.set_title("conll03-en %s-set scores;  %d runs" % ("test", num_runs))

    ax.figure.savefig(save_dir + "/active_learning_curve.png")

    plt.close()


if __name__ == "__main__":
    folder = "conll03_en_1percent"
    data = data_io.read_jsonl("active_learning/results/%s/scores.jsonl" %
                              folder)
    plot_it(data, save_dir='active_learning/results/%s' % folder)
Пример #15
0
def regex_tokenizer(text,
                    pattern=r"(?u)\b\w\w+\b"
                    ):  # pattern stolen from scikit-learn
    return [m.group() for m in re.finditer(pattern, text)]


def text_to_bow(text):
    return get_nrams(regex_tokenizer(text), 1, 3)


if __name__ == "__main__":
    file = "BverfG_juris.jsonl.gz"
    # print(Counter(k for d in data_io.read_jsonl(file) for k in d.keys()))
    p = "Orientierungssatz"
    data = [d[p] for d in data_io.read_jsonl(file) if p in d]
    texts = [" ".join(l) for l in data]
    print("%d texts" % len(texts))

    vectorizer = TfidfVectorizer(
        min_df=3,
        tokenizer=lambda x: x,
        preprocessor=lambda x: x,
        lowercase=False,
        sublinear_tf=False,
        max_features=20000,
        max_df=0.75,
    )
    tf = vectorizer.fit_transform([text_to_bow(text) for text in texts])

    pca = TruncatedSVD(n_components=20, random_state=42)
Пример #16
0
from util import data_io
from util.util_methods import merge_dicts

from corteconstitucional.parse_edictos import Edicto
from corteconstitucional.parse_proceso_tables import parse_table


def merge_edictos_proceso_tables(
    edictos: List,
    data_path=f"{os.environ['HOME']}/data/corteconstitucional/procesos_tables"
) -> List:
    raw_data = list(
        data_io.read_json(str(file))
        for file in tqdm(Path(data_path).glob("*.json")))
    print("parse tables")
    table_data = (parse_table(d) for d in raw_data)
    exp2table = {t.expediente: t for t in tqdm(table_data)}
    g = (merge_dicts([
        asdict(e), {
            "tables": [asdict(exp2table[exp]) for exp in e.expedientes]
        }
    ]) for e in edictos)
    merged_data = list(g)
    return merged_data


if __name__ == "__main__":
    edictos = [Edicto(**d) for d in data_io.read_jsonl("edictos.jsonl")]
    merged_data = merge_edictos_proceso_tables(edictos)
    data_io.write_jsonl("/tmp/merged_edictos2tables.jsonl", merged_data)
Пример #17
0
def spaced_tokens_and_tokenoffset2charoffset(sentences:List[List[str]]):
    g = [(sent_id,tok) for sent_id,sent in enumerate(sentences) for tok in sent]
    spaced_tokens = [x for tok_id, (sent_id,tok) in enumerate(g) for x in [(tok, tok_id), (' ', tok_id + 0.5)]]
    tok2sent_id = {tok_id:sent_id for tok_id, (sent_id,tok) in enumerate(g)}
    char_offsets = numpy.cumsum([0] + [len(x) for x, _ in spaced_tokens])
    tok2charoff = {tok_id: char_offsets[i] for i, (tok, tok_id) in enumerate(spaced_tokens)}
    return spaced_tokens, tok2charoff,tok2sent_id

def another_span_is_wider(s,spans):
    return any([(s['start']>=o['start']) and (s['end']<=o['end']) and s['id']!=o['id'] for o in spans])

def convert_to_doccano(doc):
    spaced_tokens, tok2charoff,tok2sent_id = spaced_tokens_and_tokenoffset2charoffset(doc['sentences'])
    if not isinstance(doc['ner'],dict):
        ner={'luan':doc['ner']}
    else:
        ner = doc['ner']
    spans = build_ner_spans(ner, lambda x: True, tok2charoff)
    spans = list({'%d-%d'%(s['start'],s['end']):s for s in spans}.values())
    spans = [s for s in spans if not another_span_is_wider(s,spans)]
    assert len(spans)>0
    text = ''.join(s for s,_ in spaced_tokens)
    labels = [[int(s['start']),int(s['end']),s['label']] for s in spans]
    return {'text':text,'labels':labels}


if __name__ == '__main__':
    file = 'data/processed_data/json/train.json'
    data = list(data_io.read_jsonl(file))[:3]
    doccano_datum = convert_to_doccano(data[0])
    pprint(doccano_datum)
Пример #18
0
    try:
        key = f"{int(d['Año'])}_{int(d['Nro. Edicto'])}"
    except Exception:
        key = None
    return key


def normalize_expediente(e):
    tupe, number = e.split("-")
    while number.startswith("0"):
        number = number[1:]
    return f"{tupe}-{number}"


if __name__ == "__main__":
    g = data_io.read_jsonl("/tmp/merged_edictos2tables.jsonl")
    data = list(g)
    # print(Counter(len(d["tables"]) for d in g))
    df = load_tati_data()
    tati_data = df.to_dict("records")
    key2tati_data = {
        get_key_from_tati_data(d): d
        for d in tati_data if get_key_from_tati_data(d) is not None
    }
    shit_counter = 0
    for d in data:
        year = d["edicto_year"]
        num = int(d["no"])
        key = f"{year}_{num}"
        if key in key2tati_data.keys():
            tati_datum = key2tati_data.pop(key)
Пример #19
0
def multi_eval(algos,LOGS_DIR, num_eval=5, num_workers=12):

    """
    evaluating 12 jobs with 1 workers took: 415.78 seconds
    evaluating 12 jobs with 3 workers took: 154.78 seconds
    evaluating 12 jobs with 6 workers took: 91.88 seconds
    evaluating 12 jobs with 12 workers took: 70.68 seconds

    on gunther one gets cuda out of mem error with num_workers>12
    """

    task = PlatoScoreTask(LOGS_DIR = LOGS_DIR)


    jobs = [
        Experiment(
            job_id=get_id(),
            name=build_name(algo, error_sim, two_slots),
            config=build_config(algo, error_sim=error_sim, two_slots=two_slots),
            train_dialogues=td,
            eval_dialogues=1000,
            num_warmup_dialogues=warmupd
        )
        for _ in range(num_eval)
        for error_sim in [False,True]
        for two_slots in [False,True]
        for td in [40000]
        for warmupd in [4000]
        for algo in algos
    ]
    start = time()

    outfile = LOGS_DIR+"/results.jsonl"

    mode = "wb"
    if os.path.isdir(LOGS_DIR):
        results = list(data_io.read_jsonl(outfile))
        done_ids = [e['job_id'] for e in results]
        jobs = [e for e in jobs if e.job_id not in done_ids]
        print('only got %d jobs to do'%len(jobs))
        print([e.job_id for e in jobs])
        mode = "ab"
    else:
        os.makedirs(LOGS_DIR)

    if num_workers > 0:
        num_workers = min(len(jobs),num_workers)
        with WorkerPool(processes=num_workers, task=task, daemons=False) as p:
            processed_jobs = p.process_unordered(jobs)
            data_io.write_jsonl(outfile, processed_jobs, mode=mode)
    else:
        with task as t:
            processed_jobs = [t(job) for job in jobs]
            data_io.write_jsonl(outfile, processed_jobs, mode=mode)

    scoring_runs = list(data_io.read_jsonl(outfile))
    plot_results(scoring_runs,LOGS_DIR)

    print(
        "evaluating %d jobs with %d workers took: %0.2f seconds"
        % (len(jobs), num_workers, time() - start)
    )
Пример #20
0
    speeds = []

    def benchmark_speed_print_and_append(populate_fun, method_name: str,
                                         num_processes: int):
        speed = benchmark_speed(populate_fun)
        speeds.append({
            "method": method_name,
            "speed": speed,
            "num-processes": num_processes
        })
        print("%d processes %s-speed: %0.2f docs per second" %
              (num_processes, method_name, speed))

    speed = benchmark_speed(lambda: populate_es_streaming_bulk(
        build_es_client(),
        (d for d in read_jsonl(files[0], limit=limit)),
        INDEX_NAME,
        TYPE,
    ))
    speeds.append({"method": "streaming", "speed": speed, "num-processes": 1})
    print("streaming-speed: %0.2f docs per second" % speed)
    give_es_some_time()

    for num_processes in [1, 2, 4]:
        give_es_some_time()
        benchmark_speed_print_and_append(
            populate_fun=lambda: populate_es_parallel_bulk(
                build_es_client(),
                [files[0]],
                INDEX_NAME,
                TYPE,
Пример #21
0
from util import data_io

from doccano_api import purge, create_project, create_user
admin_id = 1
if __name__ == '__main__':
    purge()
    for a in data_io.read_jsonl('annotators.jsonl'):
        create_user(a['name'], a['password'], a['id'])
        create_project(project_name=a['name'] + '_project',
                       users=[admin_id, a['id']])
Пример #22
0
        yield (doc)


if __name__ == "__main__":

    INDEX_NAME = "sample-index"
    TYPE = "document"
    host = 'localhost'  # or somewhere else!
    es = Elasticsearch(hosts=[{"host": host, "port": 9200}])
    es.indices.delete(index=INDEX_NAME, ignore=[400, 404])
    es.indices.create(index=INDEX_NAME, ignore=400)

    path = '.'
    file_names = [
        file_name for file_name in os.listdir(path)
        if file_name.endswith('.jsonl')
    ]
    dicts_g = (d for file_name in file_names
               for d in read_jsonl(path + '/' + file_name))

    actions = es_actions_generator(dicts_g)
    bulk(es, actions)

    sleep(3)
    count = es.count(index=INDEX_NAME,
                     doc_type=TYPE,
                     body={"query": {
                         "match_all": {}
                     }})['count']
    print("you've got an es-index of %d documents" % count)
Пример #23
0
            "Sentencia":
            fix_sentencia(datum["sentencia"]),
            FECHA_RADICACION:
            reformat_date(datum[radicacion]),
            FECHA_DECISION:
            datum["sentencia_date"],
            FIJACION_EDICTO:
            reformat_date(datum["edicto_date"]),
            ACUMULADA:
            reformat_date(datum[ACUMULADA][0]) if ACUMULADA in datum else None
        }
        yield tati_datum


def build_dataframe(merged_data: List) -> DataFrame:
    rows = (manual_patch(r) for d in merged_data
            for r in flatten_expedientes(d) if r[ANO] >= 2015)
    # pprint(Counter(f"{r['no']}_{r['edicto_year']}" for r in rows))
    df = pandas.DataFrame(rows)
    return df


if __name__ == '__main__':
    merged_data = list(data_io.read_jsonl("/tmp/merged_edictos2tables.jsonl"))
    df = build_dataframe(merged_data)
    df = df.sort_values([ANO, NO_EDICTO], ascending=[True, True])
    consecutive_edicto_no_step = df[df[NO_EDICTO].diff() > 1]
    for _, d in consecutive_edicto_no_step.iterrows():
        print(f"missing: year: {d[ANO]}; no: {d[NO_EDICTO]-1}")
    df.to_csv(f"tilo_table.csv", sep="\t", index=False)
Пример #24
0
import dash
import dash_auth
import dash_bootstrap_components as dbc
from flask import Flask
from util import data_io

VALID_USERNAME_PASSWORD_PAIRS = {
    d["login"]: d["password"] for d in data_io.read_jsonl("credentials.jsonl")
}
server = Flask(__name__)
app = dash.Dash(
    __name__,
    server=server,
    suppress_callback_exceptions=True,
    external_stylesheets=[dbc.themes.BOOTSTRAP],
)
auth = dash_auth.BasicAuth(app, VALID_USERNAME_PASSWORD_PAIRS)
Пример #25
0
 def not_found_generator():
     for d in tqdm(data_io.read_jsonl(file)):
         body = build_body(d)
         r = es_client.search(index=INDEX, body=body, size=3)
         if r['hits']['total']['value'] < 1:
             yield d
Пример #26
0
def read_scierc_data_to_FlairSentences(jsonl_file: str) -> Dataset:
    dataset: Dataset = [
        sent for d in data_io.read_jsonl(jsonl_file)
        for sent in build_flair_sentences(d)
    ]
    return dataset
Пример #27
0
from util import data_io
import networkx as nx
from matplotlib import pyplot as plt

if __name__ == '__main__':
    graph = nx.DiGraph()
    data = [d['a'] for d in data_io.read_jsonl('datasample.json')]

    nodes = [node for d in data for node in d['nodes']]
    for n in nodes:
        graph.add_node(n['id'], **n)

    for n in [r for d in data for r in d['rels']]:
        graph.add_edge(n['start']['id'], n['end']['id'], **n)

    plt.figure(figsize=(50, 50))
    pos = nx.drawing.layout.spring_layout(graph)
    labels = {d['id']: d['properties']['EntityName'] for d in nodes}
    nx.draw_networkx(graph, pos=pos, labels=labels, font_size=9)
    plt.savefig("graph.png")
    plt.show()

    print()