import numpy as np
from sqlite3 import dbapi2 as sqlite
from get_numpy_dtype import get_dtype
import sqlite_io

# Globals
COUNTRY_YR_IDX = np.arange(1961, 2011)

# Connect to the database and create a cursor
DB = r".\GFIN_DB.db3"
connection = sqlite.connect(DB)
connection.text_factory = str # use 8 bit strings instead of unicode strings in SQLite
cursor = connection.cursor()

# Get all Consumption and Production Data and mask values less than or equal to 0
ndtype, names = get_dtype(connection, "Commodity", nameReturn=True, remove_id=True)
Q = "SELECT %s FROM Commodity WHERE (element_id=51 OR element_id=100)"%",".join(names)
commodity_xs = np.ma.array(cursor.execute(Q).fetchall(), ndtype)
unique_country_ids = np.unique(commodity_xs['country_id']) # unique country ids

# Array to hold all values to insert into Commodity table
insert_xs = np.ma.empty((100000, len(names)))

# Demographic table names and data types
demo_names = ",".join("yr%s"%x for x in COUNTRY_YR_IDX) # yr1961, yr1962, ..., yr2009, yr2010
demo_ndtype = zip(demo_names.split(","), len(COUNTRY_YR_IDX)*['<f8'])

# Go through each country in the Commodity table
count = 0
for country_id in unique_country_ids:
select = """
SELECT * FROM Demographic WHERE element_id BETWEEN 511 AND 603
"""
xs = np.ma.masked_less_equal(np.array(cursor.execute(select).fetchall()), 0)[:,1:]

# Calculate masked differences along the first axis
diff_xs = np.ma.masked_all(np.shape(xs[:, 4:]))
diff_xs[:, 1:] = np.ma.diff(xs[:, 4:], axis=1)
diff_xs = np.ma.filled(diff_xs, -1)

# Stack id fields with net change values
xs = np.ma.filled(xs, -1)
ys = np.hstack((xs[:, :4], diff_xs))

# Convert ndarray to recarray
ys = ys.reshape(-1, ).view(get_dtype(connection, 'Demographic', remove_id=True))

# Replace element values w/ their corresponding net change values
for element_id in [511, 512, 513, 551, 561, 571, 581, 591, 592, 593, 601, 602, 603]:
    np.put(ys['element_id'], np.where(ys['element_id']==element_id), element_id + 100)

# Get last index value + 1 of Demographic table for primary key values
max_id, = np.array(cursor.execute("SELECT MAX(id) FROM Demographic").fetchall()).flatten() + 1

# Insert new data into the database using sqlite_io structure
import sqlite_io
sqlite_io.tosqlite(ys, max_id, DB, "Demographic", autoid=True, create=False)

# Close the cursor and the connection
cursor.close()
connection.close()
    SELECT name FROM sqlite_master
    WHERE type='table' AND name!='sqlite_sequence' AND name!='Commodity' AND name!='Demographic'
    """).fetchall()).flatten()

# Create new tables
create_statements = np.array(cursor.execute(
    "SELECT sql FROM sqlite_master WHERE type='table' and name!='sqlite_sequence'"
).fetchall()).flatten()
create_strs = format_creates(create_statements)
[new_cursor.execute(statement) for statement in create_strs]
new_connection.commit()

# Insert data into each table
for table in copy_tables:
    is_autoid = table in ('SchemeColor', 'AreaGroup') # tables with id as primary key
    ndtype, names = get_dtype(connection, table, remove_id=is_autoid, nameReturn=True)

    # Get data from master database for copying
    xs = np.ma.array(cursor.execute("SELECT %s FROM %s"%(",".join(names), table)).fetchall(), ndtype)

    # Mask all None values and create primary keys
    autoid = [False, True][is_autoid] # assign primary keys
    primary_key = ["%s_id"%table.lower(), False][is_autoid] # primary key
    xs = mask_none_values(xs) # mask none values
    sqlite_io.tosqlite(xs, 0, NEW_DB, table, autoid=autoid,
        create=False, primary_key=primary_key)

# Format value tables with -1 values for missing values
for table in TABLES:
    (names, typestr) = zip(*(_[1:3] for _ in connection.execute("PRAGMA TABLE_INFO(%s)"%table).fetchall()))
    names = ",".join([name.strip() for name in names if name.strip()!='id'])