示例#1
0
def train():
    print('start training ...........')
    batch_size = 16
    num_epochs = 50
    learning_rate = 0.1

    label_converter = LabelConverter(char_set=string.ascii_lowercase + string.digits)
    vocab_size = label_converter.get_vocab_size()

    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    model = CRNN(vocab_size=vocab_size).to(device)
    # model.load_state_dict(torch.load('output/weight.pth', map_location=device))

    train_loader, val_loader = get_loader('data/CAPTCHA Images/', batch_size=batch_size)

    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    # scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, 2)

    train_losses, val_losses = [], []
    for epoch in range(num_epochs):
        train_epoch_loss = fit(epoch, model, optimizer, label_converter, device, train_loader, phase='training')
        val_epoch_loss = fit(epoch, model, optimizer, label_converter, device, val_loader, phase='validation')
        print('-----------------------------------------')

        if epoch == 0 or val_epoch_loss <= np.min(val_losses):
            torch.save(model.state_dict(), 'output/weight.pth')

        train_losses.append(train_epoch_loss)
        val_losses.append(val_epoch_loss)

        write_figure('output', train_losses, val_losses)
        write_log('output', epoch, train_epoch_loss, val_epoch_loss)

        scheduler.step(val_epoch_loss)
示例#2
0
import torch
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
from model import CRNN
import os
from tqdm import tqdm
import glob
from dataset import CaptchaImagesDataset
from utils import LabelConverter
from tqdm import tqdm

if __name__ == '__main__':
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    label_converter = LabelConverter(char_set=string.ascii_lowercase +
                                     string.digits)
    vocab_size = label_converter.get_vocab_size()

    model = CRNN(vocab_size=vocab_size).to(device)
    model.load_state_dict(torch.load('output/weight.pth', map_location=device))
    model.eval()

    correct = 0.0
    image_list = glob.glob('data/CAPTCHA Images/test/*')
    for image in tqdm(image_list):
        ground_truth = image.split('/')[-1].split('.')[0]
        image = Image.open(image).convert('RGB')
        image = F.to_tensor(image).unsqueeze(0).to(device)

        output = model(image)
        encoded_text = output.squeeze().argmax(1)
        decoded_text = label_converter.decode(encoded_text)