示例#1
0
def test(args, logger):

    # 返回Net.train()或Net.eval()
    lprnet = build_lprnet(lpr_max_len=args.lpr_max_len, phase=args.phase_train, class_num=len(CHARS), dropout_rate=args.dropout_rate)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    lprnet.to(device)   # 实例化后使用.to方法将网络移动到GPU或CPU
    print("Successful to build network!")   # 到此位置模型搭建完成

    # load pretrained model
    if args.pretrained_model:
        lprnet.load_state_dict(torch.load(args.pretrained_model, map_location=device))
        print("load pretrained model successful!\n")
    else:
        print("[Error] Can't found pretrained mode, please check!")
        return False

    test_img_dirs = os.path.expanduser(args.test_img_dirs)  # 把path中包含的"~"和"~user"转换成用户目录
    test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size, args.lpr_max_len) # lpr_max_len为车牌最大字符数

    # epsilons = np.arange(0, 0.1, 0.005).tolist()
    epsilons = np.arange(0, 0.1, 0.01).tolist()
    for epsilon in epsilons:
        args.epsilon = epsilon
        acc = Greedy_Decode_Eval(lprnet, test_dataset, args)
        logger.info("Epsilon: {}\tAccuracy: {:.4f}".format(args.epsilon, acc))
        print("Epsilon: {}\tAccuracy: {:.4f}".format(args.epsilon, acc))
示例#2
0
def test(args):
    test_img_dirs = './data/my_test'  # 这是裁剪出车牌后的路径

    # 返回Net.train()或Net.eval()
    lprnet = build_lprnet(lpr_max_len=args.lpr_max_len,
                          phase=args.phase_train,
                          class_num=len(CHARS),
                          dropout_rate=args.dropout_rate)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    lprnet.to(device)  # 实例化后使用.to方法将网络移动到GPU或CPU
    print("Successful to build network!")  # 到此位置模型搭建完成

    # load pretrained model
    if args.pretrained_model:
        lprnet.load_state_dict(
            torch.load(args.pretrained_model, map_location=device))
        print("load pretrained model successful!\n")
    else:
        print("[Error] Can't found pretrained mode, please check!")
        return False

    test_img_dirs = os.path.expanduser(
        test_img_dirs)  # 把path中包含的"~"和"~user"转换成用户目录
    test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size,
                                 args.lpr_max_len)  # lpr_max_len为车牌最大字符数
    try:
        Greedy_Decode_Eval(lprnet, test_dataset, args)
    finally:
        cv2.destroyAllWindows()
示例#3
0
def test():
    args = get_parser()

    lprnet = build_lprnet(lpr_max_len=args.lpr_max_len,
                          phase=args.phase_train,
                          class_num=len(CHARS),
                          dropout_rate=args.dropout_rate)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    lprnet.to(device)
    print("Successful to build network!")

    # load pretrained model
    if args.pretrained_model:
        lprnet.load_state_dict(torch.load(args.pretrained_model))
        print("load pretrained model successful!")
    else:
        print("[Error] Can't found pretrained mode, please check!")
        return False

    test_img_dirs = os.path.expanduser(args.test_img_dirs)
    test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size,
                                 args.lpr_max_len)
    try:
        Greedy_Decode_Eval(lprnet, test_dataset, args)
    finally:
        cv2.destroyAllWindows()
