Example #1
0
 def __init__(self):
     self.trans = v_transforms.Compose([
         Normalize(bound=[-1300., 500.], cover=[0., 1.]),
         CenterCrop([48, 96, 96]),
         ToTensor(),
         Resize([48, 96, 96]),
     ])
Example #2
0
 def __init__(self,
              ann_file='data/coco2017/annotations/instances_train2017.json',
              pipeline=(LoadImageFromFile(), LoadAnnotations(),
                        Resize(img_scale=(1333, 800),
                               keep_ratio=True), RandomFlip(flip_ratio=0.5),
                        Normalize(mean=[123.675, 116.28, 103.53],
                                  std=[58.395, 57.12, 57.375],
                                  to_rgb=True), Pad(size_divisor=32),
                        DefaultFormatBundle(),
                        Collect(keys=['img', 'gt_bboxes', 'gt_labels'])),
              test_mode=False,
              filter_empty_gt=True):
     self.ann_file = ann_file
     self.img_prefix = 'data/coco2017/train2017/' if not test_mode else 'data/coco2017/test2017/'
     self.test_mode = test_mode
     self.filter_empty_gt = filter_empty_gt
     # load annotations (and proposals)
     self.img_infos = self.load_annotations(self.ann_file)
     # filter images too small
     if not test_mode:
         valid_inds = self._filter_imgs()  # 去除长或宽小于32的图片和没有标注的图片
         self.img_infos = [self.img_infos[i] for i in valid_inds]
     # set group flag for the sampler
     if not self.test_mode:
         self._set_group_flag()
     # processing pipeline
     self.pipeline = Compose(pipeline)
Example #3
0
def train(model, data_loader, optimizer, loss, epoch):
    train_loss = []
    lr = optimizer.param_groups[0]['lr']
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()
    for i, (data, target, names) in enumerate(data_loader):
        # import pdb;pdb.set_trace()

        # data = Variable(data.cuda(async = True))
        # target = Variable(target.cuda(async = True))

        data = data.cuda(non_blocking=True)
        data = Resize([opt.sample_duration, opt.sample_size,
                       opt.sample_size])(data)
        #print(data.shape)
        #import pdb;pdb.set_trace()

        target = target.cuda(non_blocking=True)

        # import pdb;pdb.set_trace()

        out = model(data)
        #print(out,target)
        #print(out,target)
        if "FP" in opt.save_dir:
            cls = loss(out, target)
        elif "BCE" in opt.save_dir:
            cls = loss(out, target)
        else:
            cls = loss(out, target.long())
        optimizer.zero_grad()
        cls.backward()
        optimizer.step()
        pred = torch.sigmoid(out[:, :1])
        if opt.n_classes == 1:
            train_acc = acc_metric(pred.data.cpu().numpy(),
                                   target.data.cpu().numpy())
        else:
            train_acc = calculate_accuracy(out, target.long())
        #train_acc = acc_metric(pred.data.cpu().numpy(), target.data.cpu().numpy())

        try:
            train_loss.append(cls.data[0])
        except:
            train_loss.append(cls.item())

        if i % 5 == 0:
            try:
                print(
                    "Training: Epoch %d: %dth batch, loss %2.4f, acc %2.4f, lr: %2.6f!"
                    % (epoch, i, cls.item(), train_acc, lr))
            except:
                print(
                    "Training: Epoch %d: %dth batch, loss %2.4f, acc %2.4f, lr: %2.6f!"
                    % (epoch, i, cls.item(), train_acc, lr))

    return np.mean(train_loss)
Example #4
0
def get_predict_transform(size):  #mean=mean, std=std, size=0):
    predict_transform = transforms.Compose([
        Resize((int(size * (256 / 224)), int(size * (256 / 224)))),
        #transforms.Resize(size, size),
        transforms.RandomCrop(size),
        transforms.ToTensor(),
        transforms.Normalize(mean_nums, std_nums),
    ])
    return predict_transform
