示例#1
0
def benchmark_dataset(dataset, title, fname, testname, xlabels, ylabels=None):
    global benchmark_data
    print(f'Benchmarking {title}.. ')

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=10,
        shuffle=False, num_workers=4,
    )

    tracked_metrics = [
        metrics.Accuracy(),
        metrics.RocAuc(),
        metrics.FScore()
    ]

    logs, cm = trainer.test(
        model=model, test_dataloader=dataloader,
        criterion=criterion, metrics=tracked_metrics, device=device
    )

    with open(f'logs/{vars.corda_version}/{name}/{fname}-metric.txt', 'w') as f:
        f.write(f'{fname}: ' + trainer.summarize_metrics(logs) + '\n')

    ax = sns.heatmap(
        cm.get(normalized=True), annot=True, fmt=".2f",
        xticklabels=xlabels, yticklabels=ylabels or xlabels,
        vmin=0., vmax=1.
    )
    ax.set_title(title)
    plt.xlabel('predicted')
    plt.ylabel('ground')
    hm = ax.get_figure()
    hm.savefig(f'logs/{vars.corda_version}/{name}/{fname}.png')
    hm.clf()

    fpr, tpr, thresholds = tracked_metrics[1].get_curve()
    auc = tracked_metrics[1].get()
    f = plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (auc = {auc:.2f})')
    plt.title(f'{title} ROC')
    plt.legend(loc='lower right')
    plt.savefig(f'logs/{vars.corda_version}/{name}/{fname}-roc.png')
    plt.clf()
    plt.cla()
    plt.close()

    specificity, fpr, fnr, sensitivity = cm.get(normalized=True).ravel()
    dor = (sensitivity*specificity)/((1-sensitivity)*(1-specificity))
    fscore = tracked_metrics[2].get()
    ba = (sensitivity+specificity)/2.

    data = {
        'arch': args.arch, 'pretrain': args.pretrain, 'train': args.train.upper(),
        'test': testname, 'accuracy': tracked_metrics[0].get(), 'auc': auc,
        'sensitivity': sensitivity, 'specificity': specificity, 'fscore': fscore,
        'ba': ba, 'missrate': fnr, 'dor': dor
    }

    for k,v in data.items():
        benchmark_data[k].append(v)
def test_accuracy():
    m = metrics.Accuracy()
    m.reset()
    m.add(torch.tensor([1]).to(DEVICE), torch.tensor([[.0, .1]]).to(DEVICE))
    assert m.get() == 1
    m.add(torch.tensor([0]).to(DEVICE), torch.tensor([[.0, .1]]).to(DEVICE))
    assert m.get() == 0.5
    m.add(torch.tensor([0]).to(DEVICE), torch.tensor([[.1, .0]]).to(DEVICE))
    m.add(torch.tensor([0]).to(DEVICE), torch.tensor([[.1, .0]]).to(DEVICE))
    assert m.get() == 0.75
示例#3
0
def benchmark_dataset(dataset, title, fname, xlabels, ylabels=None):
    print(f'Benchmarking {title}.. ', end='', flush=True)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=10,
        shuffle=False,
        num_workers=4,
    )

    tracked_metrics = [metrics.Accuracy(), metrics.RocAuc(), metrics.FScore()]

    logs, cm = trainer.test(model=model,
                            test_dataloader=dataloader,
                            criterion=criterion,
                            metrics=tracked_metrics,
                            device=device)

    with open(f'logs/{vars.corda_version}/{name}/{fname}-metric.txt',
              'w') as f:
        f.write(f'{fname}: ' + trainer.summarize_metrics(logs) + '\n')

    ax = sns.heatmap(cm.get(normalized=True),
                     annot=True,
                     fmt=".2f",
                     xticklabels=xlabels,
                     yticklabels=ylabels or xlabels)
    ax.set_title(title)
    plt.xlabel('predicted')
    plt.ylabel('ground')
    hm = ax.get_figure()
    hm.savefig(f'logs/{vars.corda_version}/{name}/{fname}.png')
    hm.clf()

    fpr, tpr, thresholds = tracked_metrics[1].get_curve()
    auc = tracked_metrics[1].get()
    f = plt.figure()
    plt.plot(fpr,
             tpr,
             color='darkorange',
             lw=2,
             label=f'ROC curve (auc = {auc:.2f})')
    plt.title(f'{title} ROC')
    plt.legend(loc='lower right')
    plt.savefig(f'logs/{vars.corda_version}/{name}/{fname}-roc.png')
    plt.clf()
    plt.cla()
    plt.close()
