示例#1
0
def add_stock():
    """add stock route"""
    form = NewStockForm()

    if form.validate_on_submit():

        user_id = current_user.id

        # adds new user_stock
        new_stock = User_Stock.add_stock(
            user_id,
            form.stock_symbol.data,
            form.stock_num.data)

        if new_stock:
            try:
                db.session.commit()
                flash('Stock added', 'success')
                return redirect(url_for('portfolio'))
            # IntegrityError occurs if primary-key error occurs(PK here is a combination of user_id & stock_symbol)
            except IntegrityError:
                flash('Stock already in portfolio', 'warning')
                return redirect(url_for('portfolio'))

        flash('Stock Symbol Not Recognized', 'warning')

    return redirect(url_for('portfolio'))
示例#2
0
    def test_add_stock_symbol_not_valid(self):
        """tests that unvalid stock symbols are not returned a value"""

        new_user_stock = User_Stock.add_stock(9876,
                                              "MSJDSJAKLFLDSJAKLFJLSDKALJFS",
                                              1)
        # finds and adds stock_symbol and stock_name to Stock Model

        self.assertIsNone(new_user_stock)
示例#3
0
    def test_get_users_stocks(self):
        """tests that get_user_stocks function returns a tuple with flaskSQL object, value for current val and initial val"""

        returned = User_Stock.get_users_stocks(user_id=self.u.id)
        # checks a value is returned for total initial val is greater then the default 0
        self.assertTrue(returned[1] > 0)
        # checks a value is returned for total current val is greater then the default 0
        self.assertTrue(returned[2] > 0)
        # checks that the original stock is returned
        self.assertIn(returned[0][0].stock_symbol, self.u_stock.stock_symbol)
示例#4
0
    def test_add_stock_symbol(self):
        """test adding a valid stock symbol works and added to stock model"""
        new_user_stock = User_Stock.add_stock(9876, "MSFT", 1)
        # finds and adds stock_symbol and stock_name to Stock Model
        db.session.add(new_user_stock)
        db.session.commit()

        # checks stock was added to Stock model
        self.assertEqual(new_user_stock.stock_symbol,
                         Stock.query.get("MSFT").stock_symbol)
示例#5
0
    def test_get_users_stocks_invalidID(self):
        """tests that get_user_stocks function returns a tuple with flaskSQL object, value for current val and initial val"""

        returned = User_Stock.get_users_stocks(123555)
        # checks a value is returned for total initial val is greater then the default 0
        self.assertTrue(returned[1] == 0)
        # checks a value is returned for total current val is greater then the default 0
        self.assertTrue(returned[2] == 0)
        # checks that an index error is returned if a user can't be found
        with self.assertRaises(IndexError):
            returned[0][0]
示例#6
0
def portfolio():
    """user portfolio page"""

    # newStockForm is displayed as a Modal in the html
    form = NewStockForm()
    # editStockForm is displayed as a Modal in the html
    edit_stock_form = EditStock()
    # used to fill table
    stock_details = User_Stock.get_users_stocks(current_user.id)

    return render_template('user/portfolio.html', form=form, stock_details=stock_details, edit_stock_form=edit_stock_form)
示例#7
0
def company_details(stock_symbol):
    """generate company details route"""
    # check user stocks
    stock_arr = []
    if (current_user.is_authenticated):
        stock_details = User_Stock.get_users_stocks(current_user.id)
        
        for details in stock_details[0]:
            stock_arr.append(details.stock_symbol)
    
    # newStockForm is displayed as a Modal in the html
    form = NewStockForm()

    # check DB for stock
    returned_stock_details = Stock.query.get(stock_symbol)
    if returned_stock_details:
        company_name = returned_stock_details.stock_name
        # render template
        return render_template('/stock/detailed_stock_view.html', stock_symbol=stock_symbol, company_name=company_name, stock_arr=stock_arr, form=form)
    
    # if company was not found in DB - search API for stock symbol
    returned_stock_details = User_Stock.add_stock_symbol(stock_symbol)
    # if stock symbol returned true (stock found and added to our DB)
    if returned_stock_details:
        # add stock basic details to DB
        returned_stock_details = Stock.add_stock_details(stock_symbol)
        company_name = returned_stock_details.stock_name
        # render template
        return render_template('/stock/detailed_stock_view.html', stock_symbol=stock_symbol, company_name=company_name, stock_arr=stock_arr, form=form)

    # if stock symbol returned false (stock not found in API)
    flash('Stock was not found', 'warning')
    
    db.session.rollback()
    if not (current_user.is_active):
        return redirect(url_for('homepage'))

    return redirect(url_for('portfolio'))
