示例#1
0
def eval_model():
    model = CRNN()
    model.load_state_dict(torch.load('./model_EEG.pt'))
    # specify the target classes
    classes = ('True', 'False')

    # track test loss
    test_loss = 0.0
    class_correct = list(0. for i in range(num_classes))
    class_total = list(0. for i in range(num_classes))

    model.eval()
    with torch.no_grad():
        for data, target in testloader:
            data, target = data, target
            target = target.long()
            output, _ = model(data)
            #print(output.data)
            # convert output probabilities to predicted class
            _, pred = torch.max(output, 1)
            # print(pred)
            # compare predictions to true label
            correct = (pred == target).squeeze()
            for i, label in enumerate(target):
                class_correct[label] += correct[i].item()
                class_total[label] += 1
        for i in range(len(classes)):
            print('Accuracy of %s : %2d%% out of %d cases' %
                  (classes[i], 100 * class_correct[i] / class_total[i], class_total[i]))

        data = next(iter(testloader))
        inputs, targets = data
        inputs = inputs
        targets = targets
        targets = targets.long()
        outputs, _ = model(inputs)
        probability, predicted = torch.max(outputs.data, 1)
        c = (predicted == targets).squeeze()

        eval_metrics = pd.DataFrame(np.empty([2, 4]))
        eval_metrics.index = ["baseline"] + ['RNN']
        eval_metrics.columns = ["Accuracy", "ROC AUC", "PR AUC", "Log Loss"]
        pred = np.repeat(0, len(y_test.cpu()))
        pred_proba = np.repeat(0.5, len(y_test.cpu()))
        eval_metrics.iloc[0, 0] = accuracy_score(y_test.cpu(), pred)
        eval_metrics.iloc[0, 1] = roc_auc_score(y_test.cpu(), pred_proba)
        eval_metrics.iloc[0, 2] = average_precision_score(y_test.cpu(), pred_proba)
        eval_metrics.iloc[0, 3] = log_loss(y_test.cpu(), pred_proba)
        eval_metrics.iloc[1, 0] = accuracy_score(y_test.cpu(), predicted.cpu())
        eval_metrics.iloc[1, 1] = roc_auc_score(y_test.cpu(), probability.cpu())
        eval_metrics.iloc[1, 2] = average_precision_score(y_test.cpu(), probability.cpu())
        eval_metrics.iloc[1, 3] = 0  # log_loss(y_test.cpu(), pred_proba[:, 1])

        print(eval_metrics)
示例#2
0
def main():
    eval_batch_size = config["eval_batch_size"]
    cpu_workers = config["cpu_workers"]
    reload_checkpoint = config["reload_checkpoint"]

    img_height = config["img_height"]
    img_width = config["img_width"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device: {device}")

    test_dataset = Synth90kDataset(
        root_dir=config["data_dir"],
        mode="test",
        img_height=img_height,
        img_width=img_width,
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=cpu_workers,
        collate_fn=synth90k_collate_fn,
    )

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(
        1,
        img_height,
        img_width,
        num_class,
        map_to_seq_hidden=config["map_to_seq_hidden"],
        rnn_hidden=config["rnn_hidden"],
        leaky_relu=config["leaky_relu"],
    )
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    criterion = CTCLoss(reduction="sum")
    criterion.to(device)

    evaluation = evaluate(
        crnn,
        test_loader,
        criterion,
        decode_method=config["decode_method"],
        beam_size=config["beam_size"],
    )
    print("test_evaluation: loss={loss}, acc={acc}".format(**evaluation))
示例#3
0
def main():
    eval_batch_size = config['eval_batch_size']
    cpu_workers = config['cpu_workers']
    reload_checkpoint = config['reload_checkpoint']

    img_height = config['img_height']
    img_width = config['img_width']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    test_dataset = Synth90kDataset(root_dir=config['data_dir'],
                                   mode='test',
                                   img_height=img_height,
                                   img_width=img_width)

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=eval_batch_size,
                             shuffle=False,
                             num_workers=cpu_workers,
                             collate_fn=synth90k_collate_fn)

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1,
                img_height,
                img_width,
                num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    criterion = CTCLoss(reduction='sum')
    criterion.to(device)

    evaluation = evaluate(crnn,
                          test_loader,
                          criterion,
                          decode_method=config['decode_method'],
                          beam_size=config['beam_size'])
    print('test_evaluation: loss={loss}, acc={acc}'.format(**evaluation))
