def main():
    #Load the model from architecture
    model, optimizer = Initialize_Att_model(args)
    model.eval()
    args.gpu = False

    ###make SWA name
    model_name = str(args.model_dir).split('/')[-1]
    ct = model_name + '_SWA_random_tag_' + str(args.SWA_random_tag)

    ##check the Weight averaged file and if the file does not exist then lcreate them
    ## if the file exists load them
    if not isfile(join(args.model_dir, ct)):
        model_names, checkpoint_ter = get_best_weights(args.weight_text_file,
                                                       args.Res_text_file)
        model_names_checkpoints = model_names[:args.early_stopping_checkpoints]
        model = Stocasting_Weight_Addition(model, model_names_checkpoints)
        torch.save(model.state_dict(), join(args.model_dir, ct))
    else:
        print("taking the weights from", ct, join(args.model_dir, str(ct)))
        args.pre_trained_weight = join(args.model_dir, str(ct))
        model, optimizer = Initialize_Att_model(args)
    #---------------------------------------------
    model.eval()
    print("best_weight_file_after stocastic weight averaging")
    #=================================================
    model = model.cuda() if args.gpu else model
    plot_path = join(args.model_dir, 'decoding_files', 'plots')
    #=================================================
    #=================================================
    ####read all the scps and make large scp with each lines as a feature
    decoding_files_list = glob.glob(args.dev_path + "*")
    scp_paths_decoding = []
    for i_scp in decoding_files_list:
        scp_paths_decoding_temp = open(i_scp, 'r').readlines()
        scp_paths_decoding += scp_paths_decoding_temp

    #scp_paths_decoding this should contain all the scp files for decoding
    #====================================================
    ###sometime i tend to specify more jobs than maximum number of lines in that case python indexing error we get
    job_no = int(args.Decoding_job_no) - 1

    #args.gamma=0.5
    #print(job_no)
    #####get_cer_for_beam takes a list as input
    present_path = [scp_paths_decoding[job_no]]

    text_file_dict = {
        line.split(' ')[0]: line.strip().split(' ')[1:]
        for line in open(args.text_file)
    }
    get_cer_for_beam(present_path, model, text_file_dict, plot_path, args)
Example #2
0
def main():
    args.gpu = False
    best_weight_file = get_best_weights(args.weight_text_file,
                                        args.Res_text_file)
    print("best_weight_file", best_weight_file)
    args.pre_trained_weight = join(best_weight_file)

    #=================================================
    model, optimizer = Initialize_Att_model(args)
    model.eval()
    model = model.cuda() if args.gpu else model

    #=================================================
    plot_path = join(args.model_dir, 'decoding_files', 'plots')
    if not isdir(plot_path):
        os.makedirs(plot_path)
    #=================================================
    ####read all the scps and make large scp with each lines as a feature
    decoding_files_list = glob.glob(args.dev_path + "*")
    scp_paths_decoding = []
    for i_scp in decoding_files_list:
        scp_paths_decoding_temp = open(i_scp, 'r').readlines()
        scp_paths_decoding += scp_paths_decoding_temp

    #scp_paths_decoding this should contain all the scp files for decoding
    #====================================================
    ###sometime i tend to specify more jobs than maximum number of lines in that case python indexing error we get
    job_no = int(args.Decoding_job_no) - 1

    #args.gamma=0.5
    #print(job_no)
    #####get_cer_for_beam takes a list as input
    present_path = [scp_paths_decoding[job_no]]

    #print(present_path)
    text_file_dict = {
        line.split(' ')[0]: line.strip().split(' ')[1:]
        for line in open(args.text_file)
    }
    get_cer_for_beam(present_path, model, text_file_dict, plot_path, args)