def get_Bleu_for_beam(scp_paths_decoding, model, text_file_dict,
                      plot_path_name, args):
    #-----------------------------------
    """ If you see best-hypothesis having worse WER that the remainig beam them tweak with the beam hyperpearmaeters Am_wt, len_pen, gamma 
        If you see best-hypothesis having better performance than the oothers in the beam then improve the model training
    """
    #-----------------------------------
    #import pdb;pdb.set_trace()
    for line in scp_paths_decoding:
        line = line.strip()
        key = line.split(' ')[0]

        feat_path = line.split(' ')[1:]
        feat_path = feat_path[0].strip()

        #-----------------------------------
        ####get the model predictions
        Output_seq = model.predict(feat_path, args)
        #Output_seq = model.predict(input,args.LM_model,args.Am_weight,args.beam,args.gamma,args.len_pen)

        ###get the true label if it exists
        True_label = text_file_dict.get(key, None)
        #-----------------------------------

        llr = [item.get('score').unsqueeze(0) for item in Output_seq]
        norm_llr = torch.nn.functional.softmax(torch.cat(llr, dim=0), dim=0)

        print("final_ouputs", '====', 'key', 'Text_seq', 'LLR',
              'Beam_norm_llr', 'Yseq', 'CER')
        print("True_label", True_label)

        #-----------------------------------
        #-----------------------------------
        #import pdb;pdb.set_trace()
        for ind, seq in enumerate(Output_seq):
            Text_seq = seq['Text_seq'][0]
            Text_seq_formatted = [x for x in Text_seq.split(' ') if x.strip()]
            Yseq = seq['yseq'].data.numpy()
            Ynorm_llr = norm_llr[ind].data.numpy()
            Yllr = seq['score'].data.data.numpy()

            #
            #---------------------------------------------
            attention_record = seq.get('alpha_i_list', 'None')

            #if (attention_record) or (attention_record=='None'):

            if (torch.is_tensor(attention_record)):
                #---------------------------------------------
                attention_record = attention_record[:, :, 0].transpose(0, 1)
                attention_record = attention_record.data.cpu().numpy()

                #---------------------------------------------
                if args.plot_decoding_pics:
                    pname = str(key) + '_beam_' + str(ind)
                    plotting_name = join(plot_path_name, pname)
                    plotting(plotting_name, attention_record)

            #-----------------------------------
            #-----------------------------------
            if True_label:
                CER = compute_cer(" ".join(Text_seq_formatted),
                                  " ".join(True_label), 'doesnot_matter') * 100
            else:
                CER = None

            #---------------------------------------------
            if ind == 0:
                print("nbest_output", '=', key, '=',
                      " ".join(Text_seq_formatted), '=', " ".join(True_label),
                      '=', CER)

            print("final_ouputs", '=', ind, '=', key, '=', Text_seq, '=', Yllr,
                  '=', Ynorm_llr, '=', Yseq, '=', CER)