Example #5
0
def get_train_transform(size):  #mean=mean, std=std, size=0):
    train_transform = transforms.Compose([
        Resize((int(size * (256 / 224)), int(size * (256 / 224)))),
        #transforms.Resize(size, size),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=30),
        transforms.RandomCrop(size),
        #transforms.CenterCrop(size),
        # RandomGaussianBlur(),
        transforms.ToTensor(),
        transforms.Normalize(mean_nums, std_nums),
    ])
    return train_transform
Example #6
0
 def __init__(self, cfg):
     super(Data, self).__init__()
     self.cfg = cfg
     # 下面是数据增强等
     self.randombrig = RandomBrightness()
     self.normalize = Normalize(mean=cfg.mean, std=cfg.std)
     self.randomcrop = RandomCrop()
     self.blur = RandomBlur()
     self.randomvflip = RandomVorizontalFlip()
     self.randomhflip = RandomHorizontalFlip()
     self.resize = Resize(384, 384)
     self.totensor = ToTensor()
     # 读数据
     with open(cfg.datapath + '/' + cfg.mode + '.txt', 'r') as lines:
         self.samples = []
         for line in lines:
             self.samples.append(line.strip())
Example #7
0
if __name__ == '__main__':
    ITERATION, BATCH_SIZE, CLASSES_NUM = 50, 1, 2
    HEIGHT, WIDTH = 288, 800
    vgg11 = Vgg11()
    fineTuringNet = ProxyNet('vgg11', vgg11.features)
    model = Fcn(scale=8, featureProxyNet=fineTuringNet, classesNum=CLASSES_NUM)

    # 加载模型
    state_dict = torch.load('fcn_road_segment.pth')
    model.load_state_dict(state_dict)

    testSet = KittiRoadTestDataset(
        path='./data_road/training',
        type='um',
        transforms=[
            Resize(HEIGHT, WIDTH, cv2.INTER_LINEAR),
            # HistogramNormalize(),
            # Mixup(shadowImg, random_translation=False, random_rotation=False),
            ToTensor(),
        ])
    testLoader = DataLoader(testSet, batch_size=1, shuffle=True, num_workers=0)
    total = len(testSet)
    for batch, (img, img_path) in enumerate(testLoader):
        with torch.no_grad():
            img = img
            output = model(img)
            output = torch.sigmoid(output)
            output = output.reshape((CLASSES_NUM, 288, 800)).numpy()
            segment = np.zeros((3, 288, 800))
            if CLASSES_NUM == 1:
                segment[2, output[0] >= 0.5] = 1