示例#4
0
def test():
    args = get_parser()
    # epsilons = [0, 0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.035, 0.04] # 0.07时能降到10%左右
    epsilons = np.arange(0, 0.1, 0.005).tolist()
    accuracies = []

    # 返回Net.train()或Net.eval()
    lprnet = build_lprnet(lpr_max_len=args.lpr_max_len, phase=args.phase_train, class_num=len(CHARS), dropout_rate=args.dropout_rate)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    lprnet.to(device)   # 实例化后使用.to方法将网络移动到GPU或CPU
    print("Successful to build network!")   # 到此位置模型搭建完成

    # load pretrained model
    if args.pretrained_model:
        lprnet.load_state_dict(torch.load(args.pretrained_model, map_location=device))
        print("load pretrained model successful!\n")
    else:
        print("[Error] Can't found pretrained mode, please check!")
        return False

    test_img_dirs = os.path.expanduser(args.test_img_dirs)  # 把path中包含的"~"和"~user"转换成用户目录
    test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size, args.lpr_max_len) # lpr_max_len为车牌最大字符数
    epoch_size = len(test_dataset) // args.test_batch_size # 整除,多余的末尾就不会包括进来了
    test_dataset = DataLoader(test_dataset, args.test_batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn)
    # collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能 
    # shuffle:设置为True的时候,每个世代都会打乱数据集 

    examples = []
    for epsilon in epsilons:
        batch_iterator = iter(test_dataset)
        perturbed_image = Greedy_Decode_Eval(lprnet, batch_iterator, args, epsilon, epoch_size)
        b, g, r = cv2.split(perturbed_image)
        perturbed_image = cv2.merge([r, g, b])
        examples.append(perturbed_image)
    draw_all_images(epsilons, examples) # 画图
示例#5
0
def visualize_stn():
    with torch.no_grad():
        # Get a batch of training data
        dataset = LPRDataLoader([args.img_dirs], args.img_size)
        dataloader = DataLoader(dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=2,
                                collate_fn=collate_fn)
        #imgs, labels, lengths = next(iter(dataloader))
        for imgs, labels, lengths in dataloader:
            input_tensor = imgs.cpu()
            transformed_input_tensor = STN(imgs.to(device)).cpu()

            in_grid = convert_image(torchvision.utils.make_grid(input_tensor))
            out_grid = convert_image(
                torchvision.utils.make_grid(transformed_input_tensor))

            # Plot the results side-by-side
            f, axarr = plt.subplots(1, 2)
            axarr[0].imshow(in_grid)
            axarr[0].set_title('Dataset Images')
            axarr[1].imshow(out_grid)
            axarr[1].set_title('Transformed Images')
            plt.show()
示例#6
0
def test(args):
    test_img_dirs = './data/my_test'  # 这是裁剪出车牌后的路径

    epsilons = [0, 0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.035,
                0.04]  # 0.07时能降到10%左右
    accuracies = []

    # 返回Net.train()或Net.eval()
    lprnet = build_lprnet(lpr_max_len=args.lpr_max_len,
                          phase=args.phase_train,
                          class_num=len(CHARS),
                          dropout_rate=args.dropout_rate)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    lprnet.to(device)  # 实例化后使用.to方法将网络移动到GPU或CPU
    print("Successful to build network!")  # 到此位置模型搭建完成

    # load pretrained model
    if args.pretrained_model:
        lprnet.load_state_dict(
            torch.load(args.pretrained_model, map_location=device))
        print("load pretrained model successful!\n")
    else:
        print("[Error] Can't found pretrained mode, please check!")
        return False

    test_img_dirs = os.path.expanduser(
        test_img_dirs)  # 把path中包含的"~"和"~user"转换成用户目录
    test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size,
                                 args.lpr_max_len)  # lpr_max_len为车牌最大字符数
    epoch_size = len(test_dataset) // args.test_batch_size  # 整除,多余的末尾就不会包括进来了
    test_dataset = DataLoader(test_dataset,
                              args.test_batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              collate_fn=collate_fn)
    # collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
    # shuffle:设置为True的时候,每个世代都会打乱数据集

    try:
        for epsilon in epsilons:
            batch_iterator = iter(test_dataset)
            accuracies.append(
                Greedy_Decode_Eval(lprnet, batch_iterator, args, epsilon,
                                   epoch_size))
    finally:
        cv2.destroyAllWindows()

    # 画出epsilon变化趋势图
    plt.figure(figsize=(5, 5))
    plt.plot(epsilons, accuracies, "*-")
    plt.yticks(np.arange(0, 1.1, step=0.1))
    plt.xticks(np.arange(0, 0.045, step=0.005))
    plt.title("Accuracy vs Epsilon")
    plt.xlabel("Epsilon")
    plt.ylabel("Accuracy")
    plt.show()