示例#4
0
def main():
    arguments = docopt(__doc__)

    images = arguments['IMAGE']
    reload_checkpoint = arguments['-m']
    batch_size = int(arguments['-s'])
    decode_method = arguments['-d']
    beam_size = int(arguments['-b'])

    img_height = config['img_height']
    img_width = config['img_width']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    predict_dataset = Synth90kDataset(paths=images,
                                      img_height=img_height,
                                      img_width=img_width)

    predict_loader = DataLoader(dataset=predict_dataset,
                                batch_size=batch_size,
                                shuffle=False)

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1,
                img_height,
                img_width,
                num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    preds = predict(crnn,
                    predict_loader,
                    Synth90kDataset.LABEL2CHAR,
                    decode_method=decode_method,
                    beam_size=beam_size)

    show_result(images, preds)
示例#5
0
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    args = parse_args()

    noisy_mel = np.load(args.path_to_file)
    h, w = noisy_mel.shape
    noisy_mel = pad_mel_spectogram(noisy_mel)
    noisy_mel = torch.tensor(noisy_mel, dtype=torch.float32)
    noisy_mel = noisy_mel.unsqueeze(0)

    model = CRNN().to(device)
    model.load_state_dict(torch.load(args.path_to_model, map_location=device))
    model.eval()

    clean_mel = model(noisy_mel)
    clean_mel = clean_mel.squeeze(0)
    clean_mel = clean_mel.data.cpu().numpy()
    clean_mel = clean_mel[:h]

    save_dir = os.path.dirname(args.path_to_save)
    if save_dir and not os.path.exists(save_dir):
        os.mkdir(save_dir)
    np.save(args.path_to_save, clean_mel)
示例#6
0
from dataset import LabelConverter, Rescale, Normalize

from model import CRNN

IMAGE_HEIGHT = 32

model_path = './ocr-model/crnn_address.pth'
img_path = './ocr_address.jpg'
# alphabet = '0123456789X'
alphabet = alphabet = ''.join(json.load(open('./cn-alphabet.json', 'rb')))

model = CRNN(IMAGE_HEIGHT, 1, len(alphabet) + 1, 256)
if torch.cuda.is_available():
    model = model.cuda()
print('loading pretrained model from %s' % model_path)
model.load_state_dict(torch.load(model_path))

converter = LabelConverter(alphabet)

image_transform = transforms.Compose(
    [Rescale(IMAGE_HEIGHT),
     transforms.ToTensor(),
     Normalize()])
image = cv2.imread(img_path, 0)
image = image_transform(image)
if torch.cuda.is_available():
    image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)

model.eval()
示例#7
0
from model import CRNN
import os
from tqdm import tqdm
import glob
from dataset import CaptchaImagesDataset
from utils import LabelConverter
from tqdm import tqdm

if __name__ == '__main__':
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    label_converter = LabelConverter(char_set=string.ascii_lowercase +
                                     string.digits)
    vocab_size = label_converter.get_vocab_size()

    model = CRNN(vocab_size=vocab_size).to(device)
    model.load_state_dict(torch.load('output/weight.pth', map_location=device))
    model.eval()

    correct = 0.0
    image_list = glob.glob('data/CAPTCHA Images/test/*')
    for image in tqdm(image_list):
        ground_truth = image.split('/')[-1].split('.')[0]
        image = Image.open(image).convert('RGB')
        image = F.to_tensor(image).unsqueeze(0).to(device)

        output = model(image)
        encoded_text = output.squeeze().argmax(1)
        decoded_text = label_converter.decode(encoded_text)

        if ground_truth == decoded_text:
            correct += 1
示例#8
0
parser.add_argument('-o', type=str, help='path to output result file', default='./result.txt')
args = parser.parse_args()

out_f = open(args.o, 'w')


checkpoint = torch.load('model_saves/vgg_lstm_subset_newloader/best.pth')
model = CRNN(hidden_size=checkpoint['hidden_size'], 
		only_cnn=checkpoint['only_cnn'], 
		cnn_type=checkpoint['cnn_type'], 
		recurrent_type=checkpoint['recurrent_type'],
		lstm_layers=checkpoint['lstm_layers'],
		nheads=checkpoint['nheads'], 
		nlayers=checkpoint['nlayers'],
		input_shape=checkpoint['input_shape']).double().to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