def main():
        ##Load setpiece models for Dataloaders
        Word_model=Load_sp_models(args.Word_model_path)
        Char_model=Load_sp_models(args.Char_model_path)
        ###initilize the model
        model,optimizer=Initialize_Att_model(args)
        #============================================================
        #------------------------------------------------------------  
        #
        train_gen = DataLoader(files=glob.glob(args.data_dir + "train_scp_splits/aa_*") 
                                + glob.glob(args.data_dir + "train_scp_splits_temp1/bb_*")
                                + glob.glob(args.data_dir + "train_scp_splits_temp2/cc_*"),
                                max_batch_label_len=args.max_batch_label_len,
                                max_batch_len=args.max_batch_len,
                                max_feat_len=args.max_feat_len,
                                max_label_len=args.max_label_len,
                                Word_model=Word_model,
                                Char_model=Char_model,
                                apply_cmvn=int(args.apply_cmvn))


        dev_gen = DataLoader(files=glob.glob(args.data_dir + "dev_scp"),
                                max_batch_label_len=args.max_batch_label_len,
                                max_batch_len=args.max_batch_len,
                                max_feat_len=5000,
                                max_label_len=1000,
                                Word_model=Word_model,
                                Char_model=Char_model,
                                apply_cmvn=int(args.apply_cmvn))


        #Flags that may change while training 
        if args.spec_aug_flag==2:
                weight_noise_flag=False
                spec_aug_flag=True
        val_history=np.zeros(args.nepochs)
        #======================================
        for epoch in range(args.nepochs):
            ##start of the epoch
            tr_CER=[]; tr_BPE_CER=[]; L_train_cost=[]
            model.train();
            for trs_no in range(args.validate_interval):
                B1 = train_gen.next()
                assert B1 is not None, "None should never come out of the DataLoader"

                Output_trainval_dict=train_val_model(smp_no=trs_no,
                                                    args = args, 
                                                    model = model,
                                                    optimizer = optimizer,
                                                    data_dict = B1,
                                                    weight_noise_flag=weight_noise_flag,
                                                    spec_aug_flag=spec_aug_flag,
                                                    trainflag = True)
                #
                #
                #get the losses form the dict
                L_train_cost.append(Output_trainval_dict.get('cost_cpu'))
                tr_CER.append(Output_trainval_dict.get('Char_cer'))
                tr_BPE_CER.append(Output_trainval_dict.get('Word_cer'))
                #attention_map=Output_trainval_dict.get('attention_record').data.cpu().numpy()
                #==========================================
                if (trs_no%args.tr_disp==0):
                    print("tr ep:==:>",epoch,"sampl no:==:>",trs_no,"train_cost==:>",mean(L_train_cost),"CER:",mean(tr_CER),'BPE_CER',mean(tr_BPE_CER),flush=True)    
                    #------------------------
                    if args.plot_fig_training:
                        plot_name=join(png_dir,'train_epoch'+str(epoch)+'_attention_single_file_'+str(trs_no)+'.png')

                        plotting(plot_name,attention_map)
            
            ###validate the model
            model.eval()
            #=======================================================
            Vl_CER=[]; Vl_BPE_CER=[];L_val_cost=[]
            val_examples=0
            for vl_smp in range(args.max_val_examples):
                B1 = dev_gen.next()
                smp_feat = B1.get('smp_feat')
                val_examples+=smp_feat.shape[0]
                assert B1 is not None, "None should never come out of the DataLoader"

                ##brak when the examples are more
                if (val_examples >= args.max_val_examples):
                    break;
                #--------------------------------------                
                Val_Output_trainval_dict=train_val_model(smp_no=trs_no,
                                                        args=args,
                                                        model = model,
                                                        optimizer = optimizer,
                                                        data_dict = B1,
                                                        weight_noise_flag=False,
                                                        spec_aug_flag=False,
                                                        trainflag = False)
            
                L_val_cost.append(Val_Output_trainval_dict.get('cost_cpu'))
                Vl_CER.append(Val_Output_trainval_dict.get('Char_cer'))
                Vl_BPE_CER.append(Val_Output_trainval_dict.get('Word_cer'))
                #attention_map=Val_Output_trainval_dict.get('attention_record').data.cpu().numpy()

                #======================================================     
                #======================================================
                if (vl_smp%args.vl_disp==0) or (val_examples==args.max_val_examples-1):
                    print("val epoch:==:>",epoch,"val smp no:==:>",vl_smp,"val_cost:==:>",mean(L_val_cost),"CER:",mean(Vl_CER),'BPE_CER',mean(Vl_BPE_CER),flush=True)    

                    if args.plot_fig_validation:
                        plot_name=join(png_dir,'val_epoch'+str(epoch)+'_attention_single_file_'+str(vl_smp)+'.png')                                 
                        plotting(plot_name,attention_map)                             
            #----------------------------------------------------
