Пример #1
0
def main():

    """
    Main Function
    """

    # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
    assert_and_infer_cfg(args)
    writer = prep_experiment(args, parser)
    train_loader, val_loader, train_obj = datasets.setup_loaders(args)
    criterion, criterion_val = loss.get_loss(args)
    net = network.get_net(args, criterion)
    optim, scheduler = optimizer.get_optimizer(args, net)

    if args.fix_bn:
        net.apply(set_bn_eval)
        print("Fix bn for finetuning")

    if args.fp16:
        net, optim = amp.initialize(net, optim, opt_level="O1")

    net = network.wrap_network_in_dataparallel(net, args.apex)
    if args.snapshot:
        optimizer.load_weights(net, optim,
                               args.snapshot, args.restore_optimizer)
    if args.evaluateF:
        assert args.snapshot is not None, "must load weights for evaluation"
        evaluate(val_loader, net, args)
        return
    torch.cuda.empty_cache()
    # Main Loop
    for epoch in range(args.start_epoch, args.max_epoch):
        # Update EPOCH CTR
        cfg.immutable(False)
        cfg.EPOCH = epoch
        cfg.immutable(True)

        scheduler.step()
        train(train_loader, net, optim, epoch, writer)
        if args.apex:
            train_loader.sampler.set_epoch(epoch + 1)
        validate(val_loader, net, criterion_val,
                 optim, epoch, writer)
        if args.class_uniform_pct:
            if epoch >= args.max_cu_epoch:
                train_obj.build_epoch(cut=True)
                if args.apex:
                    train_loader.sampler.set_num_samples()
            else:
                train_obj.build_epoch()
Пример #2
0
def main():
    """
    Main Function
    """
    # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
    assert_and_infer_cfg(args)
    prep_experiment(args, parser)
    writer = None

    _, _, _, extra_val_loaders, _ = datasets.setup_loaders(args)

    criterion, criterion_val = loss.get_loss(args)
    criterion_aux = loss.get_loss_aux(args)
    net = network.get_net(args, criterion, criterion_aux)

    optim, scheduler = optimizer.get_optimizer(args, net)

    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net = network.warp_network_in_dataparallel(net, args.local_rank)
    epoch = 0
    i = 0

    if args.snapshot:
        epoch, mean_iu = optimizer.load_weights(net, optim, scheduler,
                            args.snapshot, args.restore_optimizer)

    print("#### iteration", i)
    torch.cuda.empty_cache()
    # Main Loop
    # for epoch in range(args.start_epoch, args.max_epoch):

    for dataset, val_loader in extra_val_loaders.items():
        print("Extra validating... This won't save pth file")
        validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False)
Пример #3
0
def main():
    """
    Main Function
    """
    # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
    assert_and_infer_cfg(args)
    writer = prep_experiment(args, parser)

    train_loader, val_loaders, train_obj, extra_val_loaders, covstat_val_loaders = datasets.setup_loaders(
        args)

    criterion, criterion_val = loss.get_loss(args)
    criterion_aux = loss.get_loss_aux(args)
    net = network.get_net(args, criterion, criterion_aux)

    optim, scheduler = optimizer.get_optimizer(args, net)

    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net = network.warp_network_in_dataparallel(net, args.local_rank)
    epoch = 0
    i = 0

    if args.snapshot:
        epoch, mean_iu = optimizer.load_weights(net, optim, scheduler,
                                                args.snapshot,
                                                args.restore_optimizer)
        if args.restore_optimizer is True:
            iter_per_epoch = len(train_loader)
            i = iter_per_epoch * epoch
        else:
            epoch = 0

    print("#### iteration", i)
    torch.cuda.empty_cache()
    # Main Loop
    # for epoch in range(args.start_epoch, args.max_epoch):

    while i < args.max_iter:
        # Update EPOCH CTR
        cfg.immutable(False)
        cfg.ITER = i
        cfg.immutable(True)

        i = train(train_loader, net, optim, epoch, writer, scheduler,
                  args.max_iter)
        train_loader.sampler.set_epoch(epoch + 1)

        if (args.dynamic and args.use_isw and epoch % (args.cov_stat_epoch + 1) == args.cov_stat_epoch) \
           or (args.dynamic is False and args.use_isw and epoch == args.cov_stat_epoch):
            net.module.reset_mask_matrix()
            for trial in range(args.trials):
                for dataset, val_loader in covstat_val_loaders.items(
                ):  # For get the statistics of covariance
                    validate_for_cov_stat(val_loader,
                                          dataset,
                                          net,
                                          criterion_val,
                                          optim,
                                          scheduler,
                                          epoch,
                                          writer,
                                          i,
                                          save_pth=False)
                    net.module.set_mask_matrix()

        if args.local_rank == 0:
            print("Saving pth file...")
            evaluate_eval(args,
                          net,
                          optim,
                          scheduler,
                          None,
                          None, [],
                          writer,
                          epoch,
                          "None",
                          None,
                          i,
                          save_pth=True)

        if args.class_uniform_pct:
            if epoch >= args.max_cu_epoch:
                train_obj.build_epoch(cut=True)
                train_loader.sampler.set_num_samples()
            else:
                train_obj.build_epoch()

        epoch += 1

    # Validation after epochs
    if len(val_loaders) == 1:
        # Run validation only one time - To save models
        for dataset, val_loader in val_loaders.items():
            validate(val_loader, dataset, net, criterion_val, optim, scheduler,
                     epoch, writer, i)
    else:
        if args.local_rank == 0:
            print("Saving pth file...")
            evaluate_eval(args,
                          net,
                          optim,
                          scheduler,
                          None,
                          None, [],
                          writer,
                          epoch,
                          "None",
                          None,
                          i,
                          save_pth=True)

    for dataset, val_loader in extra_val_loaders.items():
        print("Extra validating... This won't save pth file")
        validate(val_loader,
                 dataset,
                 net,
                 criterion_val,
                 optim,
                 scheduler,
                 epoch,
                 writer,
                 i,
                 save_pth=False)