audio, sample_rate = torchaudio.load(args.i)
assert(sample_rate == 8000)

audio = audio.unsqueeze(0)
audio = audio.double().to(device)
with torch.no_grad():
	pred = model(audio)
probs = F.softmax(pred, dim=1)
print(probs)
pred_label = torch.argmax(pred, dim=1)


#out_f.write("Labels- {0: FL, 1: EN}\n")
示例#9
0
def main():
    epochs = config['epochs']
    train_batch_size = config['train_batch_size']
    eval_batch_size = config['eval_batch_size']
    lr = config['lr']
    show_interval = config['show_interval']
    valid_interval = config['valid_interval']
    save_interval = config['save_interval']
    cpu_workers = config['cpu_workers']
    reload_checkpoint = config['reload_checkpoint']
    valid_max_iter = config['valid_max_iter']

    img_width = config['img_width']
    img_height = config['img_height']
    data_dir = config['data_dir']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    train_dataset = Synth90kDataset(root_dir=data_dir,
                                    mode='train',
                                    img_height=img_height,
                                    img_width=img_width)
    valid_dataset = Synth90kDataset(root_dir=data_dir,
                                    mode='dev',
                                    img_height=img_height,
                                    img_width=img_width)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=train_batch_size,
                              shuffle=True,
                              num_workers=cpu_workers,
                              collate_fn=synth90k_collate_fn)
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=eval_batch_size,
                              shuffle=True,
                              num_workers=cpu_workers,
                              collate_fn=synth90k_collate_fn)

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1,
                img_height,
                img_width,
                num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    if reload_checkpoint:
        crnn.load_state_dict(torch.load(reload_checkpoint,
                                        map_location=device))
    crnn.to(device)

    optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
    criterion = CTCLoss(reduction='sum')
    criterion.to(device)

    assert save_interval % valid_interval == 0
    i = 1
    for epoch in range(1, epochs + 1):
        print(f'epoch: {epoch}')
        tot_train_loss = 0.
        tot_train_count = 0
        for train_data in train_loader:
            loss = train_batch(crnn, train_data, optimizer, criterion, device)
            train_size = train_data[0].size(0)

            tot_train_loss += loss
            tot_train_count += train_size
            if i % show_interval == 0:
                print('train_batch_loss[', i, ']: ', loss / train_size)

            if i % valid_interval == 0:
                evaluation = evaluate(crnn,
                                      valid_loader,
                                      criterion,
                                      decode_method=config['decode_method'],
                                      beam_size=config['beam_size'])
                print('valid_evaluation: loss={loss}, acc={acc}'.format(
                    **evaluation))

                if i % save_interval == 0:
                    prefix = 'crnn'
                    loss = evaluation['loss']
                    save_model_path = os.path.join(
                        config['checkpoints_dir'],
                        f'{prefix}_{i:06}_loss{loss}.pt')
                    torch.save(crnn.state_dict(), save_model_path)
                    print('save model at ', save_model_path)

            i += 1

        print('train_loss: ', tot_train_loss / tot_train_count)
