def get_stat(args, at_bool, sed_bool):
    lbs = cfg.lbs
    step_time_in_sec = cfg.step_time_in_sec
    max_len = cfg.max_len
    thres_ary = [0.3] * len(lbs)

    # Calculate AT stat
    if at_bool:
        pd_prob_mat_csv_path = os.path.join(args.pred_dir, "at_prob_mat.csv.gz")
        at_stat_path = os.path.join(args.stat_dir, "at_stat.csv")
        at_submission_path = os.path.join(args.submission_dir, "at_submission.csv")
        
        at_evaluator = evaluate.AudioTaggingEvaluate(
            weak_gt_csv="meta_data/groundtruth_weak_label_testing_set.csv", 
            lbs=lbs)
        
        at_stat = at_evaluator.get_stats_from_prob_mat_csv(
                        pd_prob_mat_csv=pd_prob_mat_csv_path, 
                        thres_ary=thres_ary)
                        
        # Write out & print AT stat
        at_evaluator.write_stat_to_csv(stat=at_stat, 
                                       stat_path=at_stat_path)
        at_evaluator.print_stat(stat_path=at_stat_path)
        
        # Write AT to submission format
        io_task4.at_write_prob_mat_csv_to_submission_csv(
            at_prob_mat_path=pd_prob_mat_csv_path, 
            lbs=lbs, 
            thres_ary=at_stat['thres_ary'], 
            out_path=at_submission_path)
               
    # Calculate SED stat
    if sed_bool:
        sed_prob_mat_list_path = os.path.join(args.pred_dir, "sed_prob_mat_list.csv.gz")
        sed_stat_path = os.path.join(args.stat_dir, "sed_stat.csv")
        sed_submission_path = os.path.join(args.submission_dir, "sed_submission.csv")
        
        sed_evaluator = evaluate.SoundEventDetectionEvaluate(
            strong_gt_csv="meta_data/groundtruth_strong_label_testing_set.csv", 
            lbs=lbs, 
            step_sec=step_time_in_sec, 
            max_len=max_len)
                            
        # Write out & print SED stat
        sed_stat = sed_evaluator.get_stats_from_prob_mat_list_csv(
                    pd_prob_mat_list_csv=sed_prob_mat_list_path, 
                    thres_ary=thres_ary)
                    
        # Write SED to submission format
        sed_evaluator.write_stat_to_csv(stat=sed_stat, 
                                        stat_path=sed_stat_path)                     
        sed_evaluator.print_stat(stat_path=sed_stat_path)
        
        # Write SED to submission format
        io_task4.sed_write_prob_mat_list_csv_to_submission_csv(
            sed_prob_mat_list_path=sed_prob_mat_list_path, 
            lbs=lbs, 
            thres_ary=thres_ary, 
            step_sec=step_time_in_sec, 
            out_path=sed_submission_path)
                                                        
    print("Calculating stat finished!")
def get_stat(args, at_bool, sed_bool):
    print("Get Stat started!")
    t_stat = time.time()
    lbs = cfg.lbs
    step_time_in_sec = cfg.step_time_in_sec
    max_len = cfg.max_len
    thres_ary = [0.001, 0.999]#[0.5] * len(lbs)

    # Calculate AT stat
    if at_bool:
        pd_prob_mat_csv_path = os.path.join(args.pred_dir, "at_prob_mat.csv.gz")
        at_stat_path = os.path.join(args.stat_dir, "at_stat.csv")
        at_submission_path = os.path.join(args.submission_dir, "at_submission.csv")
        
        at_evaluator = evaluate.AudioTaggingEvaluate(
            weak_gt_csv=args.gt_weak_csv, #finn todo
            lbs=lbs)
        
        at_stat = at_evaluator.get_stats_from_prob_mat_csv(
                        pd_prob_mat_csv=pd_prob_mat_csv_path, 
                        thres_ary=thres_ary)
                        
        # Write out & print AT stat
        at_evaluator.write_stat_to_csv(stat=at_stat, 
                                       stat_path=at_stat_path)
        at_evaluator.print_stat(stat_path=at_stat_path)
        
        # Write AT to submission format
        io_task4.at_write_prob_mat_csv_to_submission_csv(
            at_prob_mat_path=pd_prob_mat_csv_path, 
            lbs=lbs, 
            thres_ary=at_stat['thres_ary'], 
            out_path=at_submission_path)
               
    # Calculate SED stat
    if sed_bool:
        sed_prob_mat_list_path = os.path.join(args.pred_dir, "sed_prob_mat_list.csv.gz")
        sed_stat_path = os.path.join(args.stat_dir, "sed_stat.csv")
        sed_submission_path = os.path.join(args.submission_dir, "sed_submission.csv")
        
        sed_evaluator = evaluate.SoundEventDetectionEvaluate(
            strong_gt_csv=args.gt_strong_csv, #finn todo
            lbs=lbs, 
            step_sec=step_time_in_sec, 
            max_len=max_len)
                            
        # Write out & print SED stat
        sed_stat = sed_evaluator.get_stats_from_prob_mat_list_csv(
                    pd_prob_mat_list_csv=sed_prob_mat_list_path, 
                    thres_ary=thres_ary)
                    
        # Write SED to submission format
        sed_evaluator.write_stat_to_csv(stat=sed_stat, 
                                        stat_path=sed_stat_path)                     
        sed_evaluator.print_stat(stat_path=sed_stat_path)
        
        # Write SED to submission format
        io_task4.sed_write_prob_mat_list_csv_to_submission_csv(
            sed_prob_mat_list_path=sed_prob_mat_list_path, 
            lbs=lbs, 
            thres_ary=thres_ary, 
            step_sec=step_time_in_sec, 
            out_path=sed_submission_path)
                                                        
    print("FINN Calculating stat finished!, time: ", (time.time()-t_stat))#this is really just me seeing if this is where most of the time is spent)
