Exemplo n.º 1
0
def test_query_df_schema():
    agent = Agent(df, schema)
    q = 'how many people died of stomach cancer in 2011'
    sql = 'SELECT SUM(death_count) FROM cancer_death WHERE cancer_site = "Stomach" AND year = "2011" '
    res = agent.get_query(q)
    if res.strip() != sql.strip():
        raise AssertionError
Exemplo n.º 2
0
def test_query_df():
    agent = Agent(df)
    q = 'Get me the average age of stomach cancer deaths'
    sql = 'SELECT AVG(death_count) FROM dataframe WHERE cancer_site = "Stomach" '
    res = agent.get_query(q)
    if res.strip() != sql.strip():
        raise AssertionError
Exemplo n.º 3
0
def test_query_df_schema():
    agent = Agent(df, schema)
    q = 'how many people died of stomach cancer in 2011'
    sql = "SELECT SUM(death_count) FROM cancer_death WHERE cancer_site = 'Stomach' AND year = '2011' "
    res = agent.get_query(q)
    if res.strip() != sql.strip():
        print("question", q)
        print("predicted", res)
        print("actual", sql)
        raise AssertionError
Exemplo n.º 4
0
def test_query_df():
    agent = Agent(df)
    q = 'Get me the average age of stomach cancer in male'
    sql = "SELECT AVG(age) FROM dataframe WHERE gender = 'Male' AND cancer_site = 'Stomach' "
    res = agent.get_query(q)
    if res.strip() != sql.strip():
        print("question", q)
        print("predicted", res)
        print("actual", sql)
        raise AssertionError
Exemplo n.º 5
0
def test_query_df():
    agent = Agent(df)
    q = 'Get me the average age of stomach cancer deaths'
    sql = 'SELECT AVG(death_count) FROM dataframe WHERE cancer_site = "Stomach" '
    res = agent.get_query(q)
    if res.strip() != sql.strip():
        print("question", q)
        print("predicted", res)
        print("actual", sql)
        raise AssertionError
Exemplo n.º 6
0
def test_query():
    agent=Agent(os.path.join(currpath,"cleaned_data"))
    qmaps={'how many nuclear medicine activities in 2012':'SELECT COUNT(activity_type) FROM activities_data WHERE activity_type_chapter = "Nuclear Medicine" AND year = "2012" ',
            'find me the diseases having above 3000 cases':'SELECT disease FROM communicable_diseases_data WHERE cases  > 3000 ',
            'which are the activities in 2011':'SELECT activity_type FROM activities_data WHERE year = "2011" ',
            'find the maximum number of cases':'SELECT MAX(cases) FROM communicable_diseases_data',
            'Get me the average age of stomach cancer deaths':'SELECT AVG(death_count) FROM cancer_death_data WHERE cancer_site = "Stomach" '
            }
    for q,sql in qmaps.items():
        res=agent.get_query(q)
        assert res.strip() ==sql.strip()
Exemplo n.º 7
0
def test_query_schema():
    agent=Agent(os.path.join(currpath,"cleaned_data"),os.path.join(currpath,"schema"))
    qmaps={'how many people died of stomach cancer in 2011':'SELECT SUM(death_count) FROM cancer_death WHERE cancer_site = "Stomach" AND year = "2011" ',
            'how many deaths of age below 40 had stomach cancer':'SELECT SUM(death_count) FROM cancer_death WHERE cancer_site = "Stomach" AND age  < 40 ',
            'which are the activities in 2011':'SELECT activity_type_chapter FROM activities WHERE year = "2011" ',
            'find the maximum number of cases':'SELECT MAX(cases) FROM communicable_diseases ',
            'Get me the average age of stomach cancer deaths':'SELECT AVG(death_count) FROM cancer_death WHERE cancer_site = "Stomach" '
            }
    for q,sql in qmaps.items():
        res=agent.get_query(q)
        assert res.strip() ==sql.strip()
Exemplo n.º 8
0
def simple_upload(request):
    try:
        if request.method == 'POST' and request.FILES['myfile']:
            myfile = request.FILES['myfile']
            query = request.POST['query']
            df = pd.read_csv(myfile)
            agent = Agent(df)
            database_result = agent.query_db(query)
            database_query = agent.get_query(question=query)
            return render(request, 'agent/simple_upload.html', {
                'database_result': database_result,
                'database_query': database_query
            })
        return render(request, 'agent/simple_upload.html')
    except MultiValueDictKeyError:
        database_result = "Please Upload a CSV"
        return render(request, 'agent/simple_upload.html',{
            'database_result':database_result
        })
Exemplo n.º 9
0
def test_query():
    agent = Agent(os.path.join(currpath, "cleaned_data"))
    qmaps = {
        'how many nuclear medicine activities in 2012':
        "SELECT COUNT(activity_type) FROM activities_data WHERE activity_type_chapter = 'Nuclear Medicine' AND year = '2012' ",
        'find me the diseases having above 3000 cases':
        'SELECT disease FROM communicable_diseases_data WHERE cases  > 3000 ',
        'which are the activities in 2011':
        "SELECT activity_type,activity_type_chapter FROM activities_data WHERE year = '2011' ",
        'find the maximum number of cases':
        "SELECT MAX(cases) FROM communicable_diseases_data",
        'Get me the average of stomach cancer deaths':
        "SELECT AVG(death_count) FROM cancer_death_data WHERE cancer_site = 'Stomach' ",
        'how many activities between 2011 and 2014':
        "SELECT COUNT(activity_type) FROM activities_data WHERE year  BETWEEN 2011 AND 2014"
    }
    for q, sql in qmaps.items():
        res = agent.get_query(q)
        if res.strip() != sql.strip():
            print("question", q)
            print("predicted", res)
            print("actual", sql)
            raise AssertionError
Exemplo n.º 10
0
def test_query_schema():
    agent = Agent(os.path.join(currpath, "cleaned_data"),
                  os.path.join(currpath, "schema"))
    qmaps = {
        'how many people died of stomach cancer in 2011':
        "SELECT SUM(death_count) FROM cancer_death WHERE cancer_site = 'Stomach' AND year = '2011' ",
        'how many deaths of age below 40 had stomach cancer':
        "SELECT SUM(death_count) FROM cancer_death WHERE cancer_site = 'Stomach' AND age  < 40 ",
        'which are the activities in 2011':
        "SELECT activity_type_chapter,activity_type FROM activities WHERE year = '2011' ",
        'find the maximum number of cases':
        "SELECT MAX(cases) FROM communicable_diseases ",
        'Get me the average of stomach cancer deaths':
        "SELECT AVG(death_count) FROM cancer_death WHERE cancer_site = 'Stomach' ",
        'how many people between age 30 and 40 died of stomach cancer':
        "SELECT SUM(death_count) FROM cancer_death WHERE cancer_site = 'Stomach' AND age  BETWEEN 30 AND 40 "
    }
    for q, sql in qmaps.items():
        res = agent.get_query(q)
        if res.strip() != sql.strip():
            print("question", q)
            print("predicted", res)
            print("actual", sql)
            raise AssertionError