示例#10
0
def main():
    epochs = config["epochs"]
    train_batch_size = config["train_batch_size"]
    eval_batch_size = config["eval_batch_size"]
    lr = config["lr"]
    show_interval = config["show_interval"]
    valid_interval = config["valid_interval"]
    save_interval = config["save_interval"]
    cpu_workers = config["cpu_workers"]
    reload_checkpoint = config["reload_checkpoint"]
    valid_max_iter = config["valid_max_iter"]

    img_width = config["img_width"]
    img_height = config["img_height"]
    data_dir = config["data_dir"]

    torch.cuda.empty_cache()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device: {device}")

    #     # Extracts metadata related to the file path, label, and image width + height.
    #    wbsin_dir = Path().cwd() / "data" / "processed" / "cropped_wbsin"
    #    wbsin_meta_df = extract_jpg_meta(img_dir=wbsin_dir, img_type="wbsin")
    #     # Saves the extracted metadata.
    # interim_path = Path.cwd() / "data" / "interim"
    #    interim_path.mkdir(parents=True, exist_ok=True)
    #    wbsin_meta_df.to_csv(interim_path / "wbsin_meta.csv", index=False)

    X_transforms = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((160, 1440)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, )),
    ])

    wbsin_dataset = WbsinImageDataset(
        meta_file=(Path.cwd() / "data" / "processed" /
                   "processed_wbsin_meta.csv"),
        transform=X_transforms,
    )

    train_size = int(0.8 * len(wbsin_dataset))
    test_size = len(wbsin_dataset) - train_size

    train_dataset, test_dataset = torch.utils.data.random_split(
        wbsin_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42),
    )

    # Save the test_dataset

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=wbsin_collate_fn,
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=wbsin_collate_fn,
    )

    num_class = len(WbsinImageDataset.LABEL2CHAR) + 1
    crnn = CRNN(
        1,
        img_height,
        img_width,
        num_class,
        map_to_seq_hidden=config["map_to_seq_hidden"],
        rnn_hidden=config["rnn_hidden"],
        leaky_relu=config["leaky_relu"],
    )
    if reload_checkpoint:
        crnn.load_state_dict(torch.load(reload_checkpoint,
                                        map_location=device))
    crnn.to(device)

    optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
    criterion = CTCLoss(reduction="sum")
    criterion.to(device)

    assert save_interval % valid_interval == 0
    i = 1
    for epoch in range(1, epochs + 1):
        print(f"epoch: {epoch}")
        tot_train_loss = 0.0
        tot_train_count = 0
        for train_data in train_dataloader:
            loss = train_batch(crnn, train_data, optimizer, criterion, device)
            train_size = train_data[0].size(0)

            tot_train_loss += loss
            tot_train_count += train_size
            if i % show_interval == 0:
                print("train_batch_loss[", i, "]: ", loss / train_size)

            if i % valid_interval == 0:
                evaluation = evaluate(
                    crnn,
                    test_dataloader,
                    criterion,
                    decode_method=config["decode_method"],
                    beam_size=config["beam_size"],
                )
                print(
                    "valid_evaluation: loss={loss}, acc={acc}, char_acc={char_acc}"
                    .format(**evaluation))

                if i % save_interval == 0:
                    prefix = "crnn"
                    loss = evaluation["loss"]
                    save_model_path = os.path.join(
                        config["checkpoints_dir"],
                        f"{prefix}_{i:06}_loss{loss}.pt")
                    torch.save(crnn.state_dict(), save_model_path)
                    print("save model at ", save_model_path)
            i += 1

        print("train_loss: ", tot_train_loss / tot_train_count)
示例#11
0
def main():
    arguments = docopt(__doc__)

    images = arguments["IMAGE"]
    images = [image for image in (Path.cwd() / images[0]).glob("*.jpg")]
    reload_checkpoint = arguments["-m"]
    batch_size = int(arguments["-s"])
    decode_method = arguments["-d"]
    beam_size = int(arguments["-b"])

    img_height = config["img_height"]
    img_width = config["img_width"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device: {device}")

    X_transforms = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((160, 1440)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, )),
    ])
    wbsin_dataset = WbsinImageDataset(
        meta_file=(Path.cwd() / "data" / "interim" / "cropped_wbsin_meta.csv"),
        transform=X_transforms)

    train_size = int(0.8 * len(wbsin_dataset))
    test_size = len(wbsin_dataset) - train_size

    train_dataset, test_dataset = torch.utils.data.random_split(
        wbsin_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42),
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        collate_fn=wbsin_collate_fn,
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        collate_fn=wbsin_collate_fn,
    )

    num_class = len(WbsinImageDataset.LABEL2CHAR) + 1
    crnn = CRNN(
        1,
        img_height,
        img_width,
        num_class,
        map_to_seq_hidden=config["map_to_seq_hidden"],
        rnn_hidden=config["rnn_hidden"],
        leaky_relu=config["leaky_relu"],
    )
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    preds = predict(
        crnn,
        train_dataloader,
        WbsinImageDataset.LABEL2CHAR,
        decode_method=decode_method,
        beam_size=beam_size,
    )

    show_result(images, preds)
示例#12
0
# input channel , 因为训练图片是转成灰度图,所以该值为1
nc = 1
lr = 0.001
beta1 = 0.5
MOMENTUM = 0.9
EPOCH = 100

# 字符转换编码
converter = utils.strLabelConverter(c.alphabet)
# 损失函数
criterion = CTCLoss()

crnn = CRNN(imgH, nc, nclass, nh, ngpu)
crnn.apply(weights_init)
if os.path.exists('/home/hecong/temp/data/ocr/simple_ocr.pkl'):
    crnn.load_state_dict(
        torch.load('/home/hecong/temp/data/ocr/simple_ocr.pkl'))