#==================================================================
            val_history[epoch]=(mean(Vl_CER)*100)
            print("val_history:",val_history[:epoch+1])
            #================================================================== 
            ####saving_weights 
            ct="model_epoch_"+str(epoch)+"_sample_"+str(trs_no)+"_"+str(mean(L_train_cost))+"___"+str(mean(L_val_cost))+"__"+str(mean(Vl_CER))
            print(ct)
            torch.save(model.state_dict(),join(args.model_dir,str(ct)))
            ####saving otpimizer helped Transformer
            #torch.save(optimizer.state_dict(),join(args.model_dir,str(ct)+'_opt'))

            #######################################################                    
            #######################################################
            ###open the file write and close it to avoid delays
            with open(args.weight_text_file,'a+') as weight_saving_file:
                print(join(args.model_dir,str(ct)), file=weight_saving_file)

            with open(args.Res_text_file,'a+') as Res_saving_file:
                print(float(mean(Vl_CER)), file=Res_saving_file)
            #=================================
             # early_stopping and checkpoint averaging:                    
            if args.early_stopping:
                 A=val_history
                 Non_zero_loss=A[A>0]
                 min_cpts=np.argmin(Non_zero_loss)
                 Non_zero_len=len(Non_zero_loss)

                 if ((Non_zero_len-min_cpts)>1):
                                weight_noise_flag=True
                                spec_aug_flag=True

                 if (Non_zero_len-min_cpts) > args.early_stopping_patience:                                
                    print("The model is early stopping........","minimum value of model is:",min_cpts)
                    exit(0)