def get_stat_eval(args, at_bool, sed_bool):
    lbs = cfg.lbs
    step_time_in_sec = 10.0 / args['timesteps']
    max_len = args['timesteps']
    thres_ary = args['threshold']

    # Calculate AT stat
    if at_bool:
        pd_prob_mat_csv_path = os.path.join(args['pred_dir'],
                                            "at_prob_mat_eval.csv.gz")
        at_stat_path = os.path.join(args['stat_dir'], "at_stat_eval.csv")
        at_submission_path = os.path.join(
            args['submission_dir'],
            "at_submission_eval-{0}.csv".format(args['defined_name']))

        at_evaluator = evaluate.AudioTaggingEvaluate(
            weak_gt_csv="./metadata/eval/eval_file_available_weak_labels.csv",
            lbs=lbs)

        at_stat = at_evaluator.get_stats_from_prob_mat_csv(
            pd_prob_mat_csv=pd_prob_mat_csv_path, thres_ary=thres_ary)

        # Write out & print AT stat
        at_evaluator.write_stat_to_csv(stat=at_stat, stat_path=at_stat_path)
        at_evaluator.print_stat(stat_path=at_stat_path)

        # Write AT to submission format
        io_config.at_write_prob_mat_csv_to_submission_csv(
            at_prob_mat_path=pd_prob_mat_csv_path,
            lbs=lbs,
            thres_ary=at_stat['thres_ary'],
            out_path=at_submission_path)

    # Calculate SED stat
    if sed_bool:
        sed_prob_mat_list_path = os.path.join(args['pred_dir'],
                                              "sed_prob_mat_list_eval.csv.gz")
        sed_stat_path = os.path.join(args['stat_dir'], "sed_stat_eval.csv")
        sed_submission_path = os.path.join(
            args['submission_dir'],
            "sed_submission_eval-{0}.csv".format(args['defined_name']))

        sed_evaluator = evaluate.SoundEventDetectionEvaluate(
            strong_gt_csv=
            "./metadata/eval/eval_file_available_strong_labels.csv",
            lbs=lbs,
            step_sec=step_time_in_sec,
            max_len=max_len)

        # Write out & print SED stat
        sed_stat = sed_evaluator.get_stats_from_prob_mat_list_csv(
            pd_prob_mat_list_csv=sed_prob_mat_list_path, thres_ary=thres_ary)

        # Write SED to submission format
        sed_evaluator.write_stat_to_csv(stat=sed_stat, stat_path=sed_stat_path)
        sed_evaluator.print_stat(stat_path=sed_stat_path)

        # Write SED to submission format
        io_config.sed_write_prob_mat_list_csv_to_submission_csv(
            sed_prob_mat_list_path=sed_prob_mat_list_path,
            lbs=lbs,
            thres_ary=thres_ary,
            step_sec=step_time_in_sec,
            out_path=sed_submission_path)

    print("Calculating stat for Evaluation Dataset finished!")