image = torch.FloatTensor(batchSize, 3, imgH, imgH)
text = torch.IntTensor(batchSize * 5)
length = torch.IntTensor(batchSize)

# optimizer = optim.Adam(
#     crnn.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer = optim.SGD(crnn.parameters(), lr=lr, momentum=MOMENTUM)

for epoch in range(EPOCH):
    for step, (t_image, t_label) in enumerate(train_loader):
        batch_size = t_image.size(0)
        utils.loadData(image, t_image)
        t, l = converter.encode(t_label)
        utils.loadData(text, t)
示例#13
0
def train_model():

    model = CRNN().cuda()
    writer = SummaryWriter('./tblogs/%f/' % learning_rate)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay, lr=learning_rate)
    lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30, 40, 50], gamma=0.1)
    t0 = time.time()
    cnt = 0
    total_step = len(trainloader)
    hidden = None
    class_correct = list(0. for i in range(num_classes))
    class_total = list(0. for i in range(num_classes))
    classes = ('True', 'False')
    start_epoch = 0
    ########################
    ####断点重训########
    if RESUME:
        path_checkpoint = 'F:/EEG_data/MNE_test/code/reweights/checkpoint/%f/ckpt_best.pth' % (learning_rate)  # 断点路径
        checkpoint = torch.load(path_checkpoint)  # 加载断点

        model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
        start_epoch = checkpoint['epoch']  # 设置开始的epoch
        lr_schedule.load_state_dict(checkpoint['lr_schedule'])

    for epoch in range(start_epoch + 1, num_epochs + 1):

        # keep track of training
        train_loss = 0.0
        train_counter = 0
        train_losses = 0.0
        ###################
        # train the model #
        ###################
        model.train()
        for data, target in trainloader:
            data, target = data.cuda(), target.cuda()
            target = target.long()
            optimizer.zero_grad()
            output, hidden = model(data, hidden)
            a, b = hidden
            hidden = (a.data, b.data)
            loss = criterion(output, target)
            # print(target.data)
            loss.backward()
            optimizer.step()
            train_loss += (loss.item() * data.size(0))
            train_counter += data.size(0)
            train_losses = (train_loss / train_counter)
            writer.add_scalar('Train/Loss', train_losses, epoch)
            cnt += 1
            if cnt % 10 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, cnt + 1, total_step, loss.item()))
        cnt = 0
        checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch,
            'lr_schedule': lr_schedule.state_dict()
        }
        if not os.path.isdir("F:/EEG_data/MNE_test/code/reweights/checkpoint/%f" % (learning_rate)):
            os.makedirs("F:/EEG_data/MNE_test/code/reweights/checkpoint/%f" % (learning_rate))
        torch.save(checkpoint, 'F:/EEG_data/MNE_test/code/reweights/checkpoint/%f/ckpt_best.pth' % (learning_rate))
        if epoch % 200 == 0:
            torch.save(model.state_dict(), './model_EEG.pt')
        #############
        # eval
        #############
        #TODO add eval part
    # torch.save(model.state_dict(), './model_EEG.pt')
    time_total = time.time() - t0
    print('Total time: {:4.3f}, average time per epoch: {:4.3f}'.format(time_total, time_total / num_epochs))
示例#14
0
#loading other parameters of opt
nclass = len(opt.alphabet) + 1
nc = 1

#adding model
crnn = CRNN(imgH=32,
            nc=1,
            nclass=nclass,
            nh=opt.nhidden,
            n_rnn=2,
            leakyRelu=False)

#loading model weights
if opt.crnn_checkpoints != '' and (os.path.exists(opt.crnn_checkpoints)):
    try:
        crnn.load_state_dict(torch.load(opt.crnn_checkpoints))
        print('loaded pretrained model from %s' % opt.crnn_checkpoints)
    except:
        crnn.apply(weights_init)
        print('Corrupt checkpoint given. Training from scratch')

if opt.checkpoints_folder == None:
    opt.checkpoints_folder = 'model_checkpoints'

print(opt.checkpoints_folder)

if not (os.path.exists(opt.checkpoints_folder)):
    os.system('mkdir {0}'.format(opt.checkpoints_folder))

