Ejemplo n.º 1
0
def main(args):
    if args.rnn:
        transform = transforms.Compose([
            Normalize([0.3956, 0.5763, 0.5616], [0.1535, 0.1278, 0.1299]),
            Resize((204, 32)),
            ToTensorRGBFlatten()
        ])
    else:
        transform = transforms.Compose([
            Normalize([0.3956, 0.5763, 0.5616], [0.1535, 0.1278, 0.1299]),
            Resize((204, 32)),
            ToTensor()
        ])
    train_set = digitsDataset(args.train_root_path, transform=transform)
    val_set = digitsDataset(args.val_root_path, transform=transform)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)
    # trainer parameters
    params = EasyDict()
    params.max_epoch = args.max_epoch
    params.print_freq = args.print_freq
    params.validate_interval = args.validate_interval
    params.save_interval = args.save_interval
    params.expr_path = args.expr_path
    params.rnn = args.rnn
    device = torch.device("cuda")

    # train engine
    ntoken = len(args.alphabet) + 1
    if args.rnn:
        input_dim = 96
        model = LSTMFeatures(input_dim, args.batch_size, ntoken)
    else:
        model = DenseNetFeature(num_classes=ntoken)
    model = model.to(device)
    criterion = CTCLoss()
    criterion = criterion.to(device)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    converter = LabelConverter(args.alphabet)

    solver = SolverWrapper(params)
    # train
    solver.train(train_loader, val_loader, model, criterion, optimizer, device,
                 converter)
Ejemplo n.º 2
0
def main():
    # Set parameters of the trainer
    global args, device
    args = parse_args()

    print('=' * 60)
    print(args)
    print('=' * 60)
    random.seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)

    if torch.cuda.is_available() and not args.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )
    if args.cuda and torch.cuda.is_available():
        device = torch.device('cuda')
        #cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # load alphabet from file
    if os.path.isfile(args.alphabet):
        alphabet = ''
        with open(args.alphabet, mode='rb') as f:
            for line in f.readlines():
                alphabet += line.decode('utf-8')[0]
        args.alphabet = alphabet

    converter = utils.CTCLabelConverter(args.alphabet, ignore_case=False)

    # data loader
    image_size = (args.image_h, args.image_w)
    collater = DatasetCollater(image_size, keep_ratio=args.keep_ratio)
    train_dataset = Dataset(mode='train',
                            data_root=args.data_root,
                            transform=None)
    #sampler = RandomSequentialSampler(train_dataset, args.batch_size)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   collate_fn=collater,
                                   shuffle=True,
                                   num_workers=args.workers)

    val_dataset = Dataset(mode='val', data_root=args.data_root, transform=None)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 collate_fn=collater,
                                 shuffle=True,
                                 num_workers=args.workers)

    # network
    num_classes = len(args.alphabet) + 1
    num_channels = 1
    if args.arch == 'crnn':
        model = CRNN(args.image_h, num_channels, num_classes, args.num_hidden)
    elif args.arch == 'densenet':
        model = DenseNet(
            num_channels=num_channels,
            num_classes=num_classes,
            growth_rate=12,
            block_config=(3, 6, 9),  #(3,6,12,16),
            compression=0.5,
            num_init_features=64,
            bn_size=4,
            rnn=args.rnn,
            num_hidden=args.num_hidden,
            drop_rate=0,
            small_inputs=True,
            efficient=False)
    else:
        raise ValueError('unknown architecture {}'.format(args.arch))
    model = model.to(device)
    summary(model, torch.zeros((2, 1, 32, 650)).to(device))
    #print('='*60)
    #print(model)
    #print('='*60)

    # loss
    criterion = CTCLoss()
    criterion = criterion.to(device)

    # setup optimizer
    if args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               betas=(0.9, 0.999),
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(),
                                   weight_decay=args.weight_decay)
    else:
        raise ValueError('unknown optimizer {}'.format(args.optimizer))
    print('=' * 60)
    print(optimizer)
    print('=' * 60)

    # Define learning rate decay schedule
    global scheduler
    #exp_decay = math.exp(-0.1)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                 gamma=args.decay_rate)
    #step_size = 10000
    #gamma_decay = 0.8
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma_decay)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=gamma_decay)

    # initialize model
    if args.pretrained and os.path.isfile(args.pretrained):
        print(">> Using pre-trained model '{}'".format(
            os.path.basename(args.pretrained)))
        state_dict = torch.load(args.pretrained)
        model.load_state_dict(state_dict)
        print("loading pretrained model done.")

    global is_best, best_accuracy
    is_best = False
    best_accuracy = 0.0
    start_epoch = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            # load checkpoint weights and update model and optimizer
            print(">> Loading checkpoint:\n>> '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_accuracy = checkpoint['best_accuracy']
            print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})".format(
                args.resume, start_epoch))
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            # important not to forget scheduler updating
            #scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.decay_rate, last_epoch=start_epoch - 1)
        else:
            print(">> No checkpoint found at '{}'".format(args.resume))

    # Create export dir if it doesnt exist
    checkpoint = "{}".format(args.arch)
    checkpoint += "_{}".format(args.optimizer)
    checkpoint += "_lr_{}".format(args.lr)
    checkpoint += "_decay_rate_{}".format(args.decay_rate)
    checkpoint += "_bsize_{}".format(args.batch_size)
    checkpoint += "_height_{}".format(args.image_h)
    checkpoint += "_keep_ratio" if args.keep_ratio else "_width_{}".format(
        image_size[1])

    args.checkpoint = os.path.join(args.checkpoint, checkpoint)
    if not os.path.exists(args.checkpoint):
        os.makedirs(args.checkpoint)

    print('start training...')
    for epoch in range(start_epoch, args.max_epoch):
        # Aujust learning rate for each epoch
        scheduler.step()

        # Train for one epoch on train set
        _ = train(train_loader, val_loader, model, criterion, optimizer, epoch,
                  converter)
Ejemplo n.º 3
0
text = torch.IntTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)

if opt.cuda != '-1':
    str_ids = opt.cuda.split(",")
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            gpu_ids.append(id)
    if len(gpu_ids) > 0:
        torch.cuda.set_device(gpu_ids[0])
        crnn.to(gpu_ids[0])
        crnn = torch.nn.DataParallel(crnn, device_ids=gpu_ids)
        image = image.to(gpu_ids[0])
        criterion = criterion.to(gpu_ids[0])
if opt.pretrained > -1:
    model_path = '{0}/netCRNN_{1}.pth'.format(opt.expr_dir, opt.pretrained)
    print('loading pretrained model from %s' % model_path)
    # crnn.load_state_dict(torch.load(opt.pretrained))
    crnn.load_state_dict(torch.load(model_path))
print(crnn)

image = Variable(image)
text = Variable(text)
length = Variable(length)

# loss averager
loss_avg = utils.averager()

# setup optimizer
Ejemplo n.º 4
0
crnn.apply(weights_init)
if opt.pretrained != '':
    print('loading pretrained model from %s' % opt.pretrained)
    crnn.load_state_dict(torch.load(opt.pretrained))
print(crnn)

image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgW)
text = torch.IntTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)


if torch.cuda.device_count() > 1:
    crnn = nn.DataParallel(crnn)
crnn.to(device)
image = image.to(device)
criterion = criterion.to(device)

image = image.to(device)
# text = text.to(device)
# length = length.to(device)

image = Variable(image)
text = Variable(text)
length = Variable(length)

# loss averager
loss_avg = utils.averager()
epoch_loss_avg = utils.averager()

# setup optimizer
if opt.adam: