running_loss = 0

    loss_all_out = whole_loss / len_train
    print("epoch is %d, the whole loss is %f" % (epoch, loss_all_out))
    # with open("training_data/whole_loss_%.5f_pre_GN_te05_d02_all.txt" % (lr_rate), "a") as f:
    #     f.write("%s\n" % (str(loss_all_out)))

    # this is the prediction and compute wer loss
    total_dist = 0
    total_label = 0
    total_line = 0
    total_line_rec = 0
    whole_loss_t = 0

    encoder.eval()
    attn_decoder1.eval()
    print('Now, begin testing!!')

    for step_t, (x_t, y_t) in enumerate(test_loader):
        x_real_high = x_t.size()[2]
        x_real_width = x_t.size()[3]
        if x_t.size()[0] < batch_size_t:
            break
        print('testing for %.3f%%' % (step_t * 100 * batch_size_t / len_test),
              end='\r')
        h_mask_t = []
        w_mask_t = []
        for i in x_t:
            #h*w
            size_mask_t = i[1].size()
            s_w_t = str(i[1][0])
def for_test(x_t):
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    h_mask_t = []
    w_mask_t = []
    encoder = densenet121()
    attn_decoder1 = AttnDecoderRNN(hidden_size, 256, dropout_p=0.5)

    encoder = torch.nn.DataParallel(encoder, device_ids=device_ids)
    attn_decoder1 = torch.nn.DataParallel(attn_decoder1, device_ids=device_ids)
    if torch.cuda.is_available():
        encoder = encoder.cuda()
        attn_decoder1 = attn_decoder1.cuda()

    encoder.load_state_dict(
        torch.load(
            '../model/encoder_lr0.00000_GN_te1_d05_SGD_bs6_mask_conv_bn_b_xavier.pkl',
            map_location=device))
    attn_decoder1.load_state_dict(
        torch.load(
            '../model/attn_decoder_lr0.00000_GN_te1_d05_SGD_bs6_mask_conv_bn_b_xavier.pkl',
            map_location=device))

    encoder.eval()
    attn_decoder1.eval()
    x_t = Variable(x_t)

    if torch.cuda.is_available():
        x_t = Variable(x_t.cuda())

    x_mask = torch.ones(x_t.size()[0],
                        x_t.size()[1],
                        x_t.size()[2],
                        x_t.size()[3])
    if torch.cuda.is_available():
        x_mask = x_mask.cuda()

    x_t = torch.cat((x_t, x_mask), dim=1)
    x_real_high = x_t.size()[2]
    x_real_width = x_t.size()[3]
    h_mask_t.append(int(x_real_high))
    w_mask_t.append(int(x_real_width))
    x_real = x_t[0][0].view(x_real_high, x_real_width)
    output_highfeature_t = encoder(x_t)

    x_mean_t = torch.mean(output_highfeature_t)
    x_mean_t = float(x_mean_t)
    output_area_t1 = output_highfeature_t.size()
    output_area_t = output_area_t1[3]
    dense_input = output_area_t1[2]

    decoder_input_t = torch.LongTensor([111] * batch_size_t)
    if torch.cuda.is_available():
        decoder_input_t = decoder_input_t.cuda()

    decoder_hidden_t = torch.randn(batch_size_t, 1, hidden_size)
    if torch.cuda.is_available():
        decoder_hidden_t = decoder_hidden_t.cuda()
    # nn.init.xavier_uniform_(decoder_hidden_t)
    decoder_hidden_t = decoder_hidden_t * x_mean_t
    decoder_hidden_t = torch.tanh(decoder_hidden_t)

    prediction = torch.zeros(batch_size_t, maxlen)
    # label = torch.zeros(batch_size_t,maxlen)
    prediction_sub = []
    label_sub = []

    decoder_attention_t = torch.zeros(batch_size_t, 1, dense_input,
                                      output_area_t)
    attention_sum_t = torch.zeros(batch_size_t, 1, dense_input, output_area_t)
    if torch.cuda.is_available():
        decoder_attention_t = decoder_attention_t.cuda()
        attention_sum_t = attention_sum_t.cuda()
    decoder_attention_t_cat = []

    for i in range(maxlen):
        decoder_output, decoder_hidden_t, decoder_attention_t, attention_sum_t = attn_decoder1(
            decoder_input_t, decoder_hidden_t, output_highfeature_t,
            output_area_t, attention_sum_t, decoder_attention_t, dense_input,
            batch_size_t, h_mask_t, w_mask_t, device_ids)

        decoder_attention_t_cat.append(
            decoder_attention_t[0].data.cpu().numpy())
        topv, topi = torch.max(decoder_output, 2)
        if torch.sum(topi) == 0:
            break
        decoder_input_t = topi
        decoder_input_t = decoder_input_t.view(batch_size_t)

        # prediction
        prediction[:, i] = decoder_input_t

    k = numpy.array(decoder_attention_t_cat)
    x_real = numpy.array(x_real.cpu().data)

    prediction = prediction[0]

    prediction_real = []
    for ir in range(len(prediction)):
        if int(prediction[ir]) == 0:
            break
        prediction_real.append(worddicts_r[int(prediction[ir])])
    prediction_real.append('<eol>')

    prediction_real_show = numpy.array(prediction_real)

    return k, prediction_real_show