Example #8
0
def main():
    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default=None,
                        help='path to the data')
    parser.add_argument('--epochs',
                        '-e',
                        dest='epochs',
                        type=int,
                        help='number of train epochs',
                        default=100)
    parser.add_argument('--batch_size',
                        '-b',
                        dest='batch_size',
                        type=int,
                        help='batch size',
                        default=128)  # 1o024
    parser.add_argument('--weight_decay',
                        '-wd',
                        dest='weight_decay',
                        type=float,
                        help='weight_decay',
                        default=5e-4)
    parser.add_argument('--lr',
                        '-lr',
                        dest='lr',
                        type=float,
                        help='lr',
                        default=1e-4)
    parser.add_argument('--model',
                        '-m',
                        dest='model',
                        type=str,
                        help='model_name',
                        default='CRNN')
    parser.add_argument('--lr_step',
                        '-lrs',
                        dest='lr_step',
                        type=int,
                        help='lr step',
                        default=10)
    parser.add_argument('--lr_gamma',
                        '-lrg',
                        dest='lr_gamma',
                        type=float,
                        help='lr gamma factor',
                        default=0.5)
    parser.add_argument('--input_wh',
                        '-wh',
                        dest='input_wh',
                        type=str,
                        help='model input size',
                        default='320x64')
    parser.add_argument('--rnn_dropout',
                        '-rdo',
                        dest='rnn_dropout',
                        type=float,
                        help='rnn dropout p',
                        default=0.1)
    parser.add_argument('--rnn_num_directions',
                        '-rnd',
                        dest='rnn_num_directions',
                        type=int,
                        help='bi',
                        default=2)
    parser.add_argument('--augs',
                        '-a',
                        dest='augs',
                        type=float,
                        help='degree of geometric augs',
                        default=0)
    parser.add_argument('--load',
                        '-l',
                        dest='load',
                        type=str,
                        help='pretrained weights',
                        default=None)
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        type=float,
                        default=0.7,
                        help='train/val split')
    parser.add_argument('-o',
                        '--output_dir',
                        dest='output_dir',
                        default='./output',
                        help='dir to save log and models')
    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    w, h = list(map(int, args.input_wh.split('x')))

    logger = get_logger(os.path.join(args.output_dir, 'train.log'))
    logger.info('Start training with params:')
    for arg, value in sorted(vars(args).items()):
        logger.info("Argument %s: %r", arg, value)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    net = RecognitionModel(dropout=args.rnn_dropout,
                           num_directions=args.rnn_num_directions,
                           input_size=(w, h))
    if args.load is not None:
        net.load_state_dict(torch.load(args.load))
    net = net.to(device)
    criterion = ctc_loss
    logger.info('Model type: {}'.format(net.__class__.__name__))

    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) \
        if args.lr_step is not None else None

    train_transforms = Compose([
        Rotate(max_angle=args.augs * 7.5, p=0.5),  # 5 -> 7.5
        Pad(max_size=args.augs / 10, p=0.1),
        Compress(),
        Blur(),
        Resize(size=(w, h)),
        ScaleToZeroOne(),
    ])
    val_transforms = Compose([
        Resize(size=(w, h)),
        ScaleToZeroOne(),
    ])
    train_dataset = RecognitionDataset(args.data_path,
                                       os.path.join(args.data_path,
                                                    'train_rec.json'),
                                       abc=abc,
                                       transforms=train_transforms)
    val_dataset = RecognitionDataset(args.data_path,
                                     None,
                                     abc=abc,
                                     transforms=val_transforms)
    # split dataset into train/val, don't try to do this at home ;)
    train_size = int(len(train_dataset) * args.val_split)
    val_dataset.image_names = train_dataset.image_names[train_size:]
    val_dataset.texts = train_dataset.texts[train_size:]
    train_dataset.image_names = train_dataset.image_names[:train_size]
    train_dataset.texts = train_dataset.texts[:train_size]

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  collate_fn=train_dataset.collate_fn)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=8,
                                collate_fn=val_dataset.collate_fn)
    logger.info('Length of train/val=%d/%d', len(train_dataset),
                len(val_dataset))
    logger.info('Number of batches of train/val=%d/%d', len(train_dataloader),
                len(val_dataloader))

    try:
        train(net,
              criterion,
              optimizer,
              scheduler,
              train_dataloader,
              val_dataloader,
              args=args,
              logger=logger,
              device=device)
    except KeyboardInterrupt:
        torch.save(
            net.state_dict(),
            os.path.join(args.output_dir, f'{args.model}_INTERRUPTED.pth'))
        logger.info('Saved interrupt')
        sys.exit(0)
Example #9
0
import cv2
import torch
from torch.utils.data import DataLoader

from dataset import FightDataset
from transform import Compose, ToTensor, Resize
from model import MyNet

torch.cuda.set_device(0)
transform_ = Compose([Resize((112, 112)), ToTensor()])
xx = FightDataset("./fight_classify", tranform=transform_)

dataloader = DataLoader(xx, batch_size=1, shuffle=True)
# for i_batch, sample_batched in enumerate(dataloader):
#     print(i_batch)
#     print(sample_batched["image"].size())
dev = torch.device("cuda:0")
model = MyNet().to(dev)

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)

for t in range(20):
    # Forward pass: Compute predicted y by passing x to the model
    for i_batch, sample_batched in enumerate(dataloader):
        image = sample_batched["image"]
        label = sample_batched["label"]
        # label = torch.transpose(label, 0,1)
        y_pred = model(image)
        # print(y_pred)
        # print(label)