def test():
    args = get_parser()

    lprnet = build_lprnet(lpr_max_len=args.lpr_max_len,
                          phase=args.phase_train,
                          class_num=len(CHARS),
                          dropout_rate=args.dropout_rate)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    lprnet.to(device)
    print("Successful to build network!")

    ## 搭建空间变换网络
    STN = STNet()
    STN.to(device)
    STN.load_state_dict(
        torch.load('STN/weights/STN_Model_LJK_CA_XZH.pth',
                   map_location=lambda storage, loc: storage))
    STN.eval()

    print("空间变换网络搭建完成")

    # load pretrained model
    if args.pretrained_model:
        lprnet.load_state_dict(torch.load(args.pretrained_model))
        print("load pretrained model successful!")
    else:
        print("[Error] Can't found pretrained mode, please check!")
        return False

    test_img_dirs = os.path.expanduser(args.test_img_dirs)
    test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size,
                                 args.lpr_max_len)
    try:
        Greedy_Decode_Eval(lprnet, test_dataset, args, STN, device)
    finally:
        cv2.destroyAllWindows()
示例#8
0
def train(args, logger, epsilon=0.04, alpha=0.5):

    T_length = 18 # args.lpr_max_len
    epoch = 0 + args.resume_epoch
    loss_val = 0
    cnt = 0
    # lr = args.learning_rate # 学习率

    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    lprnet = build_lprnet(lpr_max_len=args.lpr_max_len, phase=args.phase_train, class_num=len(CHARS), dropout_rate=args.dropout_rate)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    lprnet.to(device)
    print("Successful to build network!")

    # load pretrained model
    if args.pretrained_model:
        lprnet.load_state_dict(torch.load(args.pretrained_model, map_location=device))
        print("load pretrained model successful!") # 从模型net.train()和net.eval()得出的结果完全不一样
        # test_img_dirs = os.path.expanduser(args.test_img_dirs) # 测试集
        # test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size, args.lpr_max_len)
        # Greedy_Decode_Eval(lprnet, test_dataset, args)
    else:
        def xavier(param):
            nn.init.xavier_uniform(param)

        def weights_init(m):
            for key in m.state_dict():
                if key.split('.')[-1] == 'weight':
                    if 'conv' in key:
                        nn.init.kaiming_normal_(m.state_dict()[key], mode='fan_out')
                    if 'bn' in key:
                        m.state_dict()[key][...] = xavier(1)
                elif key.split('.')[-1] == 'bias':
                    m.state_dict()[key][...] = 0.01

        lprnet.backbone.apply(weights_init)
        lprnet.container.apply(weights_init)
        print("initial net weights successful!")

    # define optimizer
    # optimizer = optim.SGD(lprnet.parameters(), lr=args.learning_rate,
    #                       momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer = optim.RMSprop(lprnet.parameters(), lr=args.learning_rate, alpha = 0.9, eps=1e-08,
                         momentum=args.momentum, weight_decay=args.weight_decay)
    # os.path.expanduser把path中包含的"~"和"~user"转换成用户目录
    train_img_dirs = os.path.expanduser(args.train_img_dirs) # 训练集
    test_img_dirs = os.path.expanduser(args.test_img_dirs) # 测试集
    train_dataset = LPRDataLoader(train_img_dirs.split(','), args.img_size, args.lpr_max_len)
    test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size, args.lpr_max_len)

    epoch_size = len(train_dataset) // args.train_batch_size # 求得批数
    max_iter = args.max_epoch * epoch_size # 共需要循环的次数

    ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean') # reduction: 'none' | 'mean' | 'sum'

    if args.resume_epoch > 0:
        start_iter = args.resume_epoch * epoch_size
    else:
        start_iter = 0

    for iteration in range(start_iter, max_iter):
        if iteration % epoch_size == 0: # 说明新的一个周期开始
            # create batch iterator
            batch_iterator = iter(DataLoader(train_dataset, args.train_batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn))
            loss_val = 0
            cnt = 0 # 用于统计加了几次loss
            epoch += 1

        if iteration !=0 and iteration % args.save_interval == 0:
            torch.save(lprnet.state_dict(), args.save_folder + 'LPRNet_' + '_iteration_' + repr(iteration) + '.pth') # 经过一定的间隔后就保存状态

        if (iteration + 1) % args.test_interval == 0: # 经过一定间隔就评估模型
            Greedy_Decode_Eval(lprnet, test_dataset, args, logger)
            # lprnet.train() # should be switch to train mode

        start_time = time.time()
        # load train data
        images, labels, lengths = next(batch_iterator)
        # labels = np.array([el.numpy() for el in labels]).T
        # print(labels)
        # update lr
        lr = adjust_learning_rate(optimizer, epoch, args.learning_rate, args.lr_schedule)

        if args.cuda:
            images = Variable(images.cuda(), requires_grad=True)
            labels = Variable(labels, requires_grad=False).cuda()
        else:
            images = Variable(images, requires_grad=True)
            labels = Variable(labels, requires_grad=False)

        lprnet.eval() # 先开测试模式
        # forward
        logits = lprnet(images)
        log_probs = logits.permute(2, 0, 1) # for ctc loss: T x N x C
        # print(labels.shape)
        log_probs = log_probs.log_softmax(2).requires_grad_()   # requires_grad_()相当于把requires_grad属性置为1
        # log_probs = log_probs.detach().requires_grad_()
        # print(log_probs.shape)

        # get ctc parameters
        input_lengths, target_lengths = sparse_tuple_for_ctc(T_length, lengths)
        # backprop
        loss1 = ctc_loss(log_probs, labels, input_lengths=input_lengths, target_lengths=target_lengths)
        lprnet.zero_grad()
        loss1.backward(retain_graph=True)
        data_grad = images.grad.data
        perturbed_images = fgsm_attack(images, epsilon, data_grad)  # FGSM攻击

        logits = lprnet(perturbed_images)
        log_probs = logits.permute(2, 0, 1) # for ctc loss: T x N x C
        log_probs = log_probs.log_softmax(2).requires_grad_()   # requires_grad_()相当于把requires_grad属性置为1
        loss2 = ctc_loss(log_probs, labels, input_lengths=input_lengths, target_lengths=target_lengths)
        loss = alpha * loss1 + (1-alpha) * loss2 # 新的loss值

        lprnet.train()
        lprnet.zero_grad()
        optimizer.zero_grad()   # 梯度置0
        loss.backward()
        optimizer.step()
        loss_val += loss.item() # 在输出的时候可以loss_val取平均
        cnt += 1 # loss_val的次数
        end_time = time.time()
        if (iteration + 1) % 20 == 0:
            msg = 'Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size) \
                  + '|| Totel iter ' + repr(iteration) + ' || Loss: %.4f || ' % (loss_val / cnt) + \
                  'Batch time: %.4f sec. ||' % (end_time - start_time) + 'LR: %.8f' % (lr)
            print(msg)
            logger.info(msg) # 存入日志
        # print('Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size)
        #           + '|| Totel iter ' + repr(iteration) + ' || Loss: %.4f||' % (loss_val / cnt) +
        #           'Batch time: %.4f sec. ||' % (end_time - start_time) + 'LR: %.8f' % (lr))

    # final test
    print("Final test Accuracy:")
    logger.info("Final test Accuracy:")
    Greedy_Decode_Eval(lprnet, test_dataset, args, logger)

    # save final parameters
    torch.save(lprnet.state_dict(), args.save_folder + 'Ans_LPRNet_model.pth')
