def add_settings_to_test_db( db: DBHandler, db_settings: Optional[Dict[str, Any]], ignored_assets: Optional[List[Asset]], data_migration_version: Optional[int], ) -> None: settings = { # DO not submit usage analytics during tests 'submit_usage_analytics': False, 'main_currency': DEFAULT_TESTS_MAIN_CURRENCY, } # Set the given db_settings. The pre-set values have priority unless overriden here if db_settings is not None: for key, value in db_settings.items(): settings[key] = value db.set_settings(ModifiableDBSettings(**settings)) # type: ignore if ignored_assets: for asset in ignored_assets: db.add_to_ignored_assets(asset) if data_migration_version is not None: db.conn.cursor().execute( 'INSERT OR REPLACE INTO settings(name, value) VALUES(?, ?)', ('last_data_migration', data_migration_version), ) db.conn.commit()
def test_upgrade_db_10_to_11(data_dir, username): """Test upgrading the DB from version 10 to version 11. Deleting all entries from used_query_ranges""" msg_aggregator = MessagesAggregator() userdata_dir = os.path.join(data_dir, username) os.mkdir(userdata_dir) dir_path = os.path.dirname(os.path.realpath(__file__)) copyfile( os.path.join(os.path.dirname(dir_path), 'data', 'v10_rotkehlchen.db'), os.path.join(userdata_dir, 'rotkehlchen.db'), ) with target_patch(target_version=11): db = DBHandler(user_data_dir=userdata_dir, password='******', msg_aggregator=msg_aggregator) # Make sure that the blockchain accounts table is upgraded expected_results = [ ('ETH', '0xB2CEB220df2e4a5ec6A0aC93d79655895E9886Bc', None), ('ETH', '0x926cbe37d3487a881F9EB18F4746Ee09557D79cB', None), ('BTC', '37SQZzaCPbDno9aFBjaVKhA9KkzTbt94x2', None), ] cursor = db.conn.cursor() results = cursor.execute( 'SELECT blockchain, account, label FROM blockchain_accounts;', ) for idx, entry in enumerate(results): assert entry == expected_results[idx] # Finally also make sure that we have updated to the target version assert db.get_version() == 11
def __init__(self, args: argparse.Namespace) -> None: self.db = DBHandler( user_data_dir=default_data_directory() / args.user_name, password=args.user_password, msg_aggregator=MessagesAggregator(), initial_settings=None, )
def unlock(self, username, password, create_new): user_data_dir = os.path.join(self.data_directory, username) if create_new: if os.path.exists(user_data_dir): raise AuthenticationError( 'User {} already exists'.format(username)) else: os.mkdir(user_data_dir) else: if not os.path.exists(user_data_dir): raise AuthenticationError( 'User {} does not exist'.format(username)) if not os.path.exists(os.path.join(user_data_dir, 'rotkehlchen.db')): # This is bad. User directory exists but database is missing. # Make a backup of the directory that user should probably remove # on his own. At the same time delete the directory so that a new # user account can be created shutil.move( user_data_dir, os.path.join(self.data_directory, 'backup_%s' % username)) raise AuthenticationError( 'User {} exists but DB is missing. Somehow must have been manually ' 'deleted or is corrupt. Please recreate the user account.'. format(username)) self.db = DBHandler(user_data_dir, username, password) self.user_data_dir = user_data_dir return user_data_dir
def process_entry( self, db: DBHandler, db_ledger: DBLedgerActions, timestamp: Timestamp, data: BinanceCsvRow, ) -> None: amount = data['Change'] asset = data['Coin'] category = AssetMovementCategory.DEPOSIT if data['Operation'] == 'Deposit' else AssetMovementCategory.WITHDRAWAL # noqa: E501 if category == AssetMovementCategory.WITHDRAWAL: amount = -amount asset_movement = AssetMovement( location=Location.BINANCE, category=category, address=None, transaction_id=None, timestamp=timestamp, asset=asset, amount=AssetAmount(amount), fee=Fee(ZERO), fee_asset=A_USD, link=f'Imported from binance CSV file. Binance operation: {data["Operation"]}', ) db.add_asset_movements([asset_movement])
def check_saved_events_for_exchange( exchange_location: Location, db: DBHandler, should_exist: bool, queryrange_formatstr: str = '{exchange}_{type}_{exchange}', ) -> None: trades = db.get_trades( filter_query=TradesFilterQuery.make(location=exchange_location), has_premium=True, ) trades_range = db.get_used_query_range( queryrange_formatstr.format(exchange=exchange_location, type='trades')) # noqa: E501 margins_range = db.get_used_query_range( queryrange_formatstr.format(exchange=exchange_location, type='margins')) # noqa: E501 movements_range = db.get_used_query_range( queryrange_formatstr.format(exchange=exchange_location, type='asset_movements')) # noqa: E501 if should_exist: assert trades_range is not None assert margins_range is not None assert movements_range is not None assert len(trades) != 0 else: assert trades_range is None assert margins_range is None assert movements_range is None assert len(trades) == 0
def test_upgrade_sqlcipher_v3_to_v4_with_dbinfo(data_dir): sqlcipher_version = detect_sqlcipher_version() if sqlcipher_version != 4: # nothing to test return username = '******' userdata_dir = os.path.join(data_dir, username) os.mkdir(userdata_dir) # get the v3 database file and copy it into the user's data directory dir_path = os.path.dirname(os.path.realpath(__file__)) copyfile( os.path.join(os.path.dirname(dir_path), 'data', 'sqlcipher_v3_rotkehlchen.db'), os.path.join(userdata_dir, 'rotkehlchen.db'), ) dbinfo = { 'sqlcipher_version': 3, 'md5_hash': '20c910c28ca42370e4a5f24d6d4a73d2' } with open(os.path.join(userdata_dir, DBINFO_FILENAME), 'w') as f: f.write(rlk_jsondumps(dbinfo)) # the constructor should migrate it in-place and we should have a working DB msg_aggregator = MessagesAggregator() db = DBHandler(userdata_dir, '123', msg_aggregator) assert db.get_version() == ROTKEHLCHEN_DB_VERSION
def _init_database( data_dir: FilePath, password: str, msg_aggregator: MessagesAggregator, db_settings: Optional[Dict[str, Any]], ignored_assets: Optional[List[Asset]], blockchain_accounts: BlockchainAccounts, ) -> DBHandler: db = DBHandler(data_dir, password, msg_aggregator) settings = { # DO not submit usage analytics during tests 'submit_usage_analytics': False, 'main_currency': DEFAULT_TESTS_MAIN_CURRENCY, } # Set the given db_settings. The pre-set values have priority unless overriden here if db_settings is not None: for key, value in db_settings.items(): settings[key] = value db.set_settings(ModifiableDBSettings(**settings)) if ignored_assets: for asset in ignored_assets: db.add_to_ignored_assets(asset) # Make sure that the fixture provided accounts are in the blockchain db.add_blockchain_accounts(SupportedBlockchain.ETHEREUM, blockchain_accounts.eth) db.add_blockchain_accounts(SupportedBlockchain.BITCOIN, blockchain_accounts.btc) return db
def __init__(self, args: argparse.Namespace) -> None: user_path = FilePath( os.path.join(str(default_data_directory()), args.user_name)) self.db = DBHandler( user_data_dir=user_path, password=args.user_password, msg_aggregator=MessagesAggregator(), )
def add_tags_to_test_db(db: DBHandler, tags: List[Dict[str, Any]]) -> None: for tag in tags: db.add_tag( name=tag['name'], description=tag.get('description', None), background_color=tag['background_color'], foreground_color=tag['foreground_color'], )
def maybe_include_etherscan_key(db: DBHandler, include_etherscan_key: bool) -> None: if not include_etherscan_key: return # Add the tests only etherscan API key db.add_external_service_credentials([ExternalServiceApiCredentials( service=ExternalService.ETHERSCAN, api_key=ApiKey('8JT7WQBB2VQP5C3416Y8X3S8GBA3CVZKP4'), )])
def process_entries( self, db: DBHandler, timestamp: Timestamp, data: List[BinanceCsvRow], ) -> int: trades = self.process_trades(db=db, timestamp=timestamp, data=data) db.add_trades(trades) return len(trades)
def add_blockchain_accounts_to_db(db: DBHandler, blockchain_accounts: BlockchainAccounts) -> None: db.add_blockchain_accounts( SupportedBlockchain.ETHEREUM, [BlockchainAccountData(address=x) for x in blockchain_accounts.eth], ) db.add_blockchain_accounts( SupportedBlockchain.BITCOIN, [BlockchainAccountData(address=x) for x in blockchain_accounts.btc], )
def query_ethereum_transactions( database: DBHandler, etherscan: Etherscan, from_ts: Optional[Timestamp] = None, to_ts: Optional[Timestamp] = None, ) -> List[EthereumTransaction]: """Queries for all transactions (normal AND internal) of all ethereum accounts. Returns a list of all transactions of all accounts sorted by time. May raise: - RemoteError if etherscan is used and there is a problem with reaching it or with parsing the response. """ transactions: List[EthereumTransaction] = [] accounts = database.get_blockchain_accounts() for address in accounts.eth: # If we already have any transactions in the DB for this from_address # from to_ts and on then that means the range has already been queried if to_ts: existing_txs = database.get_ethereum_transactions(from_ts=to_ts, address=address) if len(existing_txs) > 0: # So just query the DB only here transactions.extend( database.get_ethereum_transactions( from_ts=from_ts, to_ts=to_ts, address=address, ), ) continue # else we have to query etherscan for this address # TODO: Can we somehow shorten the query here by providing a block range? # Note: If we do, we then need to retrieve the rest of the transactions # from the DB. new_transactions = etherscan.get_transactions(account=address, internal=False) new_transactions.extend( etherscan.get_transactions(account=address, internal=True)) # and finally also save the transactions in the DB database.add_ethereum_transactions( ethereum_transactions=new_transactions, from_etherscan=True, ) transactions.extend(new_transactions) transactions.sort(key=lambda tx: tx.timestamp) return transactions
def maybe_include_cryptocompare_key(db: DBHandler, include_cryptocompare_key: bool) -> None: if not include_cryptocompare_key: return keys = [ 'a4a36d7fd1835cc1d757186de8e7357b4478b73923933d09d3689140ecc23c03', 'e929bcf68fa28715fa95f3bfa3baa3b9a6bc8f12112835586c705ab038ee06aa', '5159ca00f2579ef634b7f210ad725550572afbfb44e409460dd8a908d1c6416a', '6781b638eca6c3ca51a87efcdf0b9032397379a0810c5f8198a25493161c318d', ] # Add the tests only etherscan API key db.add_external_service_credentials([ExternalServiceApiCredentials( service=ExternalService.CRYPTOCOMPARE, api_key=ApiKey(random.choice(keys)), )])
def temp_etherscan(function_scope_messages_aggregator, tmpdir_factory): directory = tmpdir_factory.mktemp('data') db = DBHandler( user_data_dir=directory, password='******', msg_aggregator=function_scope_messages_aggregator, ) # Test with etherscan API key api_key = os.environ.get('ETHERSCAN_API_KEY', None) if api_key: db.add_external_service_credentials(credentials=[ ExternalServiceApiCredentials(service=ExternalService.ETHERSCAN, api_key=api_key), ]) etherscan = Etherscan(database=db, msg_aggregator=function_scope_messages_aggregator) return etherscan
def temp_etherscan(database, inquirer, function_scope_messages_aggregator, tmpdir_factory): api_key = os.environ.get('ETHERSCAN_API_KEY', None) if not api_key: pytest.fail('No ETHERSCAN_API_KEY environment variable found.') directory = tmpdir_factory.mktemp('data') msg_aggregator = MessagesAggregator() db = DBHandler(user_data_dir=directory, password='******', msg_aggregator=msg_aggregator) db.add_external_service_credentials(credentials=[ ExternalServiceApiCredentials(service=ExternalService.ETHERSCAN, api_key=api_key), ]) etherscan = Etherscan(database=db, msg_aggregator=msg_aggregator) return etherscan
def maybe_add_external_trades_to_history( db: DBHandler, start_ts: Timestamp, end_ts: Timestamp, history: List[Trade], msg_aggregator: MessagesAggregator, ) -> List[Trade]: """ Queries the DB to get any external trades, adds them to the provided history and returns it. If there is an unexpected error at the external trade deserialization an error is logged. """ serialized_external_trades = db.get_trades() try: external_trades = trades_from_dictlist( given_trades=serialized_external_trades, start_ts=start_ts, end_ts=end_ts, location='external trades', msg_aggregator=msg_aggregator, ) except KeyError: msg_aggregator.add_error( 'External trades in the DB are in an unrecognized format') return history history.extend(external_trades) history.sort(key=lambda trade: trade.timestamp) return history
def _init_database( data_dir: Path, password: str, msg_aggregator: MessagesAggregator, db_settings: Optional[Dict[str, Any]], ignored_assets: Optional[List[Asset]], blockchain_accounts: BlockchainAccounts, include_etherscan_key: bool, include_cryptocompare_key: bool, tags: List[Dict[str, Any]], manually_tracked_balances: List[ManuallyTrackedBalance], data_migration_version: int, use_custom_database: Optional[str], ) -> DBHandler: if use_custom_database is not None: _use_prepared_db(data_dir, use_custom_database) db = DBHandler( user_data_dir=data_dir, password=password, msg_aggregator=msg_aggregator, initial_settings=None, ) # Make sure that the fixture provided data are included in the DB add_settings_to_test_db(db, db_settings, ignored_assets, data_migration_version) add_blockchain_accounts_to_db(db, blockchain_accounts) maybe_include_etherscan_key(db, include_etherscan_key) maybe_include_cryptocompare_key(db, include_cryptocompare_key) add_tags_to_test_db(db, tags) add_manually_tracked_balances_to_test_db(db, manually_tracked_balances) return db
def maybe_add_external_trades_to_history( db: DBHandler, start_ts: Timestamp, end_ts: Timestamp, history: List[Union[Trade, MarginPosition]], msg_aggregator: MessagesAggregator, ) -> List[Union[Trade, MarginPosition]]: """ Queries the DB to get any external trades, adds them to the provided history and returns it. If there is an unexpected error at the external trade deserialization an error is logged. """ serialized_external_trades = db.get_trades() try: external_trades = trades_from_dictlist( given_trades=serialized_external_trades, start_ts=start_ts, end_ts=end_ts, location='external trades', msg_aggregator=msg_aggregator, ) except (KeyError, DeserializationError): msg_aggregator.add_error( 'External trades in the DB are in an unrecognized format') return history history.extend(external_trades) # TODO: We also sort in one other place in this file and also in accountant.py # Get rid of the unneeded cases? history.sort(key=lambda trade: action_get_timestamp(trade)) return history
def add_blockchain_accounts_to_db(db: DBHandler, blockchain_accounts: BlockchainAccounts) -> None: try: db.add_blockchain_accounts( SupportedBlockchain.ETHEREUM, [BlockchainAccountData(address=x) for x in blockchain_accounts.eth], ) db.add_blockchain_accounts( SupportedBlockchain.BITCOIN, [BlockchainAccountData(address=x) for x in blockchain_accounts.btc], ) except InputError as e: raise AssertionError( f'Got error at test setup blockchain account addition: {str(e)} ' f'Probably using two different databases or too many fixtures initialized. ' f'For example do not initialize both a rotki api server and another DB at same time', ) from e
def test_binance_pairs(user_data_dir): msg_aggregator = MessagesAggregator() db = DBHandler(user_data_dir, '123', msg_aggregator, None) binance_api_key = ApiKey('binance_api_key') binance_api_secret = ApiSecret(b'binance_api_secret') db.add_exchange('binance', Location.BINANCE, binance_api_key, binance_api_secret) db.set_binance_pairs('binance', ['ETHUSDC', 'ETHBTC', 'BNBBTC'], Location.BINANCE) query = db.get_binance_pairs('binance', Location.BINANCE) assert query == ['ETHUSDC', 'ETHBTC', 'BNBBTC'] db.set_binance_pairs('binance', [], Location.BINANCE) query = db.get_binance_pairs('binance', Location.BINANCE) assert query == []
def test_multiple_location_data_and_balances_same_timestamp(user_data_dir): """ Test that adding location and balance data with same timestamp raises an error and no balance/location is added. Regression test for https://github.com/rotki/rotki/issues/1043 """ msg_aggregator = MessagesAggregator() db = DBHandler(user_data_dir, '123', msg_aggregator, None) balances = [ DBAssetBalance( category=BalanceType.ASSET, time=1590676728, asset=A_BTC, amount='1.0', usd_value='8500', ), DBAssetBalance( category=BalanceType.ASSET, time=1590676728, asset=A_BTC, amount='1.1', usd_value='9100', ), ] with pytest.raises(InputError) as exc_info: db.add_multiple_balances(balances) assert 'Adding timed_balance failed.' in str(exc_info.value) assert exc_info.errisinstance(InputError) balances = db.query_timed_balances(from_ts=0, to_ts=1590676728, asset=A_BTC) assert len(balances) == 0 locations = [ LocationData( time=1590676728, location='H', usd_value='55', ), LocationData( time=1590676728, location='H', usd_value='56', ), ] with pytest.raises(InputError) as exc_info: db.add_multiple_location_data(locations) assert 'Tried to add a timed_location_data for' in str(exc_info.value) assert exc_info.errisinstance(InputError) locations = db.get_latest_location_value_distribution() assert len(locations) == 0
def add_settings_to_test_db( db: DBHandler, db_settings: Optional[Dict[str, Any]], ignored_assets: Optional[List[Asset]], ) -> None: settings = { # DO not submit usage analytics during tests 'submit_usage_analytics': False, 'main_currency': DEFAULT_TESTS_MAIN_CURRENCY, } # Set the given db_settings. The pre-set values have priority unless overriden here if db_settings is not None: for key, value in db_settings.items(): settings[key] = value db.set_settings(ModifiableDBSettings(**settings)) # type: ignore if ignored_assets: for asset in ignored_assets: db.add_to_ignored_assets(asset)
def check_saved_events_for_exchange( exchange_location: Location, db: DBHandler, should_exist: bool, ) -> None: trades = db.get_trades(location=exchange_location) trades_range = db.get_used_query_range(f'{str(exchange_location)}_trades') margins_range = db.get_used_query_range(f'{str(exchange_location)}_margins') movements_range = db.get_used_query_range(f'{str(exchange_location)}_asset_movements') if should_exist: assert trades_range is not None assert margins_range is not None assert movements_range is not None assert len(trades) != 0 else: assert trades_range is None assert margins_range is None assert movements_range is None assert len(trades) == 0
def trades_historian(accounting_data_dir, function_scope_messages_aggregator): database = DBHandler(accounting_data_dir, '123', function_scope_messages_aggregator) historian = TradesHistorian( user_directory=accounting_data_dir, db=database, eth_accounts=[], msg_aggregator=function_scope_messages_aggregator, ) return historian
def test_upgrade_sqlcipher_v3_to_v4_without_dbinfo(user_data_dir): """Test that we can upgrade from an sqlcipher v3 to v4 rotkehlchen database Issue: https://github.com/rotki/rotki/issues/229 """ sqlcipher_version = detect_sqlcipher_version() if sqlcipher_version != 4: # nothing to test return # get the v3 database file and copy it into the user's data directory dir_path = os.path.dirname(os.path.realpath(__file__)) copyfile( os.path.join(os.path.dirname(dir_path), 'data', 'sqlcipher_v3_rotkehlchen.db'), user_data_dir / 'rotkehlchen.db', ) # the constructor should migrate it in-place and we should have a working DB msg_aggregator = MessagesAggregator() db = DBHandler(user_data_dir, '123', msg_aggregator, None) assert db.get_version() == ROTKEHLCHEN_DB_VERSION
def trades_historian(data_dir, function_scope_messages_aggregator, blockchain): database = DBHandler(data_dir, '123', function_scope_messages_aggregator) exchange_manager = ExchangeManager( msg_aggregator=function_scope_messages_aggregator) historian = TradesHistorian( user_directory=data_dir, db=database, msg_aggregator=function_scope_messages_aggregator, exchange_manager=exchange_manager, chain_manager=blockchain, ) return historian
def test_timed_balances_primary_key_works(user_data_dir): msg_aggregator = MessagesAggregator() db = DBHandler(user_data_dir, '123', msg_aggregator, None) balances = [ DBAssetBalance( category=BalanceType.ASSET, time=1590676728, asset=A_BTC, amount='1.0', usd_value='8500', ), DBAssetBalance( category=BalanceType.ASSET, time=1590676728, asset=A_BTC, amount='1.1', usd_value='9100', ), ] db.add_multiple_balances(balances) warnings = msg_aggregator.consume_warnings() errors = msg_aggregator.consume_errors() assert len(warnings) == 1 assert len(errors) == 0 balances = db.query_timed_balances(asset=A_BTC) assert len(balances) == 1 balances = [ DBAssetBalance( category=BalanceType.ASSET, time=1590676728, asset=A_ETH, amount='1.0', usd_value='8500', ), DBAssetBalance( category=BalanceType.LIABILITY, time=1590676728, asset=A_ETH, amount='1.1', usd_value='9100', ), ] db.add_multiple_balances(balances) warnings = msg_aggregator.consume_warnings() errors = msg_aggregator.consume_errors() assert len(warnings) == 0 assert len(errors) == 0 balances = db.query_timed_balances(asset=A_ETH) assert len(balances) == 2
def test_binance_query_trade_history_custom_markets(function_scope_binance, user_data_dir): """Test that custom pairs are queried correctly""" msg_aggregator = MessagesAggregator() db = DBHandler(user_data_dir, '123', msg_aggregator, None) binance_api_key = ApiKey('binance_api_key') binance_api_secret = ApiSecret(b'binance_api_secret') db.add_exchange('binance', Location.BINANCE, binance_api_key, binance_api_secret) binance = function_scope_binance markets = ['ETHBTC', 'BNBBTC', 'BTCUSDC'] db.set_binance_pairs('binance', markets, Location.BINANCE) count = 0 p = re.compile(r'symbol=[A-Z]*') seen = set() def mock_my_trades(url, timeout): # pylint: disable=unused-argument nonlocal count count += 1 market = p.search(url).group()[7:] assert market in markets and market not in seen seen.add(market) text = '[]' return MockResponse(200, text) with patch.object(binance.session, 'get', side_effect=mock_my_trades): binance.query_trade_history(start_ts=0, end_ts=1564301134, only_cache=False) assert count == len(markets)