def test_get_close_db(app): with app.app_context(): db = get_db() assert db is get_db() with pytest.raises(sqlite3.ProgrammingError) as e: db.execute('SELECT 1') assert 'closed' in str(e.value)
def login(): """Log in a registered user by adding the user id to the session.""" if request.method == "POST": username = request.form["username"] password = request.form["password"] db = get_db() error = None user = db.execute("SELECT * FROM user WHERE username = ?", (username, )).fetchone() if user is None: error = "Incorrect username." elif not check_password_hash(user["password"], password): error = "Incorrect password." if error is None: # store the user id in a new session and return to the index session.clear() session["user_id"] = user["id"] return redirect(url_for("index")) flash(error) return render_template("auth/login.html")
def app(): db_fd, db_path = tempfile.mkstemp() app = create_app({ 'TESTING': True, 'DATABASE': db_path, }) with app.app_context(): init_db() get_db().executescript(_data_sql) yield app os.close(db_fd) os.unlink(db_path)
def load_logged_in_user(): """If a user id is stored in the session, load the user object from the database into ``g.user``.""" user_id = session.get("user_id") if user_id is None: g.user = None else: g.user = (get_db().execute("SELECT * FROM user WHERE id = ?", (user_id, )).fetchone())
def test_register(client, app): assert client.get('/auth/register').status_code == 200 response = '/auth/register', data={'username': '******', 'password': '******'} ) assert 'http://localhost/auth/login' == response.headers['Location'] with app.app_context(): assert get_db().execute( "select * from user where username = '******'", ).fetchone() is not None
def test_n_periods(app): '''Checks to make sure all treatment levels have the same number of periods. Should move this test to tests package.''' with app.app_context(): db = get_db() treatment_levels = get_treatment_levels() for treatment_level in enumerate(treatment_levels): n_periods = db.execute( "SELECT MAX(PERIOD) FROM experiment_data where ID = ?;", (treatment_level[1], )).fetchone()[0] if treatment_level[0] == 0: test = n_periods else: assert n_periods == test
def register(): """Register a new user. Validates that the username is not already taken. Hashes the password for security. """ db = get_db() treatment_levels = get_treatment_levels() MIN_TREATMENT_LEVEL = treatment_levels[0] MAX_TREATMENT_LEVEL = treatment_levels[1] if request.method == "POST": username = request.form["username"] password = request.form["password"] # ip_address = request.environ['REMOTE_ADDR'] # assign a random treatment level treatment_level = np.random.randint(MIN_TREATMENT_LEVEL, MAX_TREATMENT_LEVEL) error = None if not username: error = "Username is required." elif not password: error = "Password is required." elif (db.execute("SELECT id FROM user WHERE username = ?", (username, )).fetchone() is not None): error = "User {0} is already registered.".format(username) if error is None: # the name is available, store it in the database and go to # the login page db.execute( "INSERT INTO user (username, password," " slide_number, current_stage, treatment_level," " simulation_period, calculator_count)" " VALUES (?, ?, ?, ?, ?, ?, ?)", (username, generate_password_hash(password), 1, 'demographics', treatment_level, 1, 0)) db.commit() flash("Registration Successful!") return redirect(url_for("auth.login")) flash(error) return render_template("auth/register.html")
def index(): """Home page. The user will see an overall welcome message, then be able to flip between slides to see instructions for the simulation.""" db = get_db() user_data = db.execute( "SELECT * FROM user WHERE id = ?", (session["user_id"],) ).fetchone() user_slide_number = user_data['slide_number'] fig_url = blog_functions.get_slide_url(user_slide_number) slide_recording = blog_functions.get_rec_url(user_slide_number) if request.method == 'POST': if request.form.get('action') == 'Next Slide': if user_slide_number < N_SLIDES: db.execute("UPDATE user" " SET slide_number = ?" " WHERE id = ?;", (user_slide_number + 1, session["user_id"])) db.commit() return redirect(url_for('blog.index')) if request.form.get('action') == 'Previous Slide': if user_slide_number > 1: db.execute("UPDATE user" " SET slide_number = ?" " WHERE id = ?;", (user_slide_number - 1, session["user_id"])) db.commit() return redirect(url_for('blog.index')) return render_template("blog/index.html", fig_url=fig_url, slide_caption=SLIDE_CAPTION_DICT[user_slide_number], slide_recording=slide_recording, user_slide_number=user_slide_number)
def survey(): """Survey home page.""" N_SIMULATION_PERIODS = get_n_periods() db = get_db() user_data = db.execute( "SELECT * FROM user WHERE id = ?", (session["user_id"],) ).fetchone() user_stage = user_data['current_stage'] simulation_period = user_data['simulation_period'] user_treatment_level = user_data['treatment_level'] display_dict = {} if user_stage == 'simulation': fig_url = \ blog_functions.get_fig_url(user_treatment_level, simulation_period) experiment_data, rec_param_demand_data = \ blog_functions.get_experiment_data( db, simulation_period, user_treatment_level) rec_param_demand_data_cols = ['Q_rec', 'v', 'p'] display_dict.update({x: int(rec_param_demand_data[x].tolist()[0]) for x in rec_param_demand_data_cols}) show_recs = True calc_decision_suffixes = ['_Q'] calc_decision_list = [ x + y for x in ['calc', 'decision'] for y in calc_decision_suffixes ] display_dict.update( {x: 0 for x in calc_decision_list}) display_dict.update( {'calc_errors': [], 'calc_n_errors': 0, 'decision_errors': [], 'decision_n_errors': 0, 'expected_profit': 0} ) # need an empty dataframe before history is made temp_display_df_cols = ['Period', 'Ordered From Supplier', 'Demand', 'Profit ($)'] if ((simulation_period >= 2) & (simulation_period <= N_SIMULATION_PERIODS)): # get the relevant historical data and display it as a table temp_exp_df = experiment_data.loc[ (experiment_data['ID'] == user_treatment_level) & (experiment_data['Period'] < simulation_period)][ ['Period', 'Demand']] # now get the ful contracts table for this user temp_user_contracts_df = read_sql_query( "SELECT * FROM contracts WHERE user_id = "\ + str(session["user_id"]), con=db ) temp_display_df = temp_exp_df.merge( temp_user_contracts_df, how='left', left_on='Period', right_on='simulation_period') temp_display_df = blog_functions.get_contract_metrics( temp_display_df, display_dict['v'], display_dict['p'], 'Demand', 'q' ) temp_display_df.rename( {'q': 'Ordered From Supplier', 'sales': 'Sales (Units)', 'lost_sales': 'Lost Sales (Units)', 'profit': 'Profit ($)'}, axis=1, inplace=True) cols = ['Period', 'Demand', 'Ordered From Supplier', 'Profit ($)'] temp_display_df = temp_display_df[temp_display_df_cols] else: temp_display_df = DataFrame(columns=temp_display_df_cols) if request.method == 'GET': if user_stage == 'simulation': if request.args.get('action') == 'Calculate': validate = blog_functions.validate_input() error_list = blog_functions.do_validate_instructions( validate, display_dict, request, 'calc_Q', 'calc' ) if len(error_list) == 0: expected_profit = blog_functions.get_expected_profit( int(display_dict['v']), int(display_dict['p']), int(request.args.get('calc_Q')) ) display_dict.update({ 'expected_profit': np_round(expected_profit, 2) }) update_calculator_count = user_data['calculator_count'] + 1 db.execute("UPDATE user" " SET calculator_count = ?" " WHERE id = ?;", (update_calculator_count, session["user_id"])) db.commit() return render_template("blog/" + user_stage + ".html", display_dict=display_dict, simulation_period=simulation_period, historical_table=temp_display_df.to_html( index=False, justify='left'), fig_url=fig_url, show_recs=show_recs) if simulation_period <= N_SIMULATION_PERIODS: return render_template("blog/" + user_stage + ".html", display_dict=display_dict, simulation_period=simulation_period, historical_table=temp_display_df.to_html( index=False, justify='left'), fig_url=fig_url, show_recs=show_recs) else: db.execute("UPDATE user" " SET current_stage = ?" " WHERE id = ?;", (shuttle_dict[user_stage], session["user_id"])) db.commit() if user_stage == 'risk': return render_template("blog/" + user_stage + ".html", question_dict=QUESTION_DICT, risk_preference_dict=RISK_PREFERENCE_DICT) if user_stage == 'risk_answer': given_answer = RISK_PREFERENCE_DICT['RP9'][user_data['RP9']] answer_list = ['You chose ' + given_answer + '.'] answer_list.extend( ['The computer chose ' + UNFORTUNATE_RP9[user_data['RP9']][0] \ + ' points.']) answer_list.extend(['If you would have chosen "' + \ RISK_PREFERENCE_DICT['RP9'][1 - user_data['RP9']] + '", you would have won ' + \ UNFORTUNATE_RP9[user_data['RP9']][1] + ' points!']) return render_template("blog/" + user_stage + ".html", answer_list=answer_list) return render_template("blog/" + user_stage + ".html") if request.method == 'POST': if user_stage == 'demographics': gender = request.form.get('gender') age = request.form.get('age') sc = request.form.get('sc') procurement = request.form.get('procurement') db.execute("UPDATE user" " SET gender = ?, age = ?, sc_exp = ?," " procurement_exp = ?, current_stage = ?" " WHERE id = ?;", (gender, age, sc, procurement, shuttle_dict[user_stage], session["user_id"])) db.commit() if user_stage == 'cognitive': db.execute("UPDATE user" " SET CRT1 = ?, CRT2 = ?, CRT3 = ?," " CRT4 = ?, CRT5 = ?, CRT6 = ?, CRT7 = ?," " current_stage = ?, enter_simulation = ?" " WHERE id = ?;", (request.form.get("CRT1"), request.form.get("CRT2"), request.form.get("CRT3"), request.form.get("CRT4"), request.form.get("CRT5"), request.form.get("CRT6"), request.form.get("CRT7"), shuttle_dict[user_stage],, session["user_id"])) db.commit() if user_stage == 'simulation': if simulation_period <= N_SIMULATION_PERIODS: validate = blog_functions.validate_input() error_list = blog_functions.do_validate_instructions( validate, display_dict, request, 'decision_Q', 'decision' ) if len(error_list) == 0: db.execute("INSERT INTO contracts" "(user_id, simulation_period, q, time_stamp," " calculator_count)" "VALUES (?, ?, ?, ?, ?);", (session["user_id"], simulation_period, int(request.form.get('decision_Q')),, user_data['calculator_count']) ) db.commit() update_simulation_period = simulation_period + 1 if simulation_period < N_SIMULATION_PERIODS: db.execute("UPDATE user" " SET simulation_period = ?" " WHERE id = ?", (update_simulation_period, session["user_id"])) db.commit() else: # go to the risk survey db.execute("UPDATE user" " SET current_stage = ?" " WHERE id = ?;", (shuttle_dict[user_stage], session["user_id"])) db.commit() else: return render_template("blog/" + user_stage + ".html", display_dict=display_dict, simulation_period=simulation_period, historical_table=temp_display_df.to_html( index=False, justify='left'), fig_url=fig_url, show_recs=show_recs) if user_stage == 'risk': fin_answer_dict = {x: request.form.get(x) for x in QUESTION_DICT.keys() } risk_answer_dict = {x: request.form.get(x) for x in RISK_PREFERENCE_DICT.keys() } all_updates = [shuttle_dict[user_stage]] all_updates.extend([int(fin_answer_dict[x]) for x in fin_answer_dict.keys()]) all_updates.extend([int(risk_answer_dict[x]) for x in risk_answer_dict.keys()]) all_updates.extend([session["user_id"]]) db.execute("UPDATE user" " SET current_stage = ?," " Fin1 = ?, Fin2 = ?, Fin3 = ?, Fin4 = ?, Fin5 = ?, Fin6 = ?," " RP1 = ?, RP2 = ?, RP3 = ?, RP4 = ?, RP5 = ?, RP6 = ?," " RP7 = ?, RP8 = ?, RP9 = ?" "WHERE id = ?;", tuple(all_updates) ) db.commit() if user_stage == 'risk_answer': answer = request.form.get('RP10') db.execute("UPDATE user" " SET current_stage = ?," " RP10 = ?" "WHERE id = ?;", (shuttle_dict[user_stage], answer, session["user_id"]) ) db.commit() if user_stage == 'thankyou': feedback = request.form.get('feedback_input') db.execute("UPDATE user" " SET feedback = ?, current_stage = ?" " WHERE id = ?;", (feedback, shuttle_dict[user_stage], session["user_id"])) db.commit() session.clear() return redirect(url_for("blog.survey")) return redirect(url_for("blog.survey"))
def test_survey(client, app, auth): # with client: # auth.login() # with app.app_context(): # db = get_db() # # update the user stage to demographics # db.execute("UPDATE user" # " SET current_stage = ?, simulation_period = ?" # " WHERE id = ?", # ('demographics', 1, session["user_id"])) # db.commit() # we need to test for each treatment level with app.app_context(): min_treatment_level, max_treatment_level = get_treatment_levels() n_simulation_periods = get_n_periods() for treatment_level in range(min_treatment_level, max_treatment_level + 1): with client: auth.login() # update the treatment level with app.app_context(): db = get_db() db.execute( "UPDATE user" " SET current_stage = ?," " simulation_period = ?, treatment_level = ?" " WHERE id = ?", ('demographics', 1, treatment_level, session["user_id"])) db.commit() with client:'/survey') with app.app_context(): db = get_db() stage = db.execute( "SELECT current_stage FROM user WHERE id = ?", (session["user_id"], )).fetchone()[0] assert stage == 'cognitive' with client:'/survey') with app.app_context(): db = get_db() temp_data = db.execute( "SELECT current_stage, simulation_period" " FROM user WHERE id = ?", (session["user_id"], )).fetchone() assert temp_data[0] == 'simulation' assert temp_data[1] == 1 for period in range(1, n_simulation_periods + 1): with client:'/survey', data=dict(decision_Q=5)) with app.app_context(): db = get_db() temp_data = db.execute( "SELECT current_stage, simulation_period" " FROM user WHERE id = ?", (session["user_id"], )).fetchone() if period < n_simulation_periods: assert temp_data[0] == 'simulation' assert temp_data[1] == period + 1 else: assert temp_data[0] == 'risk' with client: answer_dict = {'Fin' + str(x): 5 for x in range(1, 7)} answer_dict.update({'RP' + str(x): 0 for x in range(1, 10)})'/survey', data=answer_dict) with app.app_context(): db = get_db() stage = db.execute( "SELECT current_stage FROM user WHERE id = ?", (session["user_id"], )).fetchone()[0] assert stage == 'risk_answer' with client:'/survey', data=dict(RP10=5)) with app.app_context(): db = get_db() stage = db.execute( "SELECT current_stage FROM user WHERE id = ?", (session["user_id"], )).fetchone()[0] assert stage == 'thankyou' with client:'/survey', data=dict(feedback_input='test_feedback')) with app.app_context(): assert 'user_id' not in session