def train():
    args = get_parser()

    T_length = 18 # args.lpr_max_len
    epoch = 0 + args.resume_epoch
    loss_val = 0

    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    lprnet = build_lprnet(lpr_max_len=args.lpr_max_len, phase=args.phase_train, class_num=len(CHARS), dropout_rate=args.dropout_rate)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    lprnet.to(device)
    print("Successful to build network!")

    # load pretrained model
    if args.pretrained_model:
        lprnet.load_state_dict(torch.load(args.pretrained_model))
        print("load pretrained model successful!")
    else:
        def xavier(param):
            nn.init.xavier_uniform(param)

        def weights_init(m):
            for key in m.state_dict():
                if key.split('.')[-1] == 'weight':
                    if 'conv' in key:
                        nn.init.kaiming_normal_(m.state_dict()[key], mode='fan_out')
                    if 'bn' in key:
                        m.state_dict()[key][...] = xavier(1)
                elif key.split('.')[-1] == 'bias':
                    m.state_dict()[key][...] = 0

        lprnet.backbone.apply(weights_init)
        lprnet.container.apply(weights_init)
        print("initial net weights successful!")

    # define optimizer
    # optimizer = optim.SGD(lprnet.parameters(), lr=args.learning_rate,
    #                       momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer = optim.RMSprop(lprnet.parameters(), lr=args.learning_rate, alpha = 0.9, eps=1e-08,
                         momentum=args.momentum, weight_decay=args.weight_decay)
    train_img_dirs = os.path.expanduser(args.train_img_dirs)
    test_img_dirs = os.path.expanduser(args.test_img_dirs)
    train_dataset = LPRDataLoader(train_img_dirs.split(','), args.img_size, args.lpr_max_len)
    test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size, args.lpr_max_len)

    epoch_size = len(train_dataset) // args.train_batch_size
    max_iter = args.max_epoch * epoch_size

    ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean') # reduction: 'none' | 'mean' | 'sum'

    if args.resume_epoch > 0:
        start_iter = args.resume_epoch * epoch_size
    else:
        start_iter = 0

    for iteration in range(start_iter, max_iter):
        if iteration % epoch_size == 0:
            # create batch iterator
            batch_iterator = iter(DataLoader(train_dataset, args.train_batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn))
            loss_val = 0
            epoch += 1

        if iteration !=0 and iteration % args.save_interval == 0:
            torch.save(lprnet.state_dict(), args.save_folder + 'LPRNet_' + '_iteration_' + repr(iteration) + '.pth')

        if (iteration + 1) % args.test_interval == 0:
            Greedy_Decode_Eval(lprnet, test_dataset, args)
            lprnet.train() # should be switch to train mode

        start_time = time.time()
        # load train data
        images, labels, lengths = next(batch_iterator)
        # labels = np.array([el.numpy() for el in labels]).T
        # print(labels)
        # get ctc parameters
        input_lengths, target_lengths = sparse_tuple_for_ctc(T_length, lengths)
        # update lr
        lr = adjust_learning_rate(optimizer, epoch, args.learning_rate, args.lr_schedule)

        if args.cuda:
            images = Variable(images.cuda())
            labels = Variable(labels.cuda(), requires_grad=False)
        else:
            images = Variable(images)
            labels = Variable(labels, requires_grad=False)

        # forward
        logits = lprnet(images)
        log_probs = logits.permute(2, 0, 1) # for ctc loss: T x N x C
        # print(labels.shape)
        log_probs = log_probs.log_softmax(2).requires_grad_()
        # log_probs = log_probs.detach().requires_grad_()
        # print(log_probs.shape)
        # backprop
        optimizer.zero_grad()
        loss = ctc_loss(log_probs, labels, input_lengths=input_lengths, target_lengths=target_lengths)
        if loss.item() == np.inf:
            continue
        loss.backward()
        optimizer.step()
        loss_val += loss.item()
        end_time = time.time()
        if iteration % 20 == 0:
            print('Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size)
                  + '|| Totel iter ' + repr(iteration) + ' || Loss: %.4f||' % (loss.item()) +
                  'Batch time: %.4f sec. ||' % (end_time - start_time) + 'LR: %.8f' % (lr))
    # final test
    print("Final test Accuracy:")
    Greedy_Decode_Eval(lprnet, test_dataset, args)

    # save final parameters
    torch.save(lprnet.state_dict(), args.save_folder + 'Final_LPRNet_model.pth')
