def main(): """ Main Function """ # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer assert_and_infer_cfg(args) writer = prep_experiment(args, parser) train_loader, val_loader, train_obj = datasets.setup_loaders(args) criterion, criterion_val = loss.get_loss(args) net = network.get_net(args, criterion) optim, scheduler = optimizer.get_optimizer(args, net) if args.fix_bn: net.apply(set_bn_eval) print("Fix bn for finetuning") if args.fp16: net, optim = amp.initialize(net, optim, opt_level="O1") net = network.wrap_network_in_dataparallel(net, args.apex) if args.snapshot: optimizer.load_weights(net, optim, args.snapshot, args.restore_optimizer) if args.evaluateF: assert args.snapshot is not None, "must load weights for evaluation" evaluate(val_loader, net, args) return torch.cuda.empty_cache() # Main Loop for epoch in range(args.start_epoch, args.max_epoch): # Update EPOCH CTR cfg.immutable(False) cfg.EPOCH = epoch cfg.immutable(True) scheduler.step() train(train_loader, net, optim, epoch, writer) if args.apex: train_loader.sampler.set_epoch(epoch + 1) validate(val_loader, net, criterion_val, optim, epoch, writer) if args.class_uniform_pct: if epoch >= args.max_cu_epoch: train_obj.build_epoch(cut=True) if args.apex: train_loader.sampler.set_num_samples() else: train_obj.build_epoch()
def main(): """ Main Function """ # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer assert_and_infer_cfg(args) prep_experiment(args, parser) writer = None _, _, _, extra_val_loaders, _ = datasets.setup_loaders(args) criterion, criterion_val = loss.get_loss(args) criterion_aux = loss.get_loss_aux(args) net = network.get_net(args, criterion, criterion_aux) optim, scheduler = optimizer.get_optimizer(args, net) net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) net = network.warp_network_in_dataparallel(net, args.local_rank) epoch = 0 i = 0 if args.snapshot: epoch, mean_iu = optimizer.load_weights(net, optim, scheduler, args.snapshot, args.restore_optimizer) print("#### iteration", i) torch.cuda.empty_cache() # Main Loop # for epoch in range(args.start_epoch, args.max_epoch): for dataset, val_loader in extra_val_loaders.items(): print("Extra validating... This won't save pth file") validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False)
def main(): """ Main Function """ # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer assert_and_infer_cfg(args) writer = prep_experiment(args, parser) train_loader, val_loaders, train_obj, extra_val_loaders, covstat_val_loaders = datasets.setup_loaders( args) criterion, criterion_val = loss.get_loss(args) criterion_aux = loss.get_loss_aux(args) net = network.get_net(args, criterion, criterion_aux) optim, scheduler = optimizer.get_optimizer(args, net) net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) net = network.warp_network_in_dataparallel(net, args.local_rank) epoch = 0 i = 0 if args.snapshot: epoch, mean_iu = optimizer.load_weights(net, optim, scheduler, args.snapshot, args.restore_optimizer) if args.restore_optimizer is True: iter_per_epoch = len(train_loader) i = iter_per_epoch * epoch else: epoch = 0 print("#### iteration", i) torch.cuda.empty_cache() # Main Loop # for epoch in range(args.start_epoch, args.max_epoch): while i < args.max_iter: # Update EPOCH CTR cfg.immutable(False) cfg.ITER = i cfg.immutable(True) i = train(train_loader, net, optim, epoch, writer, scheduler, args.max_iter) train_loader.sampler.set_epoch(epoch + 1) if (args.dynamic and args.use_isw and epoch % (args.cov_stat_epoch + 1) == args.cov_stat_epoch) \ or (args.dynamic is False and args.use_isw and epoch == args.cov_stat_epoch): net.module.reset_mask_matrix() for trial in range(args.trials): for dataset, val_loader in covstat_val_loaders.items( ): # For get the statistics of covariance validate_for_cov_stat(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False) net.module.set_mask_matrix() if args.local_rank == 0: print("Saving pth file...") evaluate_eval(args, net, optim, scheduler, None, None, [], writer, epoch, "None", None, i, save_pth=True) if args.class_uniform_pct: if epoch >= args.max_cu_epoch: train_obj.build_epoch(cut=True) train_loader.sampler.set_num_samples() else: train_obj.build_epoch() epoch += 1 # Validation after epochs if len(val_loaders) == 1: # Run validation only one time - To save models for dataset, val_loader in val_loaders.items(): validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i) else: if args.local_rank == 0: print("Saving pth file...") evaluate_eval(args, net, optim, scheduler, None, None, [], writer, epoch, "None", None, i, save_pth=True) for dataset, val_loader in extra_val_loaders.items(): print("Extra validating... This won't save pth file") validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False)
def main(): """ Main Function """ # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer # args2 = copy.deepcopy(args) assert_and_infer_cfg(args) # assert_and_infer_cfg(args2) # args2.dataset = 'kitti_trav' # print(args.dataset) # print(args2.dataset) writer = prep_experiment(args, parser) # writer = prep_experiment(args2, parser) # Dataset train_loader, val_loader, train_obj = datasets.setup_loaders(args) # train_loader2, val_loader2, train_obj2 = datasets.setup_loaders(args2) criterion, criterion_val = loss.get_loss(args, data_type='semantic') criterion2, criterion_val2 = loss.get_loss(args, data_type='trav') net = network.get_net(args, criterion, criterion2) #parameters list # param1_lists = list(net.mod1.parameters()) + list(net.mod2.parameters()) + list(net.mod3.parameters()) + list(net.mod4.parameters()) + list(net.mod5.parameters()) + list(net.mod6.parameters()) + list(net.mod7.parameters()) + list(net.pool2.parameters()) + list(net.pool3.parameters()) + list(net.aspp.parameters()) + list(net.bot_fine.parameters()) + list(net.bot_aspp.parameters()) + list(net.final.parameters()) + [log_sigma_A] # param2_lists = list(net.mod1.parameters()) + list(net.mod2.parameters()) + list(net.mod3.parameters()) + list(net.mod4.parameters()) + list(net.mod5.parameters()) + list(net.mod6.parameters()) + list(net.mod7.parameters()) + list(net.pool2.parameters()) + list(net.pool3.parameters()) + list(net.aspp.parameters()) + list(net.bot_fine.parameters()) + list(net.bot_aspp.parameters()) + list(net.final2.parameters()) + [log_sigma_B] #optimizers optim, scheduler = optimizer.get_optimizer(args, net) # optim2, scheduler2 = optimizer.get_optimizer(args, param2_lists) if args.fp16: net, optim = amp.initialize(net, optim, opt_level="O1") net = network.wrap_network_in_dataparallel(net, args.apex) if args.snapshot: optimizer.load_weights(net, optim, args.snapshot, args.snapshot2, args.restore_optimizer) # optimizer.load_weights(net, optim2, # args.snapshot, args.snapshot2, args.restore_optimizer) torch.cuda.empty_cache() # Main Loop for epoch in range(args.start_epoch, args.max_epoch): # Update EPOCH CTR cfg.immutable(False) cfg.EPOCH = epoch cfg.immutable(True) scheduler.step() train(train_loader, net, optim, epoch, writer) if args.apex: train_loader.sampler.set_epoch(epoch + 1) # train_loader2.sampler.set_epoch(epoch + 1) validate(val_loader, net, criterion_val, criterion_val2, optim, epoch, writer) if args.class_uniform_pct: if epoch >= args.max_cu_epoch: train_obj.build_epoch(cut=True) # train_obj2.build_epoch(cut=True) if args.apex: train_loader.sampler.set_num_samples() # train_loader2.sampler.set_num_samples() else: train_obj.build_epoch()
def main(): """ Main Function """ prep_experiment(args, parser) if args.dataset=='robotic_instrument': from datasets.robotic_instrument import get_dataloader if args.task=='binary': args.num_classes = 2 args.ignore_label = 2 args.cls_wt = [1.0, 1.0] elif args.task=='parts': args.num_classes = 5 args.ignore_label = 5 args.cls_wt = [0.1, 1.0, 1.0, 1.0, 1.0] elif args.task=='type': args.num_classes = 8 args.ignore_label = 8 args.cls_wt = [0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] train_loader, val_loader = get_dataloader(args.task, batch_size=args.batch_size) net_param = {"class_num" : args.num_classes, "in_chns" : 3, "bilinear" : True, "feature_chns": [16, 32, 64, 128, 256], "dropout" : [0.0, 0.0, 0.3, 0.4, 0.5]} elif args.dataset=='covid19_lesion': from datasets.covid19_lesion import get_dataloader train_loader, val_loader = get_dataloader(args.task, batch_size=args.batch_size) args.ignore_label = 2 args.num_classes = 2 net_param = {"class_num" : args.num_classes, "in_chns" : 1, "bilinear" : True, "feature_chns": [16, 32, 64, 128, 256], "dropout" : [0.0, 0.0, 0.3, 0.4, 0.5]} else: raise NotImplementedError('The dataset is not supported.') if args.margin_loss: args.margins = loss.calculate_margins(train_loader, args.num_classes) criterion = loss.get_loss(args) net = COPLENet(net_param).cuda() optim, scheduler = optimizer.get_optimizer(args, net) epoch = 0 i = 0 if args.snapshot: epoch, mean_iou = optimizer.load_weights(net, optim, scheduler, args.snapshot, args.restore_optimizer) if args.restore_optimizer: iter_per_epoch = len(train_loader) i = iter_per_epoch * epoch else: epoch = 0 print("#### iteration", i) torch.cuda.empty_cache() while i < args.max_iter and epoch < args.max_epoch: i = train(train_loader, net, criterion, optim, epoch, scheduler, args.max_iter) #iou_train = iou_on_trainset(train_loader, net, criterion) #logging.info('Mean IoU on training set: %f' % iou_train) val_loss, per_cls_iou = validate(val_loader, net, criterion, optim, scheduler, epoch+1, i) epoch += 1
def main(): if args.dataset=='robotic_instrument': from datasets.robotic_instrument import get_testloader, RoboticInstrument if args.task=='binary': num_classes = 2 elif args.task=='parts': num_classes = 5 elif args.task=='type': num_classes = 8 dataset = RoboticInstrument(args.task, 'test') test_loader = get_testloader(args.task, batch_size=args.batch_size) net_param = {"class_num" : num_classes, "in_chns" : 3, "bilinear" : True, "feature_chns": [16, 32, 64, 128, 256], "dropout" : [0.0, 0.0, 0.3, 0.4, 0.5]} elif args.dataset=='covid19_lesion': from datasets.covid19_lesion import get_testloader, Covid19Dataset dataset = Covid19Dataset(args.task, 'test') test_loader = get_testloader(args.task, batch_size=args.batch_size) num_classes = 2 net_param = {"class_num" : num_classes, "in_chns" : 1, "bilinear" : True, "feature_chns": [16, 32, 64, 128, 256], "dropout" : [0.0, 0.0, 0.3, 0.4, 0.5]} else: raise NotImplementedError('The dataset is not supported.') net = COPLENet(net_param).cuda() optimizer.load_weights(net, None, None, args.snapshot, False) torch.cuda.empty_cache() net.eval() hist = 0 predictions = [] groundtruths = [] for test_idx, data in enumerate(test_loader): inputs, gts = data assert len(inputs.size()) == 4 and len(gts.size()) == 3 assert inputs.size()[2:] == gts.size()[1:] inputs, gts = inputs.cuda(), gts.cuda() with torch.no_grad(): output = net(inputs) del inputs assert output.size()[2:] == gts.size()[1:] assert output.size()[1] == num_classes prediction = output.data.max(1)[1].cpu() predictions.append(output.data.cpu().numpy()) groundtruths.append(gts.cpu().numpy()) hist += fast_hist(prediction.numpy().flatten(), gts.cpu().numpy().flatten(), num_classes) del gts, output, test_idx, data predictions = np.concatenate(predictions, axis=0) groundtruths = np.concatenate(groundtruths, axis=0) if args.dump_imgs: assert len(dataset)==predictions.shape[0] dump_dir = './dump_' + args.dataset + '_' + args.task + '_' + args.method os.makedirs(dump_dir, exist_ok=True) for i in range(len(dataset)): img = skimage.io.imread(dataset.img_paths[i]) if len(img.shape)==2: img = np.stack((img, img, img), axis=2) img = skimage.transform.resize(img, (224,336)) cm = np.argmax(predictions[i,:,:,:], axis=0) color_cm = add_color(cm) color_cm = skimage.transform.resize(color_cm, (224,336)) gt = np.asarray(groundtruths[i,:,:], np.uint8) color_gt = add_color(gt) color_gt = skimage.transform.resize(color_gt, (224,336)) blend_pred = 0.5 * img + 0.5 * color_cm blend_gt = 0.5 * img + 0.5 * color_gt blend_pred = np.asarray(blend_pred*255, np.uint8) blend_gt = np.asarray(blend_gt*255, np.uint8) #skimage.io.imsave(os.path.join(dump_dir, 'img_{:03d}.png'.format(i)), img) skimage.io.imsave(os.path.join(dump_dir, 'pred_{:03d}.png'.format(i)), blend_pred) skimage.io.imsave(os.path.join(dump_dir, 'gt_{:03d}.png'.format(i)), blend_gt) if i > 20: break acc = np.diag(hist).sum() / hist.sum() acc_cls = np.diag(hist) / hist.sum(axis=1) iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) id2cat = {i: i for i in range(len(iou))} iou_false_positive = hist.sum(axis=1) - np.diag(hist) iou_false_negative = hist.sum(axis=0) - np.diag(hist) iou_true_positive = np.diag(hist) print('IoU:') print('label_id label IoU Precision Recall TP FP FN Pixel Acc.') for idx, i in enumerate(iou): idx_string = "{:2d}".format(idx) class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else '' iou_string = '{:5.1f}'.format(i * 100) total_pixels = hist.sum() tp = '{:5.1f}'.format(100 * iou_true_positive[idx] / total_pixels) fp = '{:5.1f}'.format(100 * iou_false_positive[idx] / total_pixels) fn = '{:5.1f}'.format(100 * iou_false_negative[idx] / total_pixels) precision = '{:5.1f}'.format( iou_true_positive[idx] / (iou_true_positive[idx] + iou_false_positive[idx])) recall = '{:5.1f}'.format( iou_true_positive[idx] / (iou_true_positive[idx] + iou_false_negative[idx])) pixel_acc = '{:5.1f}'.format(100*acc_cls[idx]) print('{} {} {} {} {} {} {} {} {}'.format( idx_string, class_name, iou_string, precision, recall, tp, fp, fn, pixel_acc))