예제 #1
0
from typing import List, Optional

import numpy as np
import pandas as pd

from cascade_at.executor.args.arg_utils import ArgumentList
from cascade_at.executor.args.args import ModelVersionID, BoolArg, ListArg, StrArg, LogLevel
from cascade_at.context.model_context import Context
from cascade_at.core.log import get_loggers, LEVELS
from cascade_at.dismod.api.dismod_io import DismodIO

LOG = get_loggers(__name__)


ARG_LIST = ArgumentList([
    ModelVersionID(),
    ListArg('--locations', help='The locations to pull mulcov statistics from', type=int, required=True),
    ListArg('--sexes', help='The sexes to pull mulcov statistics from', type=int, required=True),
    StrArg('--outfile-name', help='Filepath where mulcov statistics will be saved', required=False, default='mulcov_stats'),
    BoolArg('--sample', help='If true, the results will be pulled from the sample table rather'
                             'than the fit_var table'),
    BoolArg('--mean', help='Whether or not to compute the mean'),
    BoolArg('--std', help='Whether or not to compute the standard deviation'),
    ListArg('--quantile', help='Quantiles to compute', type=float),
    LogLevel()
])


def common_covariate_names(dbs):
    return set.intersection(
        *map(set, [d.covariate.c_covariate_name.tolist() for d in dbs])
예제 #2
0
import logging
import os
import sys

from cascade_at.executor.args.arg_utils import ArgumentList
from cascade_at.executor.args.args import ModelVersionID, LogLevel
from cascade_at.context.model_context import Context
from cascade_at.core.log import get_loggers, LEVELS

LOG = get_loggers(__name__)

ARG_LIST = ArgumentList([ModelVersionID(), LogLevel()])


def cleanup(model_version_id: int) -> None:
    """
    Delete all databases (.db) files attached to a model version.

    Parameters
    ----------
    model_version_id
        The model version ID to delete databases for
    """
    context = Context(model_version_id=model_version_id)

    for root, dirs, files in os.walk(context.database_dir):
        for f in files:
            if f.endswith(".db"):
                file = context.database_dir / root / f
                LOG.info(f"Deleting {file}.")
                os.remove(file)
예제 #3
0
def test_model_version_id():
    a = ModelVersionID()
    assert a._flag == '--model-version-id'
    assert a._parser_kwargs['type'] == int
    assert a._parser_kwargs['required']
예제 #4
0
def test_argument_list_task_args():
    arg1 = IntArg('--foo')
    al = ArgumentList([arg1, ModelVersionID()])
    assert al.task_args == ['model_version_id']
    assert al.node_args == ['foo']