from typing import List, Optional import numpy as np import pandas as pd from cascade_at.executor.args.arg_utils import ArgumentList from cascade_at.executor.args.args import ModelVersionID, BoolArg, ListArg, StrArg, LogLevel from cascade_at.context.model_context import Context from cascade_at.core.log import get_loggers, LEVELS from cascade_at.dismod.api.dismod_io import DismodIO LOG = get_loggers(__name__) ARG_LIST = ArgumentList([ ModelVersionID(), ListArg('--locations', help='The locations to pull mulcov statistics from', type=int, required=True), ListArg('--sexes', help='The sexes to pull mulcov statistics from', type=int, required=True), StrArg('--outfile-name', help='Filepath where mulcov statistics will be saved', required=False, default='mulcov_stats'), BoolArg('--sample', help='If true, the results will be pulled from the sample table rather' 'than the fit_var table'), BoolArg('--mean', help='Whether or not to compute the mean'), BoolArg('--std', help='Whether or not to compute the standard deviation'), ListArg('--quantile', help='Quantiles to compute', type=float), LogLevel() ]) def common_covariate_names(dbs): return set.intersection( *map(set, [d.covariate.c_covariate_name.tolist() for d in dbs])
import logging import os import sys from cascade_at.executor.args.arg_utils import ArgumentList from cascade_at.executor.args.args import ModelVersionID, LogLevel from cascade_at.context.model_context import Context from cascade_at.core.log import get_loggers, LEVELS LOG = get_loggers(__name__) ARG_LIST = ArgumentList([ModelVersionID(), LogLevel()]) def cleanup(model_version_id: int) -> None: """ Delete all databases (.db) files attached to a model version. Parameters ---------- model_version_id The model version ID to delete databases for """ context = Context(model_version_id=model_version_id) for root, dirs, files in os.walk(context.database_dir): for f in files: if f.endswith(".db"): file = context.database_dir / root / f LOG.info(f"Deleting {file}.") os.remove(file)
def test_model_version_id(): a = ModelVersionID() assert a._flag == '--model-version-id' assert a._parser_kwargs['type'] == int assert a._parser_kwargs['required']
def test_argument_list_task_args(): arg1 = IntArg('--foo') al = ArgumentList([arg1, ModelVersionID()]) assert al.task_args == ['model_version_id'] assert al.node_args == ['foo']