def main():
    """
    Main Function
    """

    # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
    #    args2 = copy.deepcopy(args)
    assert_and_infer_cfg(args)
    #    assert_and_infer_cfg(args2)
    #    args2.dataset = 'kitti_trav'
    #    print(args.dataset)
    #    print(args2.dataset)
    writer = prep_experiment(args, parser)
    #    writer = prep_experiment(args2, parser)

    # Dataset
    train_loader, val_loader, train_obj = datasets.setup_loaders(args)
    #    train_loader2, val_loader2, train_obj2 = datasets.setup_loaders(args2)
    criterion, criterion_val = loss.get_loss(args, data_type='semantic')
    criterion2, criterion_val2 = loss.get_loss(args, data_type='trav')
    net = network.get_net(args, criterion, criterion2)

    #parameters list
    #    param1_lists = list(net.mod1.parameters()) + list(net.mod2.parameters()) + list(net.mod3.parameters()) + list(net.mod4.parameters()) + list(net.mod5.parameters()) + list(net.mod6.parameters()) + list(net.mod7.parameters()) + list(net.pool2.parameters()) + list(net.pool3.parameters()) + list(net.aspp.parameters()) + list(net.bot_fine.parameters()) + list(net.bot_aspp.parameters()) + list(net.final.parameters()) + [log_sigma_A]
    #    param2_lists = list(net.mod1.parameters()) + list(net.mod2.parameters()) + list(net.mod3.parameters()) + list(net.mod4.parameters()) + list(net.mod5.parameters()) + list(net.mod6.parameters()) + list(net.mod7.parameters()) + list(net.pool2.parameters()) + list(net.pool3.parameters()) + list(net.aspp.parameters()) + list(net.bot_fine.parameters()) + list(net.bot_aspp.parameters()) + list(net.final2.parameters()) + [log_sigma_B]

    #optimizers
    optim, scheduler = optimizer.get_optimizer(args, net)
    #    optim2, scheduler2 = optimizer.get_optimizer(args, param2_lists)

    if args.fp16:
        net, optim = amp.initialize(net, optim, opt_level="O1")

    net = network.wrap_network_in_dataparallel(net, args.apex)
    if args.snapshot:
        optimizer.load_weights(net, optim, args.snapshot, args.snapshot2,
                               args.restore_optimizer)
#        optimizer.load_weights(net, optim2,
#                               args.snapshot, args.snapshot2, args.restore_optimizer)

    torch.cuda.empty_cache()
    # Main Loop
    for epoch in range(args.start_epoch, args.max_epoch):
        # Update EPOCH CTR
        cfg.immutable(False)
        cfg.EPOCH = epoch
        cfg.immutable(True)

        scheduler.step()
        train(train_loader, net, optim, epoch, writer)
        if args.apex:
            train_loader.sampler.set_epoch(epoch + 1)
