def validate(data_loader, model, best_score, global_step, cfg):
    """
    Loads the model weights from the state dictionary. Function will only load
    the weights which have matching key names and dimensions in the state
    dictionary.

    :param state_dict: Pytorch model state dictionary
    :param verbose: bool, If True, the function will print the
        weight keys of parameters that can and cannot be loaded from the
        checkpoint state dictionary.
    :return:
    """
    model.eval()
    gts, predictions = [], []

    log.info("Validation started...")
    for data in data_loader:
        imgs, labels = data
        imgs = util.to_device(imgs, gpu=cfg.gpu)

        with torch.no_grad():
            logits = model(imgs)
            probs = model.module.probability(logits)
            preds = torch.argmax(probs, dim=1).cpu().numpy()

        labels = labels.cpu().detach().numpy()

        predictions.extend(preds)
        gts.extend(labels)

    predictions = np.array(predictions, dtype=np.int32)
    gts = np.array(gts, dtype=np.int32)
    acc, f1, prec, rec = util.clf_metrics(predictions=predictions,
                                          targets=gts,
                                          average="macro")
    report = classification_report(gts, predictions, output_dict=True)

    log.info("VALIDATION | Accuracy {:.4f} | F1 {:.4f} | Precision {:.4f} | "
             "Recall {:.4f}".format(acc, f1, prec, rec))

    if acc > best_score:
        save_config = {
            'name': config.name,
            'save_dir': config.ckpts_dir,
            'global_step': global_step,
            'clf_report': report
        }
        save_model(model=model, config=save_config)
        best_score = acc
    log.info("Validation end")

    model.train()
    return best_score
Exemple #2
0
def validate(data_loader, model, best_score, global_step, cfg):
    model.eval()
    gts, predictions = [], []

    log.info("Validation started...")
    for data in data_loader:
        imgs, labels = data
        imgs = util.to_device(imgs, gpu=cfg.gpu)

        with torch.no_grad():
            logits = model(imgs)
            probs = model.module.probability(logits)
            preds = torch.argmax(probs, dim=1).cpu().numpy()

        labels = labels.cpu().detach().numpy()

        predictions.extend(preds)
        gts.extend(labels)

    predictions = np.array(predictions, dtype=np.int32)
    gts = np.array(gts, dtype=np.int32)
    acc, f1, prec, rec = util.clf_metrics(predictions=predictions,
                                          targets=gts,
                                          average="macro")
    report = classification_report(gts, predictions, output_dict=True)

    log.info("VALIDATION | Accuracy {:.4f} | F1 {:.4f} | Precision {:.4f} | "
             "Recall {:.4f}".format(acc, f1, prec, rec))

    if f1 > best_score:
        save_config = {
            'name': config.name,
            'save_dir': config.ckpts_dir,
            'global_step': global_step,
            'clf_report': report
        }
        save_model(model=model, config=save_config)
        best_score = f1
    log.info("Validation end")

    model.train()
    return best_score
Exemple #3
0
def main():
    if config.gpu and not torch.cuda.is_available():
        raise ValueError("GPU not supported or enabled on this system.")
    use_gpu = config.gpu

    log.info("Loading train dataset")
    train_dataset = COVIDxFolder(
        config.train_imgs, config.train_labels,
        transforms.train_transforms(config.width, config.height))
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=config.n_threads,
                              pin_memory=use_gpu)
    log.info("Number of training examples {}".format(len(train_dataset)))

    log.info("Loading val dataset")
    val_dataset = COVIDxFolder(
        config.val_imgs, config.val_labels,
        transforms.val_transforms(config.width, config.height))
    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.n_threads,
                            pin_memory=use_gpu)
    log.info("Number of validation examples {}".format(len(val_dataset)))

    if config.weights:
        #state = torch.load(config.weights)
        state = None
        log.info("Loaded model weights from: {}".format(config.weights))
    else:
        state = None

    state_dict = state["state_dict"] if state else None
    model = architecture.COVIDNext50(n_classes=config.n_classes)
    if state_dict:
        model = util.load_model_weights(model=model, state_dict=state_dict)

    if use_gpu:
        model.cuda()
        model = torch.nn.DataParallel(model)
    optim_layers = filter(lambda p: p.requires_grad, model.parameters())

    # optimizer and lr scheduler
    optimizer = Adam(optim_layers,
                     lr=config.lr,
                     weight_decay=config.weight_decay)
    scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                  factor=config.lr_reduce_factor,
                                  patience=config.lr_reduce_patience,
                                  mode='max',
                                  min_lr=1e-7)

    # Load the last global_step from the checkpoint if existing
    global_step = 0 if state is None else state['global_step'] + 1

    class_weights = util.to_device(torch.FloatTensor(config.loss_weights),
                                   gpu=use_gpu)
    loss_fn = CrossEntropyLoss()

    # Reset the best metric score
    best_score = -1
    for epoch in range(config.epochs):
        log.info("Started epoch {}/{}".format(epoch + 1, config.epochs))
        for data in train_loader:
            imgs, labels = data
            imgs = util.to_device(imgs, gpu=use_gpu)
            labels = util.to_device(labels, gpu=use_gpu)

            logits = model(imgs)
            loss = loss_fn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if global_step % config.log_steps == 0 and global_step > 0:
                probs = model.module.probability(logits)
                preds = torch.argmax(probs, dim=1).detach().cpu().numpy()
                labels = labels.cpu().detach().numpy()
                acc, f1, _, _ = util.clf_metrics(preds, labels)
                lr = util.get_learning_rate(optimizer)

                log.info("Step {} | TRAINING batch: Loss {:.4f} | F1 {:.4f} | "
                         "Accuracy {:.4f} | LR {:.2e}".format(
                             global_step, loss.item(), f1, acc, lr))

            if global_step % config.eval_steps == 0 and global_step > 0:
                best_score = validate(val_loader,
                                      model,
                                      best_score=best_score,
                                      global_step=global_step,
                                      cfg=config)
                scheduler.step(best_score)
            global_step += 1