#=================================================

os.chdir('C:/Users/rfuchs/Documents/cyto_classif')

source = "SSLAMM/Week1"

files = os.listdir(source)
train_files = files[2:]
valid_files = [files[0]]
test_files = [files[1]]  # Last one one the 21 files

#===============================================================
# Former model prediction for comparison
#===============================================================

predict(source + '/' + test_files[0], source, fumseck, tn)

preds = pd.read_csv(source + '/original/Pulse6_2019-09-18 15h59.csv')

q1 = 'Total FWS'
q2 = 'Total FLR'

plot_2D(preds, tn, q1, q2)

#===============================================================
# Model preparation for fine-tuning
#===============================================================

# Freeze the first layers and retrain
for layer in fumseck.layers[:5]:
    layer.trainable = False
folder = 'C:/Users/rfuchs/Documents/cyto_classif'
file = 'SSLAMM/Week1/Labelled_Pulse6_2019-09-18 14h35.parq'

date_regex = "(Pulse[0-9]{1,2}_20[0-9]{2}-[0-9]{2}-[0-9]{2} [0-9]{2}(?:u|h)[0-9]{2})"
pred_file = 'Pulse6_2019-05-06 10h09.csv'
os.chdir(folder)

# Load pre-trained model
LottyNet = load_model(
    'C:/Users/rfuchs/Documents/cyto_classif/LottyNet_FUMSECK')

# Making formated predictions
source_path = folder + '/' + file
dest_folder = folder
predict(source_path, folder, LottyNet, tn)

# Getting those predictions
preds = pd.read_csv(folder + '/' + pred_file)

np.mean(preds['True FFT id'] == preds['Pred FFT id'])
print(confusion_matrix(preds['True FFT id'], preds['Pred FFT id']))

colors = [
    '#96ceb4', '#ffeead', '#ffcc5c', '#ff6f69', '#588c7e', '#f2e394',
    '#f2ae72', '#d96459'
]

#####################
# 2D plots
#####################
# Create a log file in the destination folder: list of the already predicted files
preds_store_folder = "C:/Users/rfuchs/Documents/preds_files/P1"  # Where to store the predictions
#log_path = preds_store_folder + "/pred_logs.txt" # Register where write the already predicted files

if not (os.path.isfile(log_path)):
    open(log_path, 'w+').close()

for file in files_to_pred:
    print('Currently predicting ' + file)
    path = export_folder + '/' + file
    is_already_pred = False

    # Check if file has already been predicted
    with open(log_path, "r") as log_file:
        if file in log_file.read():
            is_already_pred = True

    if not (is_already_pred):  # If not, perform the prediction
        # Predict the values
        #format_data(path, precomputed_data_dir, scale = False, \
        #is_ground_truth = False, hard_store = True)
        predict(path, preds_store_folder,  model, tn, \
            is_ground_truth = False, precomputed_data_dir = precomputed_data_dir)

        # Write in the logs that this file is already predicted
        with open(log_path, "a") as log_file:
            log_file.write(file + '\n')

    else:
        print(file, 'already predicted')
files_to_pred = [
    file for file in export_files if re.search(pulse_regex, file)
]  # The files containing the data to predict

# Create a log file in the destination folder: list of the already predicted files
preds_store_folder = "C:/Users/rfuchs/Documents/SSLAMM_P2/SSLAMM_L2"  # Where to store the predictions
log_path = preds_store_folder + "/pred_logs.txt"  # Register where write the already predicted files

if not (os.path.isfile(log_path)):
    open(preds_store_folder + '/pred_logs.txt', 'w+').close()

for file in files_to_pred:
    print('Currently predicting ' + file)
    path = export_folder + '/' + file
    is_already_pred = False

    # Check if file has already been predicted
    with open(log_path, "r") as log_file:
        if file in log_file.read():
            is_already_pred = True

    if not (is_already_pred):  # If not, perform the prediction
        # Predict the values
        predict(path, preds_store_folder, model, tn, is_ground_truth=False)

        # Write in the logs that this file is already predicted
        with open(log_path, "a") as log_file:
            log_file.write(file + '\n')

    else:
        print(file, 'already predicted')