Exemple #1
0
def main():
    """
    Training and validation.
    """
    global epochs_since_improvement, start_epoch, label_map, best_F1, epoch, checkpoint

    # Initialize model or load checkpoint
    if checkpoint is None:
        model = LSTMClassifier()
        # Initialize the optimizer, with twice the default learning rate for biases, as in the original Caffe repo
        biases = list()
        not_biases = list()
        for param_name, param in model.named_parameters():
            if param.requires_grad:
                if param_name.endswith('.bias'):
                    biases.append(param)
                else:
                    not_biases.append(param)
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     betas=(0.9, 0.99))
        # optimizer = torch.optim.SGD(params=[{'params': biases, 'lr': 2 * lr}, {'params': not_biases}],
        #lr=lr, momentum=momentum, weight_decay=weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch']
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_F1 = checkpoint['best_F1']
        print('\nLoaded checkpoint from epoch %d. Best F1 so far is %.3f.\n' %
              (start_epoch, best_F1))
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to default device
    model = model.to(device)
    print(model)

    # criterion = torch.nn.CrossEntropyLoss()
    criterion = FocalLoss()

    # Custom dataloaders
    train_dataset = ICDARDataset(data_folder, split='train')
    val_dataset = ICDARDataset(data_folder, split='test')
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=train_dataset.collate_fn,
        num_workers=workers,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             collate_fn=val_dataset.collate_fn,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):
        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           epoch=epoch)

        # One epoch's validation
        val_loss, accuracy, F1 = validate(val_loader=val_loader,
                                          model=model,
                                          criterion=criterion)

        # Did validation loss improve?
        # is_best = train_loss < best_loss
        # best_loss = min(train_loss, best_loss)

        # Did validation loss improve?
        is_best = F1 > best_F1
        best_F1 = max(F1, best_F1)

        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))

        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer,
                        val_loss, best_F1, is_best)

        with open('log.txt', 'a+') as f:
            f.write('epoch:' + str(epoch) + '  train loss:' + str(train_loss) +
                    '  val loss:' + str(val_loss) + 'accuracy:' +
                    str(accuracy) + '\n')