# -*- coding: utf-8 -*-
"""
Created on Fri Mar  6 16:54:29 2020

@author: fqsfyq
"""
import os
import random
import numpy as np

import transform
from transform import Resize
from transform import Compose

loading = transform.LoadImgFile
resize = Resize([224, 320, 416, 512])

#noise
AdditiveGaussianNoise = transform.AdditiveGaussianNoise(scale_limit=(10, 30))
AdditiveLaplaceNoise = transform.AdditiveLaplaceNoise()
AdditivePoissonNoise = transform.AdditivePoissonNoise()
#cotout
CoarseDropout = transform.CoarseDropout()
CoarseSaltAndPepper = transform.CoarseSaltAndPepper()
Cutout = transform.Cutout()
#color
HueSat = transform.AddToHueAndSaturation()
DropChannel = transform.DropChannel()
#geometry
elastic = transform.ElasticTransformation(sigma=9)
Rotate_Shear = transform.Rotate_Shear()
Example #11
0
def test(model, data_loader, loss, epoch, lr, max_acc, max_auc, acc_max,
         auc_max, save_recall, save_prec, save_spec, save_f1score):
    test_acc = []
    loss_lst = []

    pred_lst = []
    label_lst = []
    prob_lst = []
    isave = False
    isave_lst = False
    pred_target_dict = {}

    for i, (data, target, names) in enumerate(data_loader):
        # data = Variable(data.cuda(async = True))
        # target = Variable(target.cuda(async = True))
        data = data.cuda(non_blocking=True)
        data = Resize([opt.sample_duration, opt.sample_size,
                       opt.sample_size])(data)
        target = target.cuda(non_blocking=True)

        out = model(data).cuda()
        #print(out,target.long())
        if "FP" in opt.save_dir:
            cls = loss(out, target)
        elif "BCE" in opt.save_dir:
            cls = loss(out, target)
        else:
            cls = loss(out, target.long())

        loss_lst.append(cls.data.cpu().numpy())
        if 'FP' in opt.save_dir:
            pred = torch.sigmoid(out[:, :1])
            pred_arr = pred.data.cpu().numpy()
        elif 'BCE' in opt.save_dir:
            pred = torch.sigmoid(out[:, :1])
            pred_arr = pred.data.cpu().numpy()
        else:
            pred = torch.sigmoid(out)
            pred_arr = pred.data.cpu().numpy().argmax(axis=1)
        if opt.n_classes != 1:
            prob_arr = pred.data.cpu().numpy()
        label_arr = target.data.cpu().numpy()
        if opt.n_classes == 1:
            _acc = acc_metric(pred_arr, label_arr)
        else:
            _acc = calculate_accuracy(out, target.long())

        if opt.n_classes == 1:
            for i in range(pred_arr.shape[0]):
                pred_target_dict[names[i]] = [pred_arr[i], label_arr[i]]
        else:
            for i in range(pred_arr.shape[0]):
                pred_target_dict[names[i]] = [
                    pred_arr[i], label_arr[i], prob_arr[i]
                ]
        pred_lst.append(pred_arr)
        label_lst.append(label_arr)
        if opt.n_classes != 1:
            prob_lst.append(prob_arr)
        test_acc.append(_acc)
        # name_lst.append(names)

    # import pdb;pdb.set_trace()
    test_loss = np.mean(loss_lst)
    acc = np.mean(test_acc)
    #print(np.concatenate(label_lst, axis=0))
    if 'FP' in opt.save_dir:
        label_lst = np.concatenate(label_lst, axis=0)[:, 0].tolist()
        pred_lst = np.concatenate(pred_lst, axis=0)[:, 0].tolist()
    else:
        label_lst = np.concatenate(label_lst, axis=0).tolist()
        pred_lst = np.concatenate(pred_lst, axis=0).tolist()
    if opt.n_classes != 1:
        prob_lst = np.concatenate(prob_lst, axis=0).tolist()
        #print(label_lst,pred_lst)
        auc, prec, recall, spec = multiclass_confusion_matrics(
            label_lst, pred_lst, prob_lst)
    else:
        auc, prec, recall, spec = confusion_matrics(label_lst, pred_lst)
    f1_score = 2 * (prec * recall) / (prec + recall)
    # import pdb;pdb.set_trace()
    if acc > max_acc:
        max_acc = acc
        max_auc = auc

        save_recall = recall
        save_prec = prec
        save_spec = spec
        save_f1score = f1_score

        isave = True
        isave_lst = True

    if auc > auc_max:
        auc_max = auc
        acc_max = acc

        save_recall = recall
        save_prec = prec
        save_spec = spec
        save_f1score = f1_score

        isave = True

    print(
        "Testing: Epoch %d:%dth batch, learning rate %2.6f loss %2.4f, acc %2.4f, auc %2.4f,precision %2.4f,recall %2.4f,specificity %2.4f!"
        % (epoch, i, lr, test_loss, acc, auc, prec, recall, spec))
    return max_acc, max_auc, acc_max, auc_max, test_loss, isave, pred_target_dict, isave_lst, save_recall, save_prec, save_spec, save_f1score