Beispiel #3
0
def main():
    ##
    ##Load setpiece models for Dataloaders
    Word_model = Load_sp_models(args.Word_model_path)
    Char_model = Load_sp_models(args.Char_model_path)
    ###initilize the model
    model, optimizer = Initialize_Att_model(args)
    #============================================================
    #------------------------------------------------------------
    train_gen = DataLoader(files=glob.glob(args.train_path + "*"),
                           max_batch_label_len=20000,
                           max_batch_len=args.max_batch_len,
                           max_feat_len=args.max_feat_len,
                           max_label_len=args.max_label_len,
                           Word_model=Word_model,
                           Char_model=Char_model,
                           text_file=args.text_file)

    dev_gen = DataLoader(files=glob.glob(args.dev_path + "*"),
                         max_batch_label_len=20000,
                         max_batch_len=args.max_batch_len,
                         max_feat_len=args.max_feat_len,
                         max_label_len=args.max_label_len,
                         Word_model=Word_model,
                         Char_model=Char_model,
                         text_file=args.text_file)
    #
    #
    #Flags that may change while training
    weight_noise_flag = False
    spec_aug_flag = False
    val_history = np.zeros(args.nepochs)
    #======================================
    for epoch in range(args.nepochs):
        ##start of the epoch
        tr_CER = []
        tr_BPE_CER = []
        L_train_cost = []
        model.train()
        for trs_no in range(args.validate_interval):

            B1 = train_gen.next()
            assert B1 is not None, "None should never come out of the DataLoader"

            Output_trainval_dict = train_val_model(
                args=args,
                model=model,
                optimizer=optimizer,
                data_dict=B1,
                weight_noise_flag=weight_noise_flag,
                spec_aug_flag=spec_aug_flag,
                trainflag=True)
            #
            #
            #get the losses form the dict
            L_train_cost.append(Output_trainval_dict.get('cost_cpu'))
            tr_CER.append(Output_trainval_dict.get('Char_cer'))
            tr_BPE_CER.append(Output_trainval_dict.get('Word_cer'))
            attention_map = Output_trainval_dict.get(
                'attention_record').data.cpu().numpy()
            #==========================================
            if (trs_no % args.tr_disp == 0):
                print("tr ep:==:>",
                      epoch,
                      "sampl no:==:>",
                      trs_no,
                      "train_cost==:>",
                      mean(L_train_cost),
                      "CER:",
                      mean(tr_CER),
                      'BPE_CER',
                      mean(tr_BPE_CER),
                      flush=True)
                #------------------------
                if args.plot_fig_training:
                    plot_name = join(
                        png_dir, 'train_epoch' + str(epoch) +
                        '_attention_single_file_' + str(trs_no) + '.png')
                    #print(plot_name)
                    plotting(plot_name, attention_map)

        ###validate the model
        #=========================================
        #Vl_Output_dict=validate_the_model(args,epoch,dev_gen,model_encoder,model_decoder,encoder_optim,decoder_optim)
        #=======================================================
        model.eval()
        #=======================================================
        Vl_CER = []
        Vl_BPE_CER = []
        L_val_cost = []
        val_examples = 0
        for vl_smp in range(args.max_val_examples):
            B1 = dev_gen.next()
            smp_feat = B1.get('smp_feat')
            val_examples += smp_feat.shape[0]
            assert B1 is not None, "None should never come out of the DataLoader"

            ##brak when the examples are more
            if (val_examples >= args.max_val_examples):
                break
            #--------------------------------------
            Val_Output_trainval_dict = train_val_model(args=args,
                                                       model=model,
                                                       optimizer=optimizer,
                                                       data_dict=B1,
                                                       weight_noise_flag=False,
                                                       spec_aug_flag=False,
                                                       trainflag=False)

            L_val_cost.append(Val_Output_trainval_dict.get('cost_cpu'))
            Vl_CER.append(Val_Output_trainval_dict.get('Char_cer'))
            Vl_BPE_CER.append(Val_Output_trainval_dict.get('Word_cer'))
            attention_map = Val_Output_trainval_dict.get(
                'attention_record').data.cpu().numpy()

            #======================================================
            #======================================================
            if (vl_smp % args.vl_disp == 0) or (val_examples
                                                == args.max_val_examples - 1):
                print("val epoch:==:>",
                      epoch,
                      "val smp no:==:>",
                      vl_smp,
                      "val_cost:==:>",
                      mean(L_val_cost),
                      "CER:",
                      mean(Vl_CER),
                      'BPE_CER',
                      mean(Vl_BPE_CER),
                      flush=True)
                if args.plot_fig_validation:
                    plot_name = join(
                        png_dir, 'val_epoch' + str(epoch) +
                        '_attention_single_file_' + str(vl_smp) + '.png')
                    ##print(plot_name)
                    plotting(plot_name, attention_map)

        #----------------------------------------------------
#==================================================================
        val_history[epoch] = (mean(Vl_CER) * 100)
        print("val_history:", val_history[:epoch + 1])
        #==================================================================
        ####saving_weights
        ct = "model_epoch_" + str(epoch) + "_sample_" + str(
            trs_no) + "_" + str(mean(L_train_cost)) + "___" + str(
                mean(L_val_cost)) + "__" + str(mean(Vl_CER))
        print(ct)
        torch.save(model.state_dict(), join(args.model_dir, str(ct)))
        #######################################################
        # decoder_ct="decoder_" + str(ct)
        # print(decoder_ct)
        # torch.save(model_decoder.state_dict(),join(args.model_dir,str(decoder_ct)))

        #######################################################
        ###open the file write and close it to avoid delays
        with open(args.weight_text_file, 'a+') as weight_saving_file:
            print(join(args.model_dir, str(ct)), file=weight_saving_file)

        with open(args.Res_text_file, 'a+') as Res_saving_file:
            print(float(mean(Vl_CER)), file=Res_saving_file)
        #=================================

        #early_stopping and checkpoint averaging:
        if args.reduce_learning_rate_flag:
            #=================================================================
            A = val_history
            Non_zero_loss = A[A > 0]
            min_cpts = np.argmin(Non_zero_loss)
            Non_zero_len = len(Non_zero_loss)

            if (
                (Non_zero_len - min_cpts) > 1
            ) and epoch > args.lr_redut_st_th:  #args.early_stopping_checkpoints:
                reduce_learning_rate(optimizer)
                #reduce_learning_rate(decoder_optim)

                ###start regularization only when model starts to overfit
                weight_noise_flag = True
                spec_aug_flag = True
            #------------------------------------
            for param_group in optimizer.param_groups:
                lr = param_group['lr']
                print("learning rate of the epoch:", epoch, "is", lr)

            if args.early_stopping:
                #------------------------------------
                if lr <= 1e-8:
                    print("lr reached to a minimum value")
                    exit(0)