示例#4
0
    ).to(device)
elif args.arch == 'resnet50':
    model = covid_classifier.CovidClassifier50(
        encoder=feature_extractor,
        pretrained=False,
        freeze_conv=False
    ).to(device)

#model = covid_classifier.LeNet1024NoPoolingDeep().to(device)

print(f'Using lr {lr}')

# TRAINING
# %%
tracked_metrics = [
    metrics.Accuracy(),
    metrics.RocAuc(),
    metrics.FScore()
]

def focal_loss(output, target, gamma=2., weight=None):
    bce = F.binary_cross_entropy(output, target, reduction='none', weight=weight)
    pt = target*output + (1-target)*(1-output)
    return (torch.pow((1-pt), gamma) * bce).mean()

criterion = focal_loss
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-3)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=15, verbose=True)

best_model = trainer.fit(
    model=model, train_dataloader=train_dataloader,
示例#5
0
    trainer = Trainer(
        model=model,
        logger=log,
        prefix="classifier",
        checkpoint_dir=ARGS.checkpoint_dir,
        summary_dir=ARGS.summary_dir,
        n_summaries=4,
        start_scratch=ARGS.start_scratch,
    )

    metrics = {
        "tp": m.TruePositives(ARGS.device),
        "tn": m.TrueNegatives(ARGS.device),
        "fp": m.FalsePositives(ARGS.device),
        "fn": m.FalseNegatives(ARGS.device),
        "accuracy": m.Accuracy(ARGS.device),
        "f1": m.F1Score(ARGS.device),
        "precision": m.Precision(ARGS.device),
        "TPR": m.Recall(ARGS.device),
        "FPR": m.FPR(ARGS.device),
    }

    optimizer = optim.Adam(model.parameters(),
                           lr=ARGS.lr,
                           betas=(ARGS.beta1, 0.999))

    metric_mode = "max"
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode=metric_mode,
        patience=patience_lr,
print(f'using lr {lr}')

# %%
model = pneumonia_classifier.PneumoniaClassifierChest(
    pretrained=True).to(device)

# %%
criterion = functools.partial(F.cross_entropy, reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.001)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                          patience=15,
                                                          verbose=True)

# %%
tracked_metrics = [
    metrics.Accuracy(multiclass=True),
]

name = f'resnet18-pneumonia-classifier-s{seed}-3-classes-unprocessed'
utils.ensure_dir(f'logs/{vars.corda_version}/{name}')

