Пример #1
0
def test_get_close_db(app):
    with app.app_context():
        db = get_db()
        assert db is get_db()

    with pytest.raises(sqlite3.ProgrammingError) as e:
        db.execute('SELECT 1')

    assert 'closed' in str(e.value)
Пример #2
0
def login():
    """Log in a registered user by adding the user id to the session."""
    if request.method == "POST":
        username = request.form["username"]
        password = request.form["password"]

        db = get_db()
        error = None
        user = db.execute("SELECT * FROM user WHERE username = ?",
                          (username, )).fetchone()

        if user is None:
            error = "Incorrect username."
        elif not check_password_hash(user["password"], password):
            error = "Incorrect password."

        if error is None:
            # store the user id in a new session and return to the index
            session.clear()
            session["user_id"] = user["id"]
            return redirect(url_for("index"))

        flash(error)

    return render_template("auth/login.html")
Пример #3
0
def app():
    db_fd, db_path = tempfile.mkstemp()

    app = create_app({
        'TESTING': True,
        'DATABASE': db_path,
    })

    with app.app_context():
        init_db()
        get_db().executescript(_data_sql)

    yield app

    os.close(db_fd)
    os.unlink(db_path)
Пример #4
0
def load_logged_in_user():
    """If a user id is stored in the session, load the user object from
    the database into ``g.user``."""
    user_id = session.get("user_id")
    if user_id is None:
        g.user = None
    else:
        g.user = (get_db().execute("SELECT * FROM user WHERE id = ?",
                                   (user_id, )).fetchone())
Пример #5
0
def test_register(client, app):
    assert client.get('/auth/register').status_code == 200
    response = client.post(
        '/auth/register', data={'username': '******', 'password': '******'}
    )
    assert 'http://localhost/auth/login' == response.headers['Location']

    with app.app_context():
        assert get_db().execute(
            "select * from user where username = '******'",
        ).fetchone() is not None
Пример #6
0
def test_n_periods(app):
    '''Checks to make sure
    all treatment levels have the same number of periods. Should move this
    test to tests package.'''
    with app.app_context():
        db = get_db()
        treatment_levels = get_treatment_levels()
        for treatment_level in enumerate(treatment_levels):
            n_periods = db.execute(
                "SELECT MAX(PERIOD) FROM experiment_data where ID = ?;",
                (treatment_level[1], )).fetchone()[0]
            if treatment_level[0] == 0:
                test = n_periods
            else:
                assert n_periods == test
Пример #7
0
def register():
    """Register a new user.

    Validates that the username is not already taken. Hashes the
    password for security.
    """
    db = get_db()
    treatment_levels = get_treatment_levels()
    MIN_TREATMENT_LEVEL = treatment_levels[0]
    MAX_TREATMENT_LEVEL = treatment_levels[1]

    if request.method == "POST":
        username = request.form["username"]
        password = request.form["password"]
        # ip_address = request.environ['REMOTE_ADDR']

        # assign a random treatment level

        treatment_level = np.random.randint(MIN_TREATMENT_LEVEL,
                                            MAX_TREATMENT_LEVEL)

        error = None

        if not username:
            error = "Username is required."
        elif not password:
            error = "Password is required."
        elif (db.execute("SELECT id FROM user WHERE username = ?",
                         (username, )).fetchone() is not None):
            error = "User {0} is already registered.".format(username)

        if error is None:
            # the name is available, store it in the database and go to
            # the login page
            db.execute(
                "INSERT INTO user (username, password,"
                " slide_number, current_stage, treatment_level,"
                " simulation_period, calculator_count)"
                " VALUES (?, ?, ?, ?, ?, ?, ?)",
                (username, generate_password_hash(password), 1, 'demographics',
                 treatment_level, 1, 0))
            db.commit()
            flash("Registration Successful!")
            return redirect(url_for("auth.login"))

        flash(error)

    return render_template("auth/register.html")
Пример #8
0
def index():
    """Home page. The user will see an overall welcome message, then 
    be able to flip between slides to see instructions for the simulation."""

    db = get_db()
    user_data = db.execute(
        "SELECT * FROM user WHERE id = ?",
        (session["user_id"],)
    ).fetchone()
    user_slide_number = user_data['slide_number']
    fig_url = blog_functions.get_slide_url(user_slide_number)
    slide_recording = blog_functions.get_rec_url(user_slide_number)

    if request.method == 'POST':
        if request.form.get('action') == 'Next Slide':
            if user_slide_number < N_SLIDES:
                db.execute("UPDATE user"
                        " SET slide_number = ?"
                        " WHERE id = ?;",
                        (user_slide_number + 1,
                        session["user_id"]))
                db.commit()
                return redirect(url_for('blog.index'))

        if request.form.get('action') == 'Previous Slide':
            if user_slide_number > 1:
                db.execute("UPDATE user"
                        " SET slide_number = ?"
                        " WHERE id = ?;",
                        (user_slide_number - 1,
                        session["user_id"]))
                db.commit()
                return redirect(url_for('blog.index'))

    return render_template("blog/index.html", 
                            fig_url=fig_url, 
                            slide_caption=SLIDE_CAPTION_DICT[user_slide_number],
                            slide_recording=slide_recording,
                            user_slide_number=user_slide_number)
