Esempio n. 1
0
 def __init__(self, file_list_path, img_path, txt_path):
     transform = transforms.Compose([
         transforms.Resize([285, 285]),
         transforms.ToTensor(),
         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
     ])
     self._loader = torch.utils.data.DataLoader(BarcodeDataset(file_list_path, img_path, txt_path, transform), batch_size=128, shuffle=False)
Esempio n. 2
0
def _train(train_img_path, train_txt_path, val_img_path, val_txt_path,
           path_to_log_dir, path_to_restore_checkpoint_file, training_options):
    batch_size = training_options['batch_size']
    initial_learning_rate = training_options['learning_rate']
    initial_patience = training_options['patience']
    num_steps_to_show_loss = 100
    num_steps_to_check = 1000

    step = 0
    patience = initial_patience
    best_accuracy = 0.0
    duration = 0.0

    model = Model(21)
    model.cuda()

    transform = transforms.Compose([
        transforms.Resize([285, 285]),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    train_loader = torch.utils.data.DataLoader(BarcodeDataset(
        train_img_path, train_txt_path, transform),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    evaluator = Evaluator(val_img_path, val_txt_path)
    optimizer = optim.SGD(model.parameters(),
                          lr=initial_learning_rate,
                          momentum=0.9,
                          weight_decay=0.0005)
    scheduler = StepLR(optimizer,
                       step_size=training_options['decay_steps'],
                       gamma=training_options['decay_rate'])

    if path_to_restore_checkpoint_file is not None:
        assert os.path.isfile(
            path_to_restore_checkpoint_file
        ), '%s not found' % path_to_restore_checkpoint_file
        step = model.restore(path_to_restore_checkpoint_file)
        scheduler.last_epoch = step
        print('Model restored from file: %s' % path_to_restore_checkpoint_file)

    path_to_losses_npy_file = os.path.join(path_to_log_dir, 'losses.npy')
    if os.path.isfile(path_to_losses_npy_file):
        losses = np.load(path_to_losses_npy_file)
    else:
        losses = np.empty([0], dtype=np.float32)

    while True:
        for batch_idx, (images, digits_labels) in enumerate(train_loader):
            start_time = time.time()
            images, digits_labels = images.cuda(), [
                digit_label.cuda() for digit_label in digits_labels
            ]
            digit2_logits, digit3_logits, digit4_logits, digit5_logits, digit6_logits, digit7_logits, digit8_logits, digit9_logits, digit10_logits, digit11_logits, digit12_logits, digit13_logits = model.train(
            )(images)
            loss = _loss(digit2_logits, digit3_logits, digit4_logits,
                         digit5_logits, digit6_logits, digit7_logits,
                         digit8_logits, digit9_logits, digit10_logits,
                         digit11_logits, digit12_logits, digit13_logits,
                         digits_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            step += 1
            duration += time.time() - start_time

            if step % num_steps_to_show_loss == 0:
                examples_per_sec = batch_size * num_steps_to_show_loss / duration
                duration = 0.0
                print(
                    '=> %s: step %d, loss = %f, learning_rate = %f (%.1f examples/sec)'
                    % (datetime.now(), step, loss.item(),
                       scheduler.get_lr()[0], examples_per_sec))

            if step % num_steps_to_check != 0:
                continue

            losses = np.append(losses, loss.item())
            np.save(path_to_losses_npy_file, losses)

            print('=> Evaluating on validation dataset...')
            accuracy = evaluator.evaluate(model)
            print('==> accuracy = %f, best accuracy %f' %
                  (accuracy, best_accuracy))

            if accuracy > best_accuracy:
                path_to_checkpoint_file = model.store(path_to_log_dir,
                                                      step=step)
                print('=> Model saved to file: %s' % path_to_checkpoint_file)
                patience = initial_patience
                best_accuracy = accuracy
            else:
                patience -= 1

            print('=> patience = %d' % patience)
            if patience == 0:
                return