示例#8
0
def send_portfolio():
    """send portfolio via email route"""

    # get details to send
    stock_details = User_Stock.get_users_stocks(current_user.id)

    # craft message
    msg = Message('Portfolio SnapShot', sender=MAIL_USER, recipients=[current_user.email])
    msg.html = render_template(
        'user/_portfolio_summary.html', stock_details=stock_details)
    # send message with flask-mail
    mail.send(msg)
    flash(f"Portfolio Snap Shot Sent", "success")

    return redirect(url_for('portfolio'))
示例#9
0
    def setUp(self):
        """create test client, add sample data"""

        self.client = app.test_client()

        User_Stock.query.delete()
        Stock.query.delete()
        User.query.delete()

        u = User.signup("testUser", "*****@*****.**", "password", "USA",
                        "CA")
        u.id = 9876

        u_stock = User_Stock.add_stock(9876, "AAPL", "5")

        db.session.add_all([u, u_stock])
        db.session.commit()
        self.u = u
        self.u_stock = u_stock
示例#10
0
    def test_adding_stock(self):
        """basic test adding stock"""

        new_user_stock = User_Stock.add_stock(9876, "MSFT", 1)
        db.session.add(new_user_stock)
        db.session.commit()

        # checks stock was added to Stock model
        self.assertEqual(new_user_stock.stock_symbol,
                         Stock.query.get("MSFT").stock_symbol)
        # checks a start date was added
        self.assertIsNotNone(new_user_stock.start_date)
        # checks a current date was added
        self.assertIsNotNone(new_user_stock.current_date)
        # checks a start stock price was added
        self.assertIsNotNone(new_user_stock.start_stock_price)
        # checks a current stock price was added
        self.assertIsNotNone(new_user_stock.start_stock_price)
        # checks a stock num was added
        self.assertIsNotNone(new_user_stock.stock_num)
    def test_delete_stock_route(self):
        """test delete stock route"""

        new_stock = User_Stock.add_stock(9876, "GS", "1")
        db.session.add(new_stock)
        db.session.commit()

        with self.client as c:
            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('/user/stock/delete',
                          data={'stock_symbol': 'GS'},
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn('GS has been deleted from your portfolio',
                          str(resp.data))
class UserViewsTestCase(TestCase):
    """basic portfolio views test case"""

    # Users & stocks to be created once (not for every test) to limit API calls (limited to 60 per min)
    u = User.signup("testUser", "*****@*****.**", "password", "USA", "CA")
    u.id = 9876

    u_stock = User_Stock.add_stock(u.id, "AAPL", "5")
    u_stock_2 = User_Stock.add_stock(u.id, "UNH", "1")

    db.session.add_all([u, u_stock, u_stock_2])
    db.session.commit()

    def setUp(self):
        """create test client, add sample data"""

        self.client = app.test_client()

        u = User.query.get(9876)
        self.u = u
        u_stock = User_Stock.query.filter_by(stock_symbol='AAPL').filter_by(
            user_id=self.u.id)
        self.u_stock = u_stock[0]
        u_stock_2 = User_Stock.query.filter_by(stock_symbol='UNH').filter_by(
            user_id=self.u.id)
        self.u_stock_2 = u_stock_2[0]
        self.s = Stock.query.get('AAPL')

    def tearDown(self):
        db.session.rollback()

    def test_portfolio_route(self):
        """test portfolio route"""
        with self.client as c:
            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.get('/user')

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b'TOTAL PORTFOLIO VALUE', resp.data)

    def test_portfolio_route_not_signed_in(self):
        """test portfolio route when not signed in"""
        with self.client as c:

            resp = c.get('/user', follow_redirects=True)

            self.assertEqual(resp.status_code, 405)

    def test_add_stock_route(self):
        """test add_stock route"""
        with self.client as c:
            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('/user/add',
                          data={
                              "stock_symbol": 'MS',
                              "stock_num": '1'
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b"Stock added", resp.data)

    def test_add_stock_route_not_valid_stock(self):
        """test add_stock route with a stock Symbol which is not valid"""
        with self.client as c:
            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('/user/add',
                          data={
                              "stock_symbol": 'THISISNOTVALID',
                              "stock_num": '1'
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b"Stock Symbol Not Recognized", resp.data)

    def test_add_stock_route_stock_exists(self):
        """test add_stock route with a stock Symbol which already exists in user portfolio"""
        with self.client as c:
            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('/user/add',
                          data={
                              "stock_symbol": 'AAPL',
                              "stock_num": '1'
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b"Stock already in portfolio", resp.data)

    def test_add_stock_route_not_signed_in(self):
        """test add_stock route can only be accessed when signed in"""
        with self.client as c:
            resp = c.post('/user/add',
                          data={
                              "stock_symbol": 'AAPL',
                              "stock_num": '1'
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 405)

    def test_user_settings_get_route(self):
        """test user settings get route"""
        with self.client as c:

            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.get('/user/settings')

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b'Update User Settings', resp.data)
            # checks that current details have been prefilled
            self.assertIn(self.u.username, str(resp.data))
            self.assertIn(self.u.email, str(resp.data))
            self.assertIn(self.u.country, str(resp.data))

    def test_user_settings_post_route(self):
        """test user settings post route with updated details"""
        with self.client as c:

            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('/user/settings',
                          data={
                              'username': '******',
                              'email': self.u.email,
                              'password': self.u.password,
                              'country': self.u.country,
                              'state': self.u.state
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b'settings have been updated', resp.data)
            self.assertIn(b'newUserName', resp.data)

    def test_user_settings_post_route_login_required(self):
        """test user settings post route can only be accessed if logged in"""
        with self.client as c:
            resp = c.post('/user/settings',
                          data={
                              'username': '******',
                              'email': self.u.email,
                              'password': self.u.password,
                              'country': self.u.country,
                              'state': self.u.state
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 405)

    def test_edit_password_route(self):
        """test updating user password"""
        with self.client as c:

            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('user/password',
                          data={
                              'current_password': '******',
                              'new_password': '******',
                              'confirm_new_password': '******'
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b'password has been updated', resp.data)

            # reset password back to original for future tests
            resp = c.post('user/password',
                          data={
                              'current_password': '******',
                              'new_password': '******',
                              'confirm_new_password': '******'
                          },
                          follow_redirects=True)

    def test_edit_password_route_wrong_current_pw(self):
        """test updating user password with wrong current password does not let the password update"""
        with self.client as c:

            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('user/password',
                          data={
                              'current_password': '******',
                              'new_password': '******',
                              'confirm_new_password': '******'
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b'Invalid credentials', resp.data)

    def test_edit_password_route_new_pw_not_matched(self):
        """test updating user password with new passwords which do not match does not update password"""
        with self.client as c:

            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('user/password',
                          data={
                              'current_password': '******',
                              'new_password': '******',
                              'confirm_new_password': '******'
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b'New Passwords do not match', resp.data)

    def test_edit_password_route_login_required(self):
        """test updating user password can only be accessed when logged in"""
        with self.client as c:
            resp = c.post('user/password',
                          data={
                              'current_password': '******',
                              'new_password': '******',
                              'confirm_new_password': '******'
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 405)

    def test_edit_stock_route(self):
        """test edit stock route"""
        with self.client as c:
            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('/user/stock',
                          data={
                              'stock_num': '100000',
                              'stock_symbol': self.u_stock.stock_symbol
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b'100000', resp.data)

    def test_edit_stock_stock_symbol_not_valid(self):
        """test edit stock route with a stock symbol which is not in the User_Stock portfolio"""
        with self.client as c:
            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('/user/stock',
                          data={
                              'stock_num': '100000',
                              'stock_symbol': 'NOTVALID'
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b'an error occurred', resp.data)

    def test_edit_stock_stock_symbol_login_required(self):
        """test edit stock route can only be accessed when logged in"""
        with self.client as c:
            resp = c.post('/user/stock',
                          data={
                              'stock_num': '100000',
                              'stock_symbol': self.u_stock.stock_symbol
                          },
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 405)

    def test_delete_stock_route(self):
        """test delete stock route"""

        new_stock = User_Stock.add_stock(9876, "GS", "1")
        db.session.add(new_stock)
        db.session.commit()

        with self.client as c:
            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('/user/stock/delete',
                          data={'stock_symbol': 'GS'},
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn('GS has been deleted from your portfolio',
                          str(resp.data))

    def test_delete_stock_symbol_not_valid(self):
        """test delete stock route with not a valid stock symbol"""
        with self.client as c:
            login = c.post('/login',
                           data={
                               "login_username": self.u.username,
                               "login_password": '******'
                           })

            resp = c.post('/user/stock/delete',
                          data={'stock_symbol': 'NOTVALID'},
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn('An error occurred', str(resp.data))

    def test_delete_stock_symbol_login_required(self):
        """test delete stock route can only be accessed when logged in"""
        with self.client as c:
            resp = c.post('/user/stock/delete',
                          data={'stock_symbol': 'NOTVALID'},
                          follow_redirects=True)

            self.assertEqual(resp.status_code, 405)

    def test_send_portfolio_route(self):
        """test flask mail"""

        with mail.record_messages() as outbox:
            with self.client as c:
                login = c.post('/login',
                               data={
                                   "login_username": self.u.username,
                                   "login_password": '******'
                               })

                resp = c.get('/user/send-portfolio', follow_redirects=True)

                self.assertEqual(resp.status_code, 200)
                self.assertIn(b'Portfolio Snap Shot Sent', resp.data)
                self.assertEqual(len(outbox), 1)
                self.assertEqual(outbox[0].subject, 'Portfolio SnapShot')

    def test_send_portfolio_login_required(self):
        """test send portfolio route can only be accessed when logged in"""
        with self.client as c:
            resp = c.get('/user/send-portfolio', follow_redirects=True)

            self.assertEqual(resp.status_code, 405)
示例#13
0
class CompanyAndNewsViewsTestCase(TestCase):
    """Company and News test case"""

    # Users & stocks to be created once (not for every test) to limit API calls (limited to 60 per min)
    u = User.signup("testUser", "*****@*****.**", "password", "USA", "CA")
    u.id = 9876

    u_stock = User_Stock.add_stock(u.id, "AAPL", "5")

    db.session.add_all([u, u_stock])
    db.session.commit()

    def setUp(self):
        """create test client, add sample data"""

        self.client = app.test_client()

        u = User.query.get(9876)
        self.u = u
        u_stock = User_Stock.query.filter_by(stock_symbol='AAPL').filter_by(
            user_id=self.u.id)
        self.u_stock = u_stock[0]
        self.s = Stock.query.get('AAPL')

    def tearDown(self):
        db.session.rollback()

    def test_company_details_route_existing_stock(self):
        """test company details with a stock already in database"""
        with self.client as c:
            resp = c.get('/company-details/AAPL')

            self.assertEqual(resp.status_code, 200)
            self.assertIn(
                b'To see further peer details click on the below links',
                resp.data)

    def test_company_details_route_existing_stock(self):
        """test company details with a stock already in database"""
        with self.client as c:
            resp = c.get('/company-details/AAPL')

            self.assertEqual(resp.status_code, 200)
            self.assertIn(
                b'To see further peer details click on the below links',
                resp.data)

    def test_company_details_route_new_stock(self):
        """test company details with a stock not currently in database"""
        with self.client as c:
            resp = c.get('/company-details/GS')

            self.assertEqual(resp.status_code, 200)
            self.assertIn(
                b'To see further peer details click on the below links',
                resp.data)

            stock = Stock.query.get('GS')
            self.assertEqual('GS', stock.stock_symbol[0])

    def test_company_details_route_invalid_stock(self):
        """test company details when stock symbol is not valid"""
        with self.client as c:
            resp = c.get('/company-details/INVALIDNAME', follow_redirects=True)

            self.assertEqual(resp.status_code, 200)
            self.assertIn(b'Stock was not found', resp.data)

    def test_send_stock_details_route(self):
        """test send stock details"""
        with self.client as c:
            resp = c.post('/api/company-details',
                          json={'stock_symbol': 'AAPL'})

            data = resp.json['stock']
            self.assertEqual(resp.status_code, 200)
            self.assertEqual(data['currency'], 'USD')
            self.assertEqual(data['country'], 'US')
            self.assertEqual(data['name'], 'Apple Inc')
            self.assertEqual(data['ipo'], '1980-12-12')

    def test_send_stock_details_route_invalid_stock(self):
        """test send stock details when stock symbol is not valid"""
        with self.client as c:
            resp = c.post('/api/company-details',
                          json={'stock_symbol': 'INVALIDNAME'})

            self.assertEqual(resp.status_code, 404)

    def test_send_advanced_stock_details_route(self):
        """test send stock details"""
        with self.client as c:
            resp = c.post('/api/advanced-company-details',
                          json={'stock_symbol': 'AAPL'})

            data = resp.json['stock']
            peers = resp.json['peers']
            self.assertEqual(resp.status_code, 200)
            self.assertGreater(data['price'], '0')
            self.assertGreater(data['targetMean'], '0')
            self.assertGreater(data['yearlyHigh'], '0')
            self.assertIsNotNone(peers[0])

    def test_send_advanced_stock_details_route_invalid_stock(self):
        """test send stock details when stock symbol is not valid"""
        with self.client as c:
            resp = c.post('/api/advanced-company-details',
                          json={'stock_symbol': 'INVALIDNAME'})

            self.assertEqual(resp.status_code, 404)

    def test_news_route(self):
        """test news route with stock symbol sent via json"""
        with self.client as c:
            resp = c.post('/api/company-details/news',
                          json={'stock_symbol': 'AAPL'})

            data = resp.json['news']
            self.assertEqual(resp.status_code, 200)
            # only testing the first news article returned
            self.assertIsNotNone(data[0]['category'])
            self.assertIsNotNone(data[0]['datetime'])

    def test_news_route_invalid_stock(self):
        """test news route when stock symbol is not valid"""
        with self.client as c:
            resp = c.post('/api/company-details/news',
                          json={'stock_symbol': 'INVALIDNAME'})

            self.assertEqual(resp.status_code, 404)

    def test_news_route_with_no_json(self):
        """test news route when not stock symbol given"""
        with self.client as c:
            resp = c.post('/api/company-details/news')

            data = resp.json['news']
            self.assertEqual(resp.status_code, 200)
            # only testing the first news article returned
            self.assertIsNotNone(data[0]['category'])
            self.assertIsNotNone(data[0]['datetime'])

    def test_auto_route_with_stock(self):
        """test auto route with a stock passed through in params"""
        with self.client as c:
            resp = c.get('/api/_stock-autocomplete?name=APP')
            self.assertEqual(resp.status_code, 200)
            self.assertIsNotNone(resp.json)
            self.assertIsNotNone(resp.json[0]['description'])
            self.assertIsNotNone(resp.json[0]['symbol'])