def get_time_for_files(translator, pairs,split_type='none'):
    times = []
    print(len(pairs))
    for pair in pairs:
        single_time = get_time_for_single(translator,pair,split_type)
        write_to_log("Time = {}".format(single_time))
        times.append(single_time)
    return times
def get_results(omega, model_num=2, split_type="none", fm_flag=False):
    med_results = []
    norm_med_results = []
    bow_results = []
    for pair in test_train_split:
        train = get_splits(pair[1], split_type)
        test = get_splits(pair[0], split_type)
        translator = get_translator(train, omega, fm_flag, model_num)
        med, bow, norm_med = useful_functions.get_med_bow_norm(
            translator, test)
        write_to_log("Edit distance = {}\n".format(med))
        med_results.append(med)
        bow_results.append(bow)
        norm_med_results.append(norm_med)
    return med_results, bow_results, norm_med_results
def get_med_bow_norm(translator,test_set):
    meds = 0
    bows = 0
    meds_per_tok = 0
    for ts, truth in test_set:
        predict = translator(ts)
        pred_str = " ".join(predict)
        write_to_log("predict: "+pred_str+"\n")
        # print(pred_str)
        med = minimum_edit_distance(predict, truth)
        # write_to_log(str(med)+"\n")
        meds += med
        bows += bag_of_words_test(predict, truth)
        meds_per_tok += minimum_edit_distance_per_token(predict, truth)
    return meds,bows,meds_per_tok/len(test_set)
def validate_ibmmodel(omega_range,
                      model_num=2,
                      split_type="none",
                      fm_flag=False):
    print("Validate ibmmodel{}".format(model_num))
    preprocessed_train = get_splits(train_test_data, split_type)
    preprocessed_valid = get_splits(validation_set, split_type)
    results = []
    for omega in omega_range:
        message = "omega {}\n".format(omega)
        write_to_log(message)
        print("omega", omega)
        translator = get_translator(preprocessed_train, omega, fm_flag,
                                    model_num)
        med, _, _ = useful_functions.get_med_bow_norm(translator,
                                                      preprocessed_valid)
        print(med)
        write_to_log("Edit distance = {}\n".format(med))
        results.append((omega, med))
    return results
    # med,bow,norm = get_results_from_file("logs/results_enhanced.txt","enhanced")
    # print(med)
    # print(bow)
    # print(norm)
    # med =   [199, 118, 170, 280, 115, 184, 75, 58, 161]
    # bl_med =[161, 107, 134, 187, 89, 137, 88, 80, 129]
    # print(sum(med))
    # print(sum(bl_med))

    test_num = 3

    if test_num == 0:
        message = "omega v1 100"
        print(message)
        write_to_log(message)
        decoder_with_log.beam_size = 100
        omegas = [3.0, 2.7, 2.3]
        validate_ibmmodel(omegas, 1, "none", False)
    elif test_num == 1:
        message = "enhanced omega v2"
        print(message)
        write_to_log(message)
        omegas = [3.0]
        validate_ibmmodel(omegas, 2, "enhanced", True)
    elif test_num == 2:
        message = "results split v2"
        print(message)
        write_to_log(message)
        split_v2_omega = 2.9
        get_results(split_v2_omega, 2, "split", False)