print('Model checkpoints directory created')
示例#15
0
def main():
	parser = argparse.ArgumentParser(description='CRNN')
	parser.add_argument('--name', default='32x100', type=str)
	parser.add_argument('--exp', default='syn90k', type=str)
	
	## data setting 
	parser.add_argument('--root', default='/data/data/text_recognition/',type=str)
	parser.add_argument('--test_dataset', default='ic03', type=str)
	parser.add_argument('--load_width', default=100, type=int)
	parser.add_argument('--load_height', default=32, type=int)
	parser.add_argument('--batch_size', default=64, type=int)
	parser.add_argument('--num_workers', default=8, type=int)
	## model setting
	parser.add_argument('--snapshot', default='./weights/32x100/syn90k/3_51474.pth', type=str)
	parser.add_argument('--alphabeta', default='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ', type=str)
	parser.add_argument('--ignore_case', default=True, type=bool)
	## output setting
	parser.add_argument('--out_dir', default='./outputs', type=str)

	args = parser.parse_args()

	if os.path.exists(args.out_dir) == False:
		os.mkdir(args.out_dir)
	args.out_dir = os.path.join(args.out_dir, args.name)
	if os.path.exists(args.out_dir) == False:
		os.mkdir(args.out_dir)
	args.out_dir = os.path.join(args.out_dir, args.snapshot.strip().split('/')[-1].split('.')[0])
	if os.path.exists(args.out_dir) == False:
		os.mkdir(args.out_dir)

	if args.ignore_case:
		args.alphabeta = args.alphabeta[:36]
	args.nClasses = len(args.alphabeta) + 1

	log_path = os.path.join(args.out_dir, args.test_dataset + '.txt')
	setup_logger(log_path)

	logging.info('model will be evaluated on %s'%(args.test_dataset))
	testset =  SceneLoader(args, args.test_dataset, False)
	logging.info('%d test samples'%(testset.__len__()))
	test_loader = data.DataLoader(testset, args.batch_size, num_workers=args.num_workers,
	                              shuffle=False,  pin_memory=True)

	## model
	net = CRNN(args)
	print(net)
	input()
	net = torch.nn.DataParallel(net).cuda()
	print(net)
	net.load_state_dict(torch.load(args.snapshot))
	net = net.module
	net.eval()
	n_correct = 0
	n_samples = 0
	converter = strLabelConverter(args.alphabeta, args.ignore_case)

	for index, sample in enumerate(test_loader):
		# print('model state', net.training)
		# print('bn1 state', net.cnn[0].training)
		# print('conv1.weight', net.cnn[0].subnet[0].weight[0, 0, 0, 0])
		# print('bn1.weight', net.cnn[4].subnet[1].weight[0])
		# print('bn1.bias', net.cnn[4].subnet[1].weight[0])
		# print('bn1.running_mean', net.cnn[4].subnet[1].running_mean[0])
		# print('bn1.running_var', net.cnn[4].subnet[1].running_var[0])
		imgs, gts, lexicon50, lexicon1k, lexiconfull, img_paths = sample
		imgs = Variable(imgs).cuda()
		preds = net(imgs)
		preds_size = torch.IntTensor([preds.size(0)] * preds.size(1))
		## decode
		_, preds = preds.max(2)
		preds = preds.transpose(1, 0).contiguous().view(-1)
		text_preds = converter.decode(preds.data, preds_size, raw=False)
		for pred, target in zip(text_preds, gts):
			n_samples += 1
			if pred.lower() == target.lower():
				n_correct += 1
				logging.info('pred: %s  gt:%s '%(pred, target))
			else:
				logging.info('pred: %s  gt:%s  -----------------------------!!!!!!'%(pred, target))
	assert(n_samples == testset.__len__())
	acc = n_correct*1.0/testset.__len__()
	logging.info('accuracy=%f'%(acc))
示例#16
0
        transforms.Normalize([0.5], [0.5])
    ])

    net = CRNN()
    net = net.to(device)
    criterion = nn.CTCLoss()
    optimizer = optim.Adam(net.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                             step_size=5000,
                                             gamma=0.5)

    if args.load_model:
        load_path = os.path.join(os.getcwd(), args.load_model)
        net.load_state_dict(torch.load(load_path))

    save_dir = os.path.join(os.getcwd(), args.save_dir)
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    train_data = syn_text(
        "annotation_train.txt",
        "E:\\cam2\\lpcvc-2020\\ctc\\mnt\\ramdisk\\max\\90kDICT32px", transform)
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=1)

    val_data = syn_text(
        "annotation_val.txt",