Пример #9
0
def survey():
    """Survey home page."""
    N_SIMULATION_PERIODS = get_n_periods()
    db = get_db()
    user_data = db.execute(
        "SELECT * FROM user WHERE id = ?",
        (session["user_id"],)
    ).fetchone()
    user_stage = user_data['current_stage']
    simulation_period = user_data['simulation_period']
    user_treatment_level = user_data['treatment_level']
    display_dict = {}

    if user_stage == 'simulation':
        fig_url = \
            blog_functions.get_fig_url(user_treatment_level, simulation_period)

        experiment_data, rec_param_demand_data = \
            blog_functions.get_experiment_data(
                db, simulation_period, user_treatment_level)

        rec_param_demand_data_cols = ['Q_rec', 'v', 'p']
        display_dict.update({x: int(rec_param_demand_data[x].tolist()[0])
                                    for x in rec_param_demand_data_cols})
        show_recs = True

        calc_decision_suffixes = ['_Q']
        calc_decision_list = [
            x + y for x in ['calc', 'decision'] for y in calc_decision_suffixes
        ]
        display_dict.update(
            {x: 0 for x in calc_decision_list})
        display_dict.update(
            {'calc_errors': [],
            'calc_n_errors': 0,
            'decision_errors': [],
            'decision_n_errors': 0,
            'expected_profit': 0}
        )

        # need an empty dataframe before history is made
        temp_display_df_cols = ['Period',  
                    'Ordered From Supplier', 
                    'Demand',
                    'Profit ($)']

        if ((simulation_period >= 2) 
            & (simulation_period <= N_SIMULATION_PERIODS)):
            # get the relevant historical data and display it as a table
            temp_exp_df = experiment_data.loc[
                (experiment_data['ID'] == user_treatment_level)
                & (experiment_data['Period'] < simulation_period)][
                ['Period', 'Demand']]

            # now get the ful contracts table for this user
            temp_user_contracts_df = read_sql_query(
                "SELECT * FROM contracts WHERE user_id = "\
                + str(session["user_id"]), con=db
            )

            temp_display_df = temp_exp_df.merge(
                temp_user_contracts_df, 
                how='left', 
                left_on='Period', right_on='simulation_period')

            temp_display_df = blog_functions.get_contract_metrics(
                temp_display_df, 
                display_dict['v'], 
                display_dict['p'], 
                'Demand', 
                'q' 
            ) 

            temp_display_df.rename(
                {'q': 'Ordered From Supplier',
                'sales': 'Sales (Units)',
                'lost_sales': 'Lost Sales (Units)',
                'profit': 'Profit ($)'}, axis=1, inplace=True)
                
            cols = ['Period', 'Demand', 
                    'Ordered From Supplier', 
                    'Profit ($)']

            temp_display_df = temp_display_df[temp_display_df_cols]
        else:
            temp_display_df = DataFrame(columns=temp_display_df_cols)

    if request.method == 'GET':
        if user_stage == 'simulation':
            if request.args.get('action') == 'Calculate': 
                validate = blog_functions.validate_input()
                error_list = blog_functions.do_validate_instructions(
                    validate, display_dict, request, 'calc_Q', 'calc'
                )
                    
                if len(error_list) == 0:
                    expected_profit = blog_functions.get_expected_profit(
                        int(display_dict['v']), 
                        int(display_dict['p']),
                        int(request.args.get('calc_Q'))
                    ) 

                    display_dict.update({
                        'expected_profit': np_round(expected_profit, 2)
                        })

                    update_calculator_count = user_data['calculator_count'] + 1

                    db.execute("UPDATE user"
                        " SET calculator_count = ?"
                        " WHERE id = ?;",
                        (update_calculator_count,
                        session["user_id"]))
                    db.commit()
                
                return render_template("blog/" + user_stage + ".html",
                    display_dict=display_dict,
                    simulation_period=simulation_period,
                    historical_table=temp_display_df.to_html(
                        index=False, 
                        justify='left'),
                    fig_url=fig_url,
                    show_recs=show_recs)
            
            if simulation_period <= N_SIMULATION_PERIODS:
                return render_template("blog/" + user_stage + ".html",
                    display_dict=display_dict,
                    simulation_period=simulation_period,
                    historical_table=temp_display_df.to_html(
                        index=False, 
                        justify='left'),
                    fig_url=fig_url,
                    show_recs=show_recs)
            
            else:
                db.execute("UPDATE user"
                        " SET current_stage = ?"
                        " WHERE id = ?;",
                        (shuttle_dict[user_stage], session["user_id"]))
                db.commit()
        
        if user_stage == 'risk':
            return render_template("blog/" + user_stage + ".html",
                    question_dict=QUESTION_DICT,
                    risk_preference_dict=RISK_PREFERENCE_DICT)

        if user_stage == 'risk_answer':
            given_answer = RISK_PREFERENCE_DICT['RP9'][user_data['RP9']]
            answer_list = ['You chose ' + given_answer + '.']
            answer_list.extend(
                ['The computer chose ' + UNFORTUNATE_RP9[user_data['RP9']][0] \
                    + ' points.'])
            answer_list.extend(['If you would have chosen "' + \
            RISK_PREFERENCE_DICT['RP9'][1 - user_data['RP9']] + 
            '", you would have won ' + \
            UNFORTUNATE_RP9[user_data['RP9']][1] + ' points!'])
            return render_template("blog/" + user_stage + ".html",
                answer_list=answer_list)

        return render_template("blog/" + user_stage + ".html")
    
    if request.method == 'POST':
        if user_stage == 'demographics':
            gender = request.form.get('gender')
            age = request.form.get('age')
            sc = request.form.get('sc')
            procurement = request.form.get('procurement')

            db.execute("UPDATE user"
                        " SET gender = ?, age = ?, sc_exp = ?,"
                        " procurement_exp = ?, current_stage = ?"
                        " WHERE id = ?;",
                        (gender, age, sc, procurement, 
                        shuttle_dict[user_stage], session["user_id"]))
            db.commit()

        if user_stage == 'cognitive':
            db.execute("UPDATE user"
                        " SET CRT1 = ?, CRT2 = ?, CRT3 = ?,"
                        " CRT4 = ?, CRT5 = ?, CRT6 = ?, CRT7 = ?,"
                        " current_stage = ?, enter_simulation = ?"
                        " WHERE id = ?;",
                        (request.form.get("CRT1"), 
                        request.form.get("CRT2"),
                        request.form.get("CRT3"),
                        request.form.get("CRT4"),
                        request.form.get("CRT5"),
                        request.form.get("CRT6"),
                        request.form.get("CRT7"), 
                        shuttle_dict[user_stage], 
                        datetime.now(),
                        session["user_id"]))
            db.commit()

        if user_stage == 'simulation':
            if simulation_period <= N_SIMULATION_PERIODS:
                validate = blog_functions.validate_input()
                error_list = blog_functions.do_validate_instructions(
                    validate, display_dict, request, 'decision_Q', 'decision'
                )

                if len(error_list) == 0:
                    db.execute("INSERT INTO contracts"
                        "(user_id, simulation_period, q, time_stamp,"
                        " calculator_count)"
                        "VALUES (?, ?, ?, ?, ?);",
                        (session["user_id"], 
                        simulation_period, 
                        int(request.form.get('decision_Q')),
                        datetime.now(), 
                        user_data['calculator_count'])
                    )
                    db.commit()

                    update_simulation_period = simulation_period + 1

                    if simulation_period < N_SIMULATION_PERIODS:
                        db.execute("UPDATE user"
                                " SET simulation_period = ?"
                                " WHERE id = ?",
                                (update_simulation_period,
                                session["user_id"]))
                        db.commit()
                    else:
                        # go to the risk survey
                        db.execute("UPDATE user"
                                " SET current_stage = ?"
                                " WHERE id = ?;",
                                (shuttle_dict[user_stage], session["user_id"]))
                        db.commit()
                else:
                    return render_template("blog/" + user_stage + ".html",
                                    display_dict=display_dict,
                                    simulation_period=simulation_period,
                                    historical_table=temp_display_df.to_html(
                                                    index=False, 
                                                    justify='left'),
                                    fig_url=fig_url,
                                    show_recs=show_recs)
        
        if user_stage == 'risk':
            fin_answer_dict = {x: request.form.get(x) 
                for x in QUESTION_DICT.keys()
            }
            risk_answer_dict = {x: request.form.get(x) 
                for x in RISK_PREFERENCE_DICT.keys()
            }
            
            all_updates = [shuttle_dict[user_stage]]
            all_updates.extend([int(fin_answer_dict[x]) 
                for x in fin_answer_dict.keys()])
            all_updates.extend([int(risk_answer_dict[x]) 
                for x in risk_answer_dict.keys()])
            all_updates.extend([session["user_id"]])
            
            db.execute("UPDATE user"
                    " SET current_stage = ?,"
                    " Fin1 = ?, Fin2 = ?, Fin3 = ?, Fin4 = ?, Fin5 = ?, Fin6 = ?,"
                    " RP1 = ?, RP2 = ?, RP3 = ?, RP4 = ?, RP5 = ?, RP6 = ?,"
                    " RP7 = ?, RP8 = ?, RP9 = ?"
                    "WHERE id = ?;",
                    tuple(all_updates)
                    )
            db.commit()

        if user_stage == 'risk_answer':
            answer = request.form.get('RP10')
            db.execute("UPDATE user"
                    " SET current_stage = ?,"
                    " RP10 = ?"
                    "WHERE id = ?;",
                    (shuttle_dict[user_stage], answer, session["user_id"])
                    )
            db.commit()

        if user_stage == 'thankyou':
            feedback = request.form.get('feedback_input')
            db.execute("UPDATE user"
                        " SET feedback = ?, current_stage = ?"
                        " WHERE id = ?;",
                        (feedback, shuttle_dict[user_stage], 
                        session["user_id"]))
            db.commit()
            session.clear()
            return redirect(url_for("blog.survey"))
                
        return redirect(url_for("blog.survey"))
