# print([u.subject_id for u in train_subject_instances])
# print([u.subject_id for u in val_subject_instances])
#utility.diagnose_training_subjects(all_subject_instances)
#data_generators.diagnose_generator_multiple_signal(train_gen , sig_type_source)
#data_generators.diagnose_generator_multiple_signal(val_gen , sig_type_source)

#create model
if model_type == 'Unetxl':
    sig_model = network_models.Unet_xl(input_size, kernel_size, filter_number,
                                       len(sig_type_source), no_layers)

#make model parallel
sig_model = nn.DataParallel(sig_model.cuda(), device_ids=[0, 1])

#loss function is negative pearson loss
loss = network_models.PearsonRLoss()

#training configs
cudnn.benchmark = True
config['n_epochs'] = 400
config['scheduler_milestones'] = [30, 60, 120]  #[50,100,200]
config['train_steps'] = 10
config['val_steps'] = 10
config['initial_lr'] = 0.001
config['model_path'] = directory + '/Code Output/best_so_far.pt'
config[
    'model_path_for_video'] = directory + '/Models for Video/' + file_name_pre
args = {
    'lr': config['initial_lr'],
    'n_epochs': config['n_epochs'],
    'model_path': config['model_path'],
    def test_model(gen, sig_model):
        '''
        function tests a model by running it over the test data and calculating error metrics
        :param gen: test generator
        :param sig_model: model
        :return: list_pearson_r_loss: list of pearson correlations between target and estimated segments
        :return: list_subject_ids: subject ID for each segment
        :return: list_i_errors: list of R-I interval errors between each target and estimated signal segment
        :return: list_j_errors: list of R-J interval errors between each target and estimated signal segment
        :return: list_k_errors: list of R-K interval errors between each target and estimated signal segment
        '''
        criterion = network_models.PearsonRLoss()
        with torch.no_grad():

            finished = False
            list_pearson_r_loss = []
            list_i_errors = []
            list_j_errors = []
            list_k_errors = []
            list_subject_ids = []
            list_noise_var_target = []
            list_noise_var_estimate = []
            list_target_i_points = []
            list_target_j_points = []
            list_target_k_points = []
            list_sdr = []

            while not finished:

                print('Running Testing')
                torch.cuda.empty_cache()
                finished, X_batch, Y_batch, subject_id_list= next(gen)

                if finished:
                    print('Generator Finished')
                else:

                    #get ecg
                    ecg = X_batch[:, -1, :]

                    #calculate pearson correlation
                    X_batch = torch.from_numpy(X_batch[:, :-1, :]).contiguous()
                    X_batch = network_models.cuda(X_batch)
                    X_batch = X_batch.type(torch.cuda.FloatTensor)

                    Y_batch_predicted= sig_model.forward(X_batch).squeeze()

                    Y_batch = torch.from_numpy(Y_batch)
                    Y_batch = network_models.cuda(Y_batch)
                    Y_batch = Y_batch.type(torch.cuda.FloatTensor)
                    Y_batch = Y_batch.squeeze()


                    if len(Y_batch.size())==1:
                        Y_batch=Y_batch.view(1,-1)

                    if len(Y_batch_predicted.size())==1:
                        Y_batch_predicted=Y_batch_predicted.view(1,-1)

                    loss = criterion.get_induvidual_losses(Y_batch_predicted, Y_batch)

                    list_pearson_r_loss+=loss.cpu().numpy().reshape(-1).tolist()
                    list_subject_ids+=subject_id_list

                    #ijk points
                    Y_batch_predicted = Y_batch_predicted.detach().cpu().numpy()
                    Y_batch = Y_batch.detach().cpu().numpy()

                    for v in range(Y_batch.shape[0]):
                        r_peaks = signal_processing_modules.get_R_peaks(ecg[v, :])
                        ensemble_avg_target, ensemble_beats_target = signal_processing_modules.get_ensemble_avg(r_peaks,
                                                                                                                (Y_batch[v, :] - np.mean(Y_batch[v,  :]) )/( np.sqrt(np.sum(np.power(  Y_batch[v, :] - np.mean(Y_batch[v, :])  ,2)))) ,
                                                                                                                n_samples=500,upsample_factor=1)

                        i_point_target, j_point_target, k_point_target = signal_processing_modules.get_IJK_peaks(ensemble_avg_target, upsample_factor=1)

                        ensemble_avg_estimate, ensemble_beats_estimate = signal_processing_modules.get_ensemble_avg(r_peaks,
                                                                                                                    ( Y_batch_predicted[v,:]-np.mean(Y_batch_predicted[v,:]) )/(np.sqrt(np.sum(np.power(  Y_batch_predicted[v,:]-np.mean(Y_batch_predicted[v,:]) , 2  )))) ,
                                                                                                                    n_samples=500, upsample_factor=1)
                        i_point_estimate, j_point_estimate, k_point_estimate = signal_processing_modules.get_IJK_peaks(ensemble_avg_estimate, upsample_factor=1)

                        i_error = 1000 * np.abs(i_point_target - i_point_estimate) / (500) if i_point_target != -1 and i_point_estimate != -1 else -1  # 500 is the sampling rate of the signal segments, 1000* for miliseconds
                        j_error = 1000 * np.abs(j_point_target - j_point_estimate) / (500) if j_point_target != -1 and j_point_estimate != -1 else -1
                        k_error = 1000 * np.abs(k_point_target - k_point_estimate) / (500) if k_point_target != -1 and k_point_estimate != -1 else -1

                        list_i_errors.append(i_error)
                        list_j_errors.append(j_error)
                        list_k_errors.append(k_error)
                        list_noise_var_target.append( signal_processing_modules.get_noise_variance(ensemble_avg_target, ensemble_beats_target) )
                        list_noise_var_estimate.append( signal_processing_modules.get_noise_variance(ensemble_avg_estimate, ensemble_beats_estimate ) )
                        list_target_i_points.append(i_point_target)
                        list_target_j_points.append(j_point_target)
                        list_target_k_points.append(k_point_target)
                        list_sdr.append(signal_processing_modules.get_sdr(( Y_batch_predicted[v,:]-np.mean(Y_batch_predicted[v,:]) )/(np.sqrt(np.sum(np.power(  Y_batch_predicted[v,:]-np.mean(Y_batch_predicted[v,:]) , 2  ))))
                                    , (Y_batch[v, :] - np.mean(Y_batch[v,  :]) )/( np.sqrt(np.sum(np.power(  Y_batch[v, :] - np.mean(Y_batch[v, :])  ,2)))) ))



        del X_batch
        del Y_batch_predicted
        del Y_batch

        return list_pearson_r_loss, list_subject_ids, list_i_errors, list_j_errors, list_k_errors, list_noise_var_target, list_noise_var_estimate, list_target_i_points, list_target_j_points, list_target_k_points, list_sdr
def main():
    #sampling rate of training dataset
    F_SAMPLING=2000

    #config
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--n_f',  type=int, help='n_f')
    arg('--L', type=int , help='L')
    args = vars(parser.parse_args())
    print(args)

    #configs
    config = {
                 'file_name_pre': 'mdl_00000_a000_dd_mmm_nn_'+str(args['L'])+'_'+str(args['n_f']),
                 'no_layers': args['L'] ,
                 'filter_number': args['n_f'],
                 'sig_type_source': ['aX', 'aY', 'aZ'],
                 'mode':'both',
                 'eps': 1e-3,
                 'model_type': 'Unetxl',
                 'down_sample_factor':4,
                 'frame_length' : int(4.096*F_SAMPLING),
                 'kernel_size': 7 , #(3,5) , #5,#(3,5), #5, #(3, 5),#5
                 'directory':'/media/sinan/9E82D1BB82D197DB/RESEARCH VLAB work on/Gyroscope SCG project/Deep Learning Paper Code and Materials',
                 'cycle_per_batch':2,
                 'sig_type_target': 'bcg',
                 'loss_func':'pearson_r',
                 'produce_video' : False,
                 'store_in_ram':True,
                 'augment_accel':True,
                 'augment_theta_lim': 10,
                 'augment_prob':0.5

    }

    #
    cycle_per_batch = config['cycle_per_batch']
    mode= config['mode']
    eps =  config['eps']
    kernel_size = config['kernel_size']
    directory= config['directory']
    model_type = config['model_type']
    no_layers=config['no_layers']
    filter_number=config['filter_number']
    sig_type_source=config['sig_type_source']
    sig_type_target=config['sig_type_target']
    down_sample_factor = config['down_sample_factor']
    frame_length=config['frame_length']
    input_size = frame_length//down_sample_factor
    normalized = True if model_type!='Unet_multiple_signal_in_not_normalized' else False
    loss_func = config['loss_func']
    produce_video= config['produce_video']
    store_in_ram=config['store_in_ram']
    augment_accel= config['augment_accel']
    augment_theta_lim = config['augment_theta_lim']
    augment_prob=config['augment_prob']
    file_name_pre = config['file_name_pre']
    file_name_pre = file_name_pre[0:4] + model_type[:-1] + str(no_layers) +file_name_pre[9::]
    axis_string = ''.join(['x' if 'aX' in sig_type_source else '0' , 'y' if 'aY' in sig_type_source else '0' , 'z' if 'aZ' in sig_type_source else '0' ])
    file_name_pre = file_name_pre[:12] + axis_string + file_name_pre[15::]
    print('Model Name: ' + file_name_pre)

    #get all subject data
    all_subject_instances = utility.load_subjects(directory + '/Training Data Analog Acc', store_in_ram )

    #train test split
    train_subject_instances, val_subject_instances = train_test_split( all_subject_instances  , test_size=0.2, random_state=49 )

    #make a train and val generator
    train_gen = data_generators.make_generator_multiple_signal(list_of_subjects=train_subject_instances, cycle_per_batch=cycle_per_batch, eps=eps,frame_length=frame_length,
                                        mode=mode, list_sig_type_source= sig_type_source, sig_type_target= sig_type_target , down_sample_factor =down_sample_factor,
                                                               normalized=normalized , store_in_ram=store_in_ram ,
                                                               augment_accel = augment_accel , augment_theta_lim = augment_theta_lim , augment_prob=augment_prob)

    val_gen = data_generators.make_generator_multiple_signal(list_of_subjects=val_subject_instances, cycle_per_batch=cycle_per_batch, eps=eps, frame_length=frame_length,
                                        mode=mode, list_sig_type_source= sig_type_source, sig_type_target= sig_type_target, down_sample_factor=down_sample_factor,
                                                             normalized=normalized, store_in_ram=store_in_ram)


    #check !
    #utility.diagnose_training_subjects(all_subject_instances)
    #data_generators.diagnose_generator_multiple_signal(train_gen , sig_type_source)
    #data_generators.diagnose_generator_multiple_signal(val_gen , sig_type_source)

    #create model
    if model_type=='Unetxl':
        sig_model = network_models.Unet_xl(input_size, kernel_size, filter_number,  len(sig_type_source) , no_layers )

    #make model parallel
    sig_model = nn.DataParallel(sig_model.cuda(), device_ids=[0,1])

    #loss function is negative pearson loss
    loss = network_models.PearsonRLoss()

    #training configs
    cudnn.benchmark = True
    config['n_epochs'] = 150
    config['scheduler_milestones'] = [30,60,120] #[50,100,200]
    config['train_steps'] = 10
    config['val_steps'] = 10
    config['initial_lr'] = 0.001
    config['model_path'] = directory + '/Code Output/best_so_far.pt'
    config['model_path_for_video'] = directory + '/Models for Video/' + file_name_pre
    args = {'lr': config['initial_lr'],
            'n_epochs':config['n_epochs'],
            'model_path': config['model_path'],
            'step_count':config['train_steps'],
            'val_steps': config['val_steps'],
            'scheduler_milestones': config['scheduler_milestones'],
            'model_path_for_video': config['model_path_for_video']
    }

    #train the model
    start_model_train = time.time()
    train_history, valid_history , best_val= network_models.train_torch_generator_with_video(args=args,
                                             sig_model=sig_model,
                                             criterion=loss,
                                             train_gen=train_gen,
                                             val_gen=val_gen,
                                             init_optimizer = lambda lr: Adam(sig_model.parameters(), lr=lr),
                                            init_schedule = lambda optimizer,milestones: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5) ,
                                             produce_video=produce_video)

    end_model_train = time.time()
    print('Model Training Duration In Seconds: ' + str(end_model_train - start_model_train))
    print('Best Validation Loss: ' + str(best_val))

    #save model
    sig_model = network_models.load_saved_model(model_path=config['model_path'],
                                            model_type= model_type,
                                            input_size=input_size ,
                                            kernel_size=kernel_size  ,
                                            filter_number=filter_number,
                                            signal_number=len(sig_type_source),
                                            no_layers = no_layers)
    torch.save({
        'model': sig_model.state_dict(),
    }, directory+'/Code Output/' + file_name_pre + '.pt')

    #save workspace
    pickle_list = [config, train_history , valid_history, train_subject_instances, val_subject_instances , best_val]
    fileObject = open(directory+'/Code Output/' +file_name_pre+'_pickle','wb')
    pickle.dump(pickle_list,fileObject)
    fileObject.close()

    #print configs to a test file
    with open(directory + '/Code Output/' + file_name_pre + '_config.txt', "w") as text_file:
        print(config, file=text_file)
        print('Model Training Duration In Seconds: ' + str(end_model_train - start_model_train), file=text_file)

    #plot results
    network_models.show_loss_torch_model(train_history, valid_history, file_name_pre , directory+'/Code Output')