示例#10
0
    lprnet = LPRNet(class_num=len(CHARS), dropout_rate=args.dropout_rate)
    lprnet.to(device)
    lprnet.load_state_dict(torch.load('weights/Final_LPRNet_model.pth', map_location=lambda storage, loc: storage))
#    checkpoint = torch.load('saving_ckpt/lprnet_Iter_023400_model.ckpt')
#    lprnet.load_state_dict(checkpoint['net_state_dict'])
    lprnet.eval() 
    print("LPRNet loaded")
    
#    torch.save(lprnet.state_dict(), 'weights/Final_LPRNet_model.pth')
    
    STN = STNet()
    STN.to(device)
    STN.load_state_dict(torch.load('weights/Final_STN_model.pth', map_location=lambda storage, loc: storage))
#    checkpoint = torch.load('saving_ckpt/stn_Iter_023400_model.ckpt')
#    STN.load_state_dict(checkpoint['net_state_dict'])
    STN.eval()
    print("STN loaded")
    
#    torch.save(STN.state_dict(), 'weights/Final_STN_model.pth')
    
    dataset = LPRDataLoader([args.img_dirs], args.img_size)   
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, collate_fn=collate_fn) 
    print('dataset loaded with length : {}'.format(len(dataset)))
    
    ACC = eval(lprnet, STN, dataloader, dataset, device)
    print('the accuracy is {:.2f} %'.format(ACC*100))
    
    visualize_stn()

