def train(): print('start training ...........') batch_size = 16 num_epochs = 50 learning_rate = 0.1 label_converter = LabelConverter(char_set=string.ascii_lowercase + string.digits) vocab_size = label_converter.get_vocab_size() device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") model = CRNN(vocab_size=vocab_size).to(device) # model.load_state_dict(torch.load('output/weight.pth', map_location=device)) train_loader, val_loader = get_loader('data/CAPTCHA Images/', batch_size=batch_size) optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, 2) train_losses, val_losses = [], [] for epoch in range(num_epochs): train_epoch_loss = fit(epoch, model, optimizer, label_converter, device, train_loader, phase='training') val_epoch_loss = fit(epoch, model, optimizer, label_converter, device, val_loader, phase='validation') print('-----------------------------------------') if epoch == 0 or val_epoch_loss <= np.min(val_losses): torch.save(model.state_dict(), 'output/weight.pth') train_losses.append(train_epoch_loss) val_losses.append(val_epoch_loss) write_figure('output', train_losses, val_losses) write_log('output', epoch, train_epoch_loss, val_epoch_loss) scheduler.step(val_epoch_loss)
import torch import torchvision.transforms.functional as F import matplotlib.pyplot as plt from model import CRNN import os from tqdm import tqdm import glob from dataset import CaptchaImagesDataset from utils import LabelConverter from tqdm import tqdm if __name__ == '__main__': device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") label_converter = LabelConverter(char_set=string.ascii_lowercase + string.digits) vocab_size = label_converter.get_vocab_size() model = CRNN(vocab_size=vocab_size).to(device) model.load_state_dict(torch.load('output/weight.pth', map_location=device)) model.eval() correct = 0.0 image_list = glob.glob('data/CAPTCHA Images/test/*') for image in tqdm(image_list): ground_truth = image.split('/')[-1].split('.')[0] image = Image.open(image).convert('RGB') image = F.to_tensor(image).unsqueeze(0).to(device) output = model(image) encoded_text = output.squeeze().argmax(1) decoded_text = label_converter.decode(encoded_text)