# %%
best_model = trainer.fit(model=model,
                         train_dataloader=train_dataloader,
                         val_dataloader=val_dataloader,
                         test_dataloader=test_dataloader,
                         test_every=10,
                         criterion=criterion,
                         optimizer=optimizer,
                         scheduler=lr_scheduler,
                         metrics=tracked_metrics,
示例#7
0
def train_seg_wrapper(ctx, epoch, lr, model_prefix, symbol, class_num, workspace, init_weight_file,
                      im_root, mask_root, flist_path, use_g_labels, rgb_mean, crop_size, scale_range, label_shrink_scale,
                      epoch_size, max_epoch, batch_size, wd, momentum):

    arg_dict = {}
    aux_dict = {}
    if use_g_labels:
        seg_net = symbol.create_training(class_num=class_num, gweight=1.0/batch_size, workspace=workspace)
    else:
        seg_net = symbol.create_training(class_num=class_num, workspace=workspace)
    if epoch == 0:
        if not os.path.exists(init_weight_file):
            logging.warn("No model file found at %s. Start from scratch!" % init_weight_file)
        else:
            arg_dict, aux_dict, _ = misc.load_checkpoint(init_weight_file)
            param_types = ["_weight", "_bias", "_gamma", "_beta", "_moving_mean", "_moving_var"]
            #copy params for global branch
            if use_g_labels:
                for arg in arg_dict.keys():
                    for param_type in param_types:
                        if param_type in arg:
                            arg_name = arg[:arg.rfind(param_type)]
                            arg_dict[arg_name + "_g" + param_type] = arg_dict[arg].copy()
                            if arg_name in ["fc6", "fc7"]:
                                arg_dict[arg_name + "_1" + param_type] = arg_dict[arg].copy()
                                arg_dict[arg_name + "_2" + param_type] = arg_dict[arg].copy()
                                arg_dict[arg_name + "_3" + param_type] = arg_dict[arg].copy()
                                arg_dict[arg_name + "_4" + param_type] = arg_dict[arg].copy()
                            break
                for aux in aux_dict.keys():
                    for param_type in param_types:
                        if param_type in aux:
                            aux_name = aux[:aux.rfind(param_type)]
                            aux_dict[aux_name + "_g" + param_type] = aux_dict[aux].copy()
                            break
    else:
        arg_dict, aux_dict, _ = misc.load_checkpoint(model_prefix, epoch)

    data_iter = SegTrainingIter(
        im_root=im_root,
        mask_root=mask_root,
        file_list_path=flist_path,
        provide_g_labels=use_g_labels,
        class_num=class_num,
        rgb_mean=rgb_mean,
        crop_size=crop_size,
        shuffle=True,
        scale_range=scale_range,
        label_shrink_scale=label_shrink_scale,
        random_flip=True,
        data_queue_size=8,
        epoch_size=epoch_size,
        batch_size=batch_size,
        round_batch=True
    )


    initializer = mx.initializer.Normal()
    initializer.set_verbosity(True)

    if use_g_labels:
        mod = mx.mod.Module(seg_net, context=ctx, label_names=["softmax_label", "g_logistic_label"])
    else:
        mod = mx.mod.Module(seg_net, context=ctx, label_names=["softmax_label"])
    mod.bind(data_shapes=data_iter.provide_data,
            label_shapes=data_iter.provide_label)
    mod.init_params(initializer=initializer, arg_params=arg_dict, aux_params=aux_dict, allow_missing=(epoch == 0))

    opt_params = {"learning_rate":lr,
                "wd": wd,
                'momentum': momentum,
                'rescale_grad': 1.0/len(ctx)}

    if use_g_labels:
        eval_metrics = [metrics.Accuracy(), metrics.Loss(), metrics.MultiLogisticLoss(l_index=1, p_index=1)]
    else:
        eval_metrics = [metrics.Accuracy(), metrics.Loss()]
    mod.fit(data_iter,
            optimizer="sgd",
            optimizer_params=opt_params,
            num_epoch=max_epoch,
            epoch_end_callback=callbacks.module_checkpoint(model_prefix),
            batch_end_callback=callbacks.Speedometer(batch_size, frequent=10),
            eval_metric=eval_metrics,
            begin_epoch=epoch+1)
示例#8
0
if args.arch == 'resnet18':
    model = covid_classifier.CovidClassifier(encoder=feature_extractor,
                                             pretrained=False,
                                             freeze_conv=False).to(device)
elif args.arch == 'resnet50':
    model = covid_classifier.CovidClassifier50(encoder=feature_extractor,
                                               pretrained=False,
                                               freeze_conv=False).to(device)

#model = covid_classifier.LeNet1024NoPoolingDeep().to(device)

print(f'Using lr {lr}')

# TRAINING
# %%
tracked_metrics = [metrics.Accuracy(), metrics.RocAuc(), metrics.FScore()]


def focal_loss(output, target, gamma=2., weight=None):
    bce = F.binary_cross_entropy(output,
                                 target,
                                 reduction='none',
                                 weight=weight)
    pt = target * output + (1 - target) * (1 - output)
    return (torch.pow((1 - pt), gamma) * bce).mean()


criterion = focal_loss
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-3)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                          patience=15,