示例#11
0
    #lprnet.load_state_dict(torch.load('weights/Final_LPRNet_model.pth', map_location=lambda storage, loc: storage))
    print("LPRNet loaded")
    print('LPRNet params: ', sum(p.numel() for p in lprnet.parameters()))
    STN = STNet(r=r, resw=resw, resh=resh)
    STN.to(device)
    #summary(STN, (3,24,94))
    summary(STN, (3, resh, resw))
    #STN_load=torch.load('weights/Final_LPRNet_model.pth', map_location=lambda storage, loc: storage.cuda)
    #STN.load_state_dict(torch.load('weights/Final_STN_model.pth', map_location=lambda storage, loc: storage))
    print("STN loaded")
    print('STN params: ', sum(p.numel() for p in STN.parameters()))
    # raise NameError('stop')

    dataset = {
        'train':
        LPRDataLoader([args.img_dirs_train], args.img_size,
                      aug_transform=True),
        'val':
        LPRDataLoader([args.img_dirs_val], args.img_size, aug_transform=False)
    }  ###shuffle
    dataloader = {
        'train':
        DataLoader(dataset['train'],
                   batch_size=args.batch_size,
                   shuffle=False,
                   num_workers=4,
                   collate_fn=collate_fn),
        'val':
        DataLoader(dataset['val'],
                   batch_size=args.batch_size,
                   shuffle=False,
                   num_workers=4,
            elif key.split('.')[-1] == 'bias':
                m.state_dict()[key][...] = 0.01

    lprnet.backbone.apply(weights_init)
    lprnet.container.apply(weights_init)
    print("initial net weights successful!")

    STN = STNet()
    STN.to(device)
    STN.load_state_dict(
        torch.load('weights/STN_model_Init.pth',
                   map_location=lambda storage, loc: storage))
    print("STN loaded")

    dataset = {
        'train': LPRDataLoader([args.img_dirs_train], args.img_size),
        'val': LPRDataLoader([args.img_dirs_val], args.img_size)
    }
    dataloader = {
        'train':
        DataLoader(dataset['train'],
                   batch_size=args.batch_size,
                   shuffle=False,
                   num_workers=4,
                   collate_fn=collate_fn),
        'val':
        DataLoader(dataset['val'],
                   batch_size=args.batch_size,
                   shuffle=False,
                   num_workers=4,
                   collate_fn=collate_fn)