예제 #1
0
 def test_country(self):
     """
     Test country getter
     :return:
     """
     stock_data = PyTickerSymbols()
     self.assertIsNotNone(stock_data)
     countries = list(stock_data.get_all_countries())
     self.assertIsNotNone(countries)
     self.assertIn("Germany", countries)
     self.assertIn("Netherlands", countries)
     self.assertIn("Sweden", countries)
     # duplicates are not allowed
     for country in countries:
         lenl = len([tmp for tmp in countries if tmp == country])
         self.assertEqual(lenl, 1)
예제 #2
0
    def test_valid_country_name(self):
        stock_data = PyTickerSymbols()
        countries = stock_data.get_all_countries()
        empty_names = list(filter(lambda x: not x, countries))
        empty_country_stocks = ', '.join(
            list(
                map(
                    lambda x: x['name'],
                    stock_data.get_stocks_by_country(''),
                )))
        self.assertEqual(
            len(empty_names),
            0,
            'The following stocks have an empty country string: ' +
            empty_country_stocks,
        )

        valid_countires = list(map(
            lambda x: x.name,
            pycountry.countries,
        ))
        wrong_country_name = list(
            filter(
                lambda x: x['country'] not in valid_countires,
                stock_data.get_all_stocks(),
            ), )
        wrong_country_name_stocks = ', '.join(
            list(
                map(
                    lambda x: x['name'] + '(' + x['country'] + ')',
                    wrong_country_name,
                )))
        self.assertEqual(
            len(wrong_country_name),
            0,
            'The following stocks have an empty country string:' +
            wrong_country_name_stocks,
        )
예제 #3
0
import datetime
import numpy as np
import pandas_datareader.data as web
from pandas import Series, DataFrame
import matplotlib.pyplot as plt
from matplotlib import style
import matplotlib as mpl
import altair as alt
#import yfinance as yf
#yf.pdr_override()
###https://github.com/ranaroussi/yfinance

from pytickersymbols import PyTickerSymbols

stock_data = PyTickerSymbols()
countries = stock_data.get_all_countries()
indices = stock_data.get_all_indices()
industries = stock_data.get_all_industries()

max_width = 100
padding_top = 1
padding_right = 5
padding_left = 5
padding_bottom = 20
color = 'Black'
backgr = 'white'

st.markdown(
    f"""
<style>
    .reportview-container .main .block-container{{
예제 #4
0
파일: base.py 프로젝트: ajmal017/pystockdb
class DBBase:
    """
    Database installer
    """
    def __init__(self, arguments: dict, logger: logging.Logger):
        self.logger = logger
        self.db_args = arguments["db_args"]
        self.arguments = arguments
        self.ticker_symbols = PyTickerSymbols()
        # has connection
        try:
            db.bind(**self.db_args)
        except core.BindingError:
            pass
        else:
            db.generate_mapping(check_tables=False)

        if self.db_args.get('create_db', False):
            db.drop_all_tables(with_all_data=True)
            db.create_tables()
            self.__insert_initial_data()

    def build(self):
        """
        Starts the database installation
        :return:
        """
        raise NotImplementedError

    @db_session
    def add_indices_and_stocks(self, indices_list):
        for index_name in indices_list:
            index = Index(name=index_name, price_item=PriceItem(item=Item()))
            # add index symbol
            yah_sym = self.ticker_symbols.index_to_yahoo_symbol(index_name)
            if yah_sym is None:
                self.logger.warning(
                    'Can not translate {} into yahoo symbol'.format(
                        index_name))
                continue
            idx_item = Item()
            idx_item.tags.add(Tag.get(name=Tag.IDX))
            index.price_item.symbols.create(name=yah_sym, item=idx_item)
            stocks = self.ticker_symbols.get_stocks_by_index(index.name)
            for stock_info in stocks:
                self.__add_stock_to_index(index, stock_info)
            commit()

    def __add_stock_to_index(self, index, stock_info):
        stock_in_db = Stock.get(name=stock_info['name'])
        if stock_in_db:
            self.logger.info('Add stock {}:{} to index.'.format(
                index.name, stock_in_db.name))
            index.stocks.add(stock_in_db)
        else:
            self.logger.info('Add stock {}:{} to db'.format(
                index.name, stock_info[Type.SYM]))
            # create stock
            stock = Stock(name=stock_info['name'],
                          price_item=PriceItem(item=Item()))
            # add symbols
            yao = Tag.get(name=Tag.YAO)
            gog = Tag.get(name=Tag.GOG)
            usd = Tag.get(name=Tag.USD)
            eur = Tag.get(name=Tag.EUR)
            for symbol in stock_info['symbols']:
                if Tag.GOG in symbol and symbol[Tag.GOG] != '-':
                    self.__create_symbol(stock, Tag.GOG, gog, symbol, eur, usd)
                if Tag.YAO in symbol and symbol[Tag.YAO] != '-':
                    self.__create_symbol(stock, Tag.YAO, yao, symbol, eur, usd)
            index.stocks.add(stock)
            # connect stock with industry and country
            # country
            name = stock_info['country']
            country = Tag.select(
                lambda t: t.name == name and t.type.name == Type.REG).first()
            country.items.add(stock.price_item.item)
            # industry
            indus = stock_info['industries']
            industries = Tag.select(
                lambda t: t.name in indus and t.type.name == Type.IND)
            for industry in industries:
                industry.items.add(stock.price_item.item)

    @db_session
    def __create_symbol(self, stock, my_tag, my_tag_item, symbol, eur, usd):
        if my_tag in symbol and symbol[my_tag] != '-':
            cur = eur if symbol[my_tag].startswith('FRA') or \
                symbol[my_tag].endswith('.F') else usd
            item = Item()
            item.add_tags([my_tag_item, cur])
            if Symbol.get(name=symbol[my_tag]):
                self.logger.warning('Symbol {} is related to more than one'
                                    ' stock.'.format(symbol[my_tag]))
            else:
                stock.price_item.symbols.create(item=item, name=symbol[my_tag])

    @db_session
    def download_historicals(self, symbols, start, end):
        if not (start and end):
            return False
        crawler = DataCrawler()
        chunks = [symbols[x:x + 50] for x in range(0, len(symbols), 50)]
        for chunk in chunks:
            ids = [symbol.name for symbol in chunk]
            if ids is None:
                continue
            print(ids)
            series = crawler.get_series_stack(ids, start=start, end=end)
            for symbol in chunk:
                self.logger.debug('Add prices for {} from {} until {}.'.format(
                    symbol.name, start, end))
                for value in series[symbol.name]:
                    symbol.prices.create(**value)
            commit()
        return True

    @db_session
    def __insert_initial_data(self):
        # insert types
        region_type = Type(name=Type.REG)
        industry_type = Type(name=Type.IND)
        Type(name=Type.MSC).add_tags([Tag.IDX])
        Type(name=Type.SYM).add_tags([Tag.YAO, Tag.GOG])
        Type(name=Type.CUR).add_tags([Tag.USD, Tag.EUR])
        Type(name=Type.FDM).add_tags(
            [Tag.ICA, Tag.ICF, Tag.REC, Tag.ICO, Tag.BLE, Tag.CSH])
        Type(name=Type.FIL)
        Type(name=Type.ICR)
        Type(name=Type.ARG)
        countries = self.ticker_symbols.get_all_countries()
        for country in countries:
            region_type.tags.create(name=country)
        industries = self.ticker_symbols.get_all_industries()
        for industry in industries:
            industry_type.tags.create(name=industry)
        commit()