コード例 #1
0
ファイル: main.py プロジェクト: sky-lzy/DailyCode
def predict(model_path, im_path, norm_height=32, norm_width=128, device='cpu'):
    '''
    predict a new image using a trained model
    :param model_path: path of the saved model
    :param im_path: path of an image
    :param norm_height: image normalization height
    :param norm_width: image normalization width
    :param device: 'cpu' or 'cuda'
    '''

    # step 1: initialize a model and put it on device
    model = CRNN()
    model = model.to(device)

    # step 2: load state_dict from saved model
    checkpoint = torch.load(model_path,
                            map_location=torch.device('cuda') if
                            torch.cuda.is_available() else torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    print('[Info] Load model from {}'.format(model_path))

    # step 3: initialize the label converter
    label_converter = LabelConverter()

    # step 4: read image and normalization
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((norm_height, norm_width)),
        transforms.ToTensor()
    ])
    im = cv2.imread(im_path)
    if im is None:
        raise AssertionError(
            f'the image {im_path} may not exist, please check it.')
    x = transform(im)
    x = x.unsqueeze(0)  # add the batch dimension

    # step 5: run model
    model.eval()
    with torch.no_grad():
        logits, _ = model(x)
        raw_pred = logits.argmax(2)
        pred = label_converter.decode(raw_pred)[0]
    print('prediction: {}\n'.format(pred))

    # visualize probabilities output by CTC
    savepath = os.path.splitext(im_path)[0] + '_vis.jpg'
    visual_ctc_results(im, logits, savepath)
コード例 #2
0
from tqdm import tqdm
import glob
from dataset import CaptchaImagesDataset
from utils import LabelConverter
from tqdm import tqdm

if __name__ == '__main__':
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    label_converter = LabelConverter(char_set=string.ascii_lowercase +
                                     string.digits)
    vocab_size = label_converter.get_vocab_size()

    model = CRNN(vocab_size=vocab_size).to(device)
    model.load_state_dict(torch.load('output/weight.pth', map_location=device))
    model.eval()

    correct = 0.0
    image_list = glob.glob('data/CAPTCHA Images/test/*')
    for image in tqdm(image_list):
        ground_truth = image.split('/')[-1].split('.')[0]
        image = Image.open(image).convert('RGB')
        image = F.to_tensor(image).unsqueeze(0).to(device)

        output = model(image)
        encoded_text = output.squeeze().argmax(1)
        decoded_text = label_converter.decode(encoded_text)

        if ground_truth == decoded_text:
            correct += 1

    print('accuracy =', correct / len(image_list))
コード例 #3
0
            loss.backward(retain_graph=True)
            optimizer.step()

            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch_num + 1, epochs, i + 1, total_steps, loss.item()))
            prediction = []
            for b in range(N):
                single_element_pred = []
                for t in fw_pass_output.data:
                    _, pred = torch.max(t[b].data, 0)
                    single_element_pred.append(int(pred))
                prediction.append(single_element_pred)

            for i, pred in enumerate(prediction):
                prediction_tensor = model_output_to_label(pred)
                model_pred = lc.decode(prediction_tensor)
                real_label = lc.decode(labels[i])
                print("{i}. Pred = {p}, real = {r} ".format(i=i + 1,
                                                            p=model_pred,
                                                            r=real_label))

prediction = []
for b in range(N):
    single_element_pred = []
    for t in fw_pass_output.data:
        _, pred = torch.max(t[b].data, 0)
        single_element_pred.append(int(pred))
    prediction.append(single_element_pred)

for i, pred in enumerate(prediction):
    prediction_tensor = model_output_to_label(pred)