Example #12
0
def main():
    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default=None,
                        help='path to the data')
    parser.add_argument('-e',
                        '--epochs',
                        dest='epochs',
                        default=20,
                        type=int,
                        help='number of epochs')
    parser.add_argument('-b',
                        '--batch_size',
                        dest='batch_size',
                        default=40,
                        type=int,
                        help='batch size')
    parser.add_argument('-s',
                        '--image_size',
                        dest='image_size',
                        default=256,
                        type=int,
                        help='input image size')
    parser.add_argument('-lr',
                        '--learning_rate',
                        dest='lr',
                        default=0.0001,
                        type=float,
                        help='learning rate')
    parser.add_argument('-wd',
                        '--weight_decay',
                        dest='weight_decay',
                        default=5e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('-lrs',
                        '--learning_rate_step',
                        dest='lr_step',
                        default=10,
                        type=int,
                        help='learning rate step')
    parser.add_argument('-lrg',
                        '--learning_rate_gamma',
                        dest='lr_gamma',
                        default=0.5,
                        type=float,
                        help='learning rate gamma')
    parser.add_argument('-m',
                        '--model',
                        dest='model',
                        default='unet',
                        choices=('unet', ))
    parser.add_argument('-w',
                        '--weight_bce',
                        default=0.5,
                        type=float,
                        help='weight BCE loss')
    parser.add_argument('-l',
                        '--load',
                        dest='load',
                        default=False,
                        help='load file model')
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        default=0.8,
                        help='train/val split')
    parser.add_argument('-o',
                        '--output_dir',
                        dest='output_dir',
                        default='/tmp/logs/',
                        help='dir to save log and models')
    args = parser.parse_args()
    #
    os.makedirs(args.output_dir, exist_ok=True)
    logger = get_logger(os.path.join(args.output_dir, 'train.log'))
    logger.info('Start training with params:')
    for arg, value in sorted(vars(args).items()):
        logger.info("Argument %s: %r", arg, value)
    #
    net = UNet(
    )  # TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size
    # TODO: img_size=256 is rather mediocre, try to optimize network for at least 512
    logger.info('Model type: {}'.format(net.__class__.__name__))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if args.load:
        net.load_state_dict(torch.load(args.load))
    net.to(device)
    # net = nn.DataParallel(net)

    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    # TODO: loss experimentation, fight class imbalance, there're many ways you can tackle this challenge
    criterion = lambda x, y: (args.weight_bce * nn.BCELoss()(x, y),
                              (1. - args.weight_bce) * dice_loss(x, y))
    # TODO: you can always try on plateau scheduler as a default option
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) \
        if args.lr_step > 0 else None

    # dataset
    # TODO: to work on transformations a lot, look at albumentations package for inspiration
    train_transforms = Compose([
        Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5),
        Flip(p=0.05),
        Pad(max_size=0.6, p=0.25),
        Resize(size=(args.image_size, args.image_size), keep_aspect=True)
    ])
    # TODO: don't forget to work class imbalance and data cleansing
    val_transforms = Resize(size=(args.image_size, args.image_size))

    train_dataset = DetectionDataset(args.data_path,
                                     os.path.join(args.data_path,
                                                  'train_mask.json'),
                                     transforms=train_transforms)
    val_dataset = DetectionDataset(args.data_path,
                                   None,
                                   transforms=val_transforms)

    # split dataset into train/val, don't try to do this at home ;)
    train_size = int(len(train_dataset) * args.val_split)
    val_dataset.image_names = train_dataset.image_names[train_size:]
    val_dataset.mask_names = train_dataset.mask_names[train_size:]
    train_dataset.image_names = train_dataset.image_names[:train_size]
    train_dataset.mask_names = train_dataset.mask_names[:train_size]

    # TODO: always work with the data: cleaning, sampling
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True,
                                  drop_last=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                num_workers=4,
                                shuffle=False,
                                drop_last=False)
    logger.info('Length of train/val=%d/%d', len(train_dataset),
                len(val_dataset))
    logger.info('Number of batches of train/val=%d/%d', len(train_dataloader),
                len(val_dataloader))

    try:
        train(net,
              optimizer,
              criterion,
              scheduler,
              train_dataloader,
              val_dataloader,
              logger=logger,
              args=args,
              device=device)
    except KeyboardInterrupt:
        torch.save(net.state_dict(),
                   os.path.join(args.output_dir, 'INTERRUPTED.pth'))
        logger.info('Saved interrupt')
        sys.exit(0)