def get_Bleu_for_beam(key, Src_tokens, Src_text, Tgt_tokens, Tgt_text, model,
                      plot_path, args):
    import sacrebleu
    from sacrebleu import sentence_bleu
    SMOOTH_VALUE_DEFAULT = 1e-8

    #-----------------------------------
    """ If you see best-hypothesis having worse WER that the remainig beam them tweak with the beam hyperpearmaeters Am_wt, len_pen, gamma 
        If you see best-hypothesis having better performance than the oothers in the beam then improve the model training
    """
    #-----------------------------------
    #-----------------------------------
    ####get the model predictions
    Output_seq = model.predict(Src_tokens, args)
    #Output_seq = model.predict(input,args.LM_model,args.Am_weight,args.beam,args.gamma,args.len_pen)

    ###get the true label if it exists
    True_label = Tgt_text
    #-----------------------------------

    llr = [item.get('score').unsqueeze(0) for item in Output_seq]
    norm_llr = torch.nn.functional.softmax(torch.cat(llr, dim=0), dim=0)

    print("final_ouputs", '====', 'key', 'Text_seq', 'LLR', 'Beam_norm_llr',
          'Yseq', 'CER')
    print("True_label", True_label)

    #-----------------------------------
    #-----------------------------------

    for ind, seq in enumerate(Output_seq):
        Text_seq = seq['Text_seq']
        if len(Text_seq) > 1:
            Text_seq = Text_seq[0]
            Text_seq_formatted = [x for x in Text_seq.split(' ') if x.strip()]
            Text_seq_formatted = " ".join(Text_seq_formatted)
        else:
            Text_seq_formatted = Text_seq[0]

        Yseq = seq['yseq'].data.numpy()
        Ynorm_llr = norm_llr[ind].data.numpy()
        Yllr = seq['score'].data.data.numpy()

        #---------------------------------------------
        attention_record = seq.get('alpha_i_list', 'None')

        if (torch.is_tensor(attention_record)):
            #---------------------------------------------
            attention_record = attention_record[:, :, 0].transpose(0, 1)
            attention_record = attention_record.data.cpu().numpy()

            #---------------------------------------------
            if args.plot_decoding_pics:
                pname = str(key) + '_beam_' + str(ind)
                plotting_name = join(plot_path_name, pname)
                plotting(plotting_name, attention_record)

        #-----------------------------------
        #-----------------------------------

        if True_label:
            if Text_seq_formatted.strip():
                CER = compute_cer(Text_seq_formatted, True_label,
                                  'doesnot_matter') * 100
            else:
                CER = 100
            #breakpoint()
            hyp_value = Text_seq_formatted
            ref_value = True_label
            Bleu_score = sentence_bleu(hyp_value, [ref_value],
                                       smooth_value=SMOOTH_VALUE_DEFAULT,
                                       smooth_method='exp',
                                       use_effective_order='True')

            Bleu_score = Bleu_score.score
        else:
            CER = None
            Bleu_score = None
        #---------------------------------------------
        if ind == 0:
            print("nbest_output", '=', key, '=', Text_seq_formatted, '=',
                  True_label, '=', CER, '=', Bleu_score)

        print("final_ouputs", '=', ind, '=', key, '=', Text_seq_formatted, '=',
              Yllr, '=', Ynorm_llr, '=', Yseq, '=', CER, '=', Bleu_score)