#            train_loader2.sampler.set_epoch(epoch + 1)
        validate(val_loader, net, criterion_val, criterion_val2, optim, epoch,
                 writer)
        if args.class_uniform_pct:
            if epoch >= args.max_cu_epoch:
                train_obj.build_epoch(cut=True)
                #                train_obj2.build_epoch(cut=True)
                if args.apex:
                    train_loader.sampler.set_num_samples()
#                    train_loader2.sampler.set_num_samples()
            else:
                train_obj.build_epoch()
Пример #5
0
def main():
    """
    Main Function
    """
    prep_experiment(args, parser)
    if args.dataset=='robotic_instrument':
        from datasets.robotic_instrument import get_dataloader
        if args.task=='binary':
            args.num_classes = 2
            args.ignore_label = 2
            args.cls_wt = [1.0, 1.0]
        elif args.task=='parts':
            args.num_classes = 5
            args.ignore_label = 5
            args.cls_wt = [0.1, 1.0, 1.0, 1.0, 1.0]
        elif args.task=='type':
            args.num_classes = 8
            args.ignore_label = 8
            args.cls_wt = [0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
        train_loader, val_loader = get_dataloader(args.task, batch_size=args.batch_size)
        net_param = {"class_num"   : args.num_classes,
                     "in_chns"     : 3,
                     "bilinear"    : True,
                     "feature_chns": [16, 32, 64, 128, 256],
                     "dropout"     : [0.0, 0.0, 0.3, 0.4, 0.5]}
    elif args.dataset=='covid19_lesion':
        from datasets.covid19_lesion import get_dataloader
        train_loader, val_loader = get_dataloader(args.task, batch_size=args.batch_size)
        args.ignore_label = 2
        args.num_classes = 2
        net_param = {"class_num"   : args.num_classes,
                     "in_chns"     : 1,
                     "bilinear"    : True,
                     "feature_chns": [16, 32, 64, 128, 256],
                     "dropout"     : [0.0, 0.0, 0.3, 0.4, 0.5]}
    else:
        raise NotImplementedError('The dataset is not supported.')

    if args.margin_loss:
        args.margins = loss.calculate_margins(train_loader, args.num_classes)
        

    criterion = loss.get_loss(args)
    net = COPLENet(net_param).cuda()
    optim, scheduler = optimizer.get_optimizer(args, net)

    epoch = 0
    i = 0
    if args.snapshot:
        epoch, mean_iou = optimizer.load_weights(net, optim, scheduler,
                                args.snapshot, args.restore_optimizer)
        if args.restore_optimizer:
            iter_per_epoch = len(train_loader)
            i = iter_per_epoch * epoch
        else:
            epoch = 0

    print("#### iteration", i)
    torch.cuda.empty_cache()

    while i < args.max_iter and epoch < args.max_epoch:
        i = train(train_loader, net, criterion, optim, epoch, scheduler, args.max_iter)
        #iou_train = iou_on_trainset(train_loader, net, criterion)
        #logging.info('Mean IoU on training set: %f' % iou_train)
        val_loss, per_cls_iou = validate(val_loader, net, criterion, optim, scheduler, epoch+1, i)
        epoch += 1
Пример #6
0
def main():
    if args.dataset=='robotic_instrument':
        from datasets.robotic_instrument import get_testloader, RoboticInstrument
        if args.task=='binary':
            num_classes = 2
        elif args.task=='parts':
            num_classes = 5
        elif args.task=='type':
            num_classes = 8
        dataset = RoboticInstrument(args.task, 'test')
        test_loader = get_testloader(args.task, batch_size=args.batch_size)
        net_param = {"class_num"   : num_classes,
                     "in_chns"     : 3,
                     "bilinear"    : True,
                     "feature_chns": [16, 32, 64, 128, 256],
                     "dropout"     : [0.0, 0.0, 0.3, 0.4, 0.5]}
    elif args.dataset=='covid19_lesion':
        from datasets.covid19_lesion import get_testloader, Covid19Dataset
        dataset = Covid19Dataset(args.task, 'test')
        test_loader = get_testloader(args.task, batch_size=args.batch_size)
        num_classes = 2
        net_param = {"class_num"   : num_classes,
                     "in_chns"     : 1,
                     "bilinear"    : True,
                     "feature_chns": [16, 32, 64, 128, 256],
                     "dropout"     : [0.0, 0.0, 0.3, 0.4, 0.5]}
    else:
        raise NotImplementedError('The dataset is not supported.')
    
    net = COPLENet(net_param).cuda()
    optimizer.load_weights(net, None, None, args.snapshot, False)    
    torch.cuda.empty_cache()

    net.eval()
    hist = 0
    predictions = []
    groundtruths = []
    for test_idx, data in enumerate(test_loader):
        inputs, gts = data 
        assert len(inputs.size()) == 4 and len(gts.size()) == 3
        assert inputs.size()[2:] == gts.size()[1:]
        inputs, gts = inputs.cuda(), gts.cuda()
        with torch.no_grad():
            output = net(inputs)
        del inputs
        assert output.size()[2:] == gts.size()[1:]
        assert output.size()[1] == num_classes
        
        prediction = output.data.max(1)[1].cpu()
        predictions.append(output.data.cpu().numpy())
        groundtruths.append(gts.cpu().numpy())
        hist += fast_hist(prediction.numpy().flatten(), gts.cpu().numpy().flatten(),
                             num_classes)
        del gts, output, test_idx, data
    
    predictions = np.concatenate(predictions, axis=0)
    groundtruths = np.concatenate(groundtruths, axis=0)
    if args.dump_imgs:
        assert len(dataset)==predictions.shape[0]
    
        dump_dir = './dump_' + args.dataset + '_' + args.task + '_' + args.method
        os.makedirs(dump_dir, exist_ok=True)
        for i in range(len(dataset)):
            img = skimage.io.imread(dataset.img_paths[i])
            if len(img.shape)==2:
                img = np.stack((img, img, img), axis=2)
            img = skimage.transform.resize(img, (224,336))
            cm = np.argmax(predictions[i,:,:,:], axis=0)
            color_cm = add_color(cm)
            color_cm = skimage.transform.resize(color_cm, (224,336))
            gt = np.asarray(groundtruths[i,:,:], np.uint8)
            color_gt = add_color(gt)
            color_gt = skimage.transform.resize(color_gt, (224,336))
            blend_pred = 0.5 * img + 0.5 * color_cm
            blend_gt = 0.5 * img + 0.5 * color_gt
            blend_pred = np.asarray(blend_pred*255, np.uint8)
            blend_gt = np.asarray(blend_gt*255, np.uint8)
            #skimage.io.imsave(os.path.join(dump_dir, 'img_{:03d}.png'.format(i)), img)
            skimage.io.imsave(os.path.join(dump_dir, 'pred_{:03d}.png'.format(i)), blend_pred)
            skimage.io.imsave(os.path.join(dump_dir, 'gt_{:03d}.png'.format(i)), blend_gt)
            if i > 20:
                break
    
    acc = np.diag(hist).sum() / hist.sum()
    acc_cls = np.diag(hist) / hist.sum(axis=1)
    iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))    
    id2cat = {i: i for i in range(len(iou))}
    iou_false_positive = hist.sum(axis=1) - np.diag(hist)
    iou_false_negative = hist.sum(axis=0) - np.diag(hist)
    iou_true_positive = np.diag(hist)

    print('IoU:')
    print('label_id      label    IoU    Precision Recall TP       FP      FN      Pixel Acc.')
    for idx, i in enumerate(iou):
        idx_string = "{:2d}".format(idx)
        class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else ''
        iou_string = '{:5.1f}'.format(i * 100)
        total_pixels = hist.sum()
        tp = '{:5.1f}'.format(100 * iou_true_positive[idx] / total_pixels)
        fp = '{:5.1f}'.format(100 * iou_false_positive[idx] / total_pixels)
        fn = '{:5.1f}'.format(100 * iou_false_negative[idx] / total_pixels)
        precision = '{:5.1f}'.format(
            iou_true_positive[idx] / (iou_true_positive[idx] + iou_false_positive[idx]))
        recall = '{:5.1f}'.format(
            iou_true_positive[idx] / (iou_true_positive[idx] + iou_false_negative[idx]))
        pixel_acc = '{:5.1f}'.format(100*acc_cls[idx])
        print('{}    {}   {}  {}     {}  {}   {}   {}   {}'.format(
            idx_string, class_name, iou_string, precision, recall, tp, fp, fn, pixel_acc))