Пример #10
0
def test_survey(client, app, auth):
    # with client:
    #     auth.login()
    #     with app.app_context():
    #         db = get_db()

    #         # update the user stage to demographics
    #         db.execute("UPDATE user"
    #                     " SET current_stage = ?, simulation_period = ?"
    #                     " WHERE id = ?",
    #                     ('demographics', 1, session["user_id"]))
    #         db.commit()

    # we need to test for each treatment level
    with app.app_context():
        min_treatment_level, max_treatment_level = get_treatment_levels()
        n_simulation_periods = get_n_periods()
    for treatment_level in range(min_treatment_level, max_treatment_level + 1):
        with client:
            auth.login()
            # update the treatment level
            with app.app_context():
                db = get_db()
                db.execute(
                    "UPDATE user"
                    " SET current_stage = ?,"
                    " simulation_period = ?, treatment_level = ?"
                    " WHERE id = ?",
                    ('demographics', 1, treatment_level, session["user_id"]))
                db.commit()
        with client:
            client.post('/survey')
            with app.app_context():
                db = get_db()
                stage = db.execute(
                    "SELECT current_stage FROM user WHERE id = ?",
                    (session["user_id"], )).fetchone()[0]
                assert stage == 'cognitive'
        with client:
            client.post('/survey')
            with app.app_context():
                db = get_db()
                temp_data = db.execute(
                    "SELECT current_stage, simulation_period"
                    " FROM user WHERE id = ?",
                    (session["user_id"], )).fetchone()
                assert temp_data[0] == 'simulation'
                assert temp_data[1] == 1

        for period in range(1, n_simulation_periods + 1):
            with client:
                client.post('/survey', data=dict(decision_Q=5))

                with app.app_context():
                    db = get_db()
                    temp_data = db.execute(
                        "SELECT current_stage, simulation_period"
                        " FROM user WHERE id = ?",
                        (session["user_id"], )).fetchone()
                    if period < n_simulation_periods:
                        assert temp_data[0] == 'simulation'
                        assert temp_data[1] == period + 1
                    else:
                        assert temp_data[0] == 'risk'

        with client:
            answer_dict = {'Fin' + str(x): 5 for x in range(1, 7)}
            answer_dict.update({'RP' + str(x): 0 for x in range(1, 10)})
            client.post('/survey', data=answer_dict)
            with app.app_context():
                db = get_db()
                stage = db.execute(
                    "SELECT current_stage FROM user WHERE id = ?",
                    (session["user_id"], )).fetchone()[0]
                assert stage == 'risk_answer'

        with client:
            client.post('/survey', data=dict(RP10=5))
            with app.app_context():
                db = get_db()
                stage = db.execute(
                    "SELECT current_stage FROM user WHERE id = ?",
                    (session["user_id"], )).fetchone()[0]
                assert stage == 'thankyou'

        with client:
            client.post('/survey', data=dict(feedback_input='test_feedback'))
            with app.app_context():
                assert 'user_id' not in session