예제 #1
0
def model_to_tests(model_name, upload=True, bucket=EMMAA_BUCKET_NAME):
    """Create StatementCheckingTests from model statements."""
    stmts, _ = get_assembled_statements(model_name, bucket=bucket)
    config = load_config_from_s3(model_name, bucket=bucket)
    # Filter statements if needed
    if isinstance(config.get('make_tests'), dict):
        conditions = config['make_tests']['filter']['conditions']
        evid_policy = config['make_tests']['filter']['evid_policy']
        stmts = filter_indra_stmts_by_metadata(stmts, conditions, evid_policy)
    tests = [
        StatementCheckingTest(stmt) for stmt in stmts if all(stmt.agent_list())
    ]
    date_str = make_date_str()
    test_description = (
        f'These tests were generated from the '
        f'{config.get("human_readable_name")} on {date_str[:10]}')
    test_name = f'{config.get("human_readable_name")} model test corpus'
    test_dict = {
        'test_data': {
            'description': test_description,
            'name': test_name
        },
        'tests': tests
    }
    if upload:
        save_tests_to_s3(test_dict, bucket,
                         f'tests/{model_name}_tests_{date_str}.pkl', 'pkl')
    return test_dict
예제 #2
0
파일: test_s3.py 프로젝트: indralab/emmaa
def test_get_assembled_stmts():
    # Local imports are recommended when using moto
    from emmaa.model import get_assembled_statements
    client = setup_bucket(add_mm=True)
    stmts, fkey = get_assembled_statements('test', bucket=TEST_BUCKET_NAME)
    assert len(stmts) == 2, stmts
    assert all([isinstance(stmt, Activation) for stmt in stmts])
예제 #3
0
파일: api.py 프로젝트: kolusask/emmaa
def _load_stmts_from_cache(model):
    stmts, file_key = stmts_cache.get(model, (None, None))
    latest_on_s3 = find_latest_s3_file(EMMAA_BUCKET_NAME,
                                       f'assembled/{model}/statements_',
                                       '.json')
    if file_key != latest_on_s3:
        stmts, file_key = get_assembled_statements(model, EMMAA_BUCKET_NAME)
        stmts_cache[model] = (stmts, file_key)
    else:
        logger.info(f'Loaded assembled stmts for {model} from cache.')
    return stmts
예제 #4
0
def load_model_manager_from_s3(model_name=None,
                               key=None,
                               bucket=EMMAA_BUCKET_NAME):
    # First try find the file from specified key
    if key:
        try:
            model_manager = load_pickle_from_s3(bucket, key)
            if not model_manager.model.assembled_stmts:
                stmts, _ = get_assembled_statements(model_manager.model.name,
                                                    strip_out_date(
                                                        model_manager.date_str,
                                                        'date'),
                                                    bucket=bucket)
                model_manager.model.assembled_stmts = stmts
            return model_manager
        except Exception as e:
            logger.info('Could not load the model manager directly')
            logger.info(e)
            if not model_name:
                model_name = key.split('/')[1]
            date = strip_out_date(key, 'date')
            logger.info('Trying to load model manager from statements')
            try:
                model_manager = ModelManager.load_from_statements(
                    model_name, date=date, bucket=bucket)
                return model_manager
            except Exception as e:
                logger.info('Could not load the model manager from '
                            'statements')
                logger.info(e)
                return None
    # Now try find the latest key for given model
    if model_name:
        # Versioned
        key = find_latest_s3_file(bucket,
                                  f'results/{model_name}/model_manager_',
                                  '.pkl')
        if key is None:
            # Non-versioned
            key = f'results/{model_name}/latest_model_manager.pkl'
        return load_model_manager_from_s3(model_name=model_name,
                                          key=key,
                                          bucket=bucket)
    # Could not find either from key or from model name.
    logger.info('Could not find the model manager.')
    return None
예제 #5
0
def model_to_tests(model_name, upload=True, bucket=EMMAA_BUCKET_NAME):
    """Create StatementCheckingTests from model statements."""
    stmts, _ = get_assembled_statements(model_name, bucket=bucket)
    config = load_config_from_s3(model_name, bucket=bucket)
    tests = [StatementCheckingTest(stmt) for stmt in stmts if
             all(stmt.agent_list())]
    date_str = make_date_str()
    test_description = (
        f'These tests were generated from the '
        f'{config.get("human_readable_name")} on {date_str[:10]}')
    test_name = f'{config.get("human_readable_name")} model test corpus'
    test_dict = {'test_data': {'description': test_description,
                               'name': test_name},
                 'tests': tests}
    if upload:
        save_tests_to_s3(test_dict, bucket,
                         f'tests/{model_name}_tests_{date_str}.pkl', 'pkl')
    return test_dict
예제 #6
0
 def load_from_statements(cls, model_name, mode='local', date=None,
                          bucket=EMMAA_BUCKET_NAME):
     config = load_config_from_s3(model_name, bucket=bucket)
     if date:
         prefix = f'papers/{model_name}/paper_ids_{date}'
     else:
         prefix = f'papers/{model_name}/paper_ids_'
     paper_key = find_latest_s3_file(bucket, prefix, 'json')
     if paper_key:
         paper_ids = load_json_from_s3(bucket, paper_key)
     else:
         paper_ids = None
     model = EmmaaModel(model_name, config, paper_ids)
     # Loading assembled statements to avoid reassembly
     stmts, fname = get_assembled_statements(model_name, date, bucket)
     model.assembled_stmts = stmts
     model.date_str = strip_out_date(fname, 'datetime')
     mm = cls(model, mode=mode)
     return mm
예제 #7
0
    parser.add_argument('-cd',
                        '--chemical_disease',
                        help='Path to ctd chemical disease statements pkl',
                        required=True)
    parser.add_argument('-cg',
                        '--chemical_gene',
                        help='Path to ctd chemical gene statements pkl',
                        required=True)
    parser.add_argument('-gd',
                        '--gene_disease',
                        help='Path to ctd gene disease statements pkl',
                        required=True)
    args = parser.parse_args()

    # Load model statements and tests
    model_stmts, _ = get_assembled_statements('covid19')
    curated_tests, _ = load_tests_from_s3('covid19_curated_tests')
    if isinstance(curated_tests, dict):  # if descriptions were added
        curated_tests = curated_tests['tests']
    mitre_tests, _ = load_tests_from_s3('covid19_mitre_tests')
    if isinstance(mitre_tests, dict):  # if descriptions were added
        mitre_tests = mitre_tests['tests']
    all_test_stmts = [test.stmt for test in curated_tests] + \
        [test.stmt for test in mitre_tests]

    # Load CTD statements
    chem_dis_stmts = ac.load_statements(args.chemical_disease)
    chem_gene_stmts = ac.load_statements(args.chemical_gene)
    gene_dis_stmts = ac.load_statements(args.gene_disease)
    all_ctd_stmts = chem_dis_stmts + chem_gene_stmts + gene_dis_stmts