Example #13
0
def main():

    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default='../../data/',
                        help='path to the data')
    parser.add_argument('--epochs',
                        '-e',
                        dest='epochs',
                        type=int,
                        help='number of train epochs',
                        default=2)
    parser.add_argument('--batch_size',
                        '-b',
                        dest='batch_size',
                        type=int,
                        help='batch size',
                        default=16)
    parser.add_argument('--load',
                        '-l',
                        dest='load',
                        type=str,
                        help='pretrained weights',
                        default=None)
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        default=0.8,
                        type=float,
                        help='train/val split')
    parser.add_argument('--augs',
                        '-a',
                        dest='augs',
                        type=float,
                        help='degree of geometric augs',
                        default=0)

    args = parser.parse_args()
    OCR_MODEL_PATH = '../pretrained/ocr.pt'

    all_marks = load_json(os.path.join(args.data_path, 'train.json'))
    test_start = int(args.val_split * len(all_marks))
    train_marks = all_marks[:test_start]
    val_marks = all_marks[test_start:]

    w, h = (320, 64)
    train_transforms = transforms.Compose([
        #Rotate(max_angle=args.augs * 7.5, p=0.5),  # 5 -> 7.5
        #Pad(max_size=args.augs / 10, p=0.1),
        Resize(size=(w, h)),
        transforms.ToTensor()
    ])
    val_transforms = transforms.Compose(
        [Resize(size=(w, h)), transforms.ToTensor()])
    alphabet = abc

    train_dataset = OCRDataset(marks=train_marks,
                               img_folder=args.data_path,
                               alphabet=alphabet,
                               transforms=train_transforms)
    val_dataset = OCRDataset(marks=val_marks,
                             img_folder=args.data_path,
                             alphabet=alphabet,
                             transforms=val_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  drop_last=True,
                                  num_workers=0,
                                  collate_fn=collate_fn_ocr,
                                  timeout=0,
                                  shuffle=True)

    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                drop_last=False,
                                num_workers=0,
                                collate_fn=collate_fn_ocr,
                                timeout=0)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    model = CRNN(alphabet)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=3e-4,
                                 amsgrad=True,
                                 weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=10,
                                                           factor=0.5,
                                                           verbose=True)
    criterion = F.ctc_loss

    try:
        train(model, criterion, optimizer, scheduler, train_dataloader,
              val_dataloader, OCR_MODEL_PATH, args.epochs, device)
    except KeyboardInterrupt:
        torch.save(model.state_dict(), OCR_MODEL_PATH + 'INTERRUPTED_')
        #logger.info('Saved interrupt')
        sys.exit(0)