def train(args): if not os.path.exists('checkpoints'): os.mkdir('checkpoints') # Setup Augmentations data_aug = transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05)), ]) # Setup Dataloader data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) t_loader = data_loader(data_path, is_transform=True, split='train', version='simplified', img_size=(args.img_rows, args.img_cols), augmentations=data_aug, train_fold_num=args.train_fold_num, num_train_folds=args.num_train_folds, seed=args.seed) v_loader = data_loader(data_path, is_transform=True, split='val', version='simplified', img_size=(args.img_rows, args.img_cols), num_val=args.num_val, seed=args.seed) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=2, shuffle=True, pin_memory=True, drop_last=True) valloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=2, pin_memory=True) # Setup Metrics running_metrics = runningScore(n_classes) # Setup Model v_demision = 300 model = get_model(args.arch, v_demision, use_cbam=args.use_cbam) model.cuda() optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.l_rate, weight_decay=args.weight_decay) if args.num_cycles > 0: len_trainloader = int(5e6) # 4960414 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.num_train_folds * len_trainloader // args.num_cycles, eta_min=args.l_rate * 1e-1) else: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[2, 4, 6, 8], gamma=0.5) start_epoch = 0 if args.resume is not None: if os.path.isfile(args.resume): print("Loading model and optimizer from checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume) model_dict = model.state_dict() if checkpoint.get('model_state', -1) == -1: model_dict.update( convert_state_dict(checkpoint, load_classifier=args.load_classifier)) else: model_dict.update( convert_state_dict(checkpoint['model_state'], load_classifier=args.load_classifier)) print( "Loaded checkpoint '{}' (epoch {}, mapk {:.5f}, top1_acc {:7.3f}, top2_acc {:7.3f} top3_acc {:7.3f})" .format(args.resume, checkpoint['epoch'], checkpoint['mapk'], checkpoint['top1_acc'], checkpoint['top2_acc'], checkpoint['top3_acc'])) model.load_state_dict(model_dict) if checkpoint.get('optimizer_state', None) is not None: optimizer.load_state_dict(checkpoint['optimizer_state']) start_epoch = checkpoint['epoch'] else: print("No checkpoint found at '{}'".format(args.resume)) loss_sum = 0.0 for epoch in range(start_epoch, args.n_epoch): start_train_time = timeit.default_timer() if args.num_cycles == 0: scheduler.step(epoch) model.train() optimizer.zero_grad() for i, (images, labels, recognized, _) in enumerate(trainloader): if args.num_cycles > 0: iter_num = i + epoch * len_trainloader scheduler.step( iter_num % (args.num_train_folds * len_trainloader // args.num_cycles)) # Cosine Annealing with Restarts images = images.cuda() labels = labels.cuda() outputs = model(images) a_loss = Adptive_loss().cuda() loss = a_loss(outputs, labels) loss = loss / float(args.iter_size) # Accumulated gradients loss_sum = loss_sum + loss loss.backward() if (i + 1) % args.print_train_freq == 0: print("Epoch [%d/%d] Iter [%6d/%6d] Loss: %.4f" % (epoch + 1, args.n_epoch, i + 1, len(trainloader), loss_sum)) if (i + 1) % args.iter_size == 0 or i == len(trainloader) - 1: optimizer.step() optimizer.zero_grad() loss_sum = 0.0 elapsed_train_time = timeit.default_timer() - start_train_time print('Training time (epoch {0:5d}): {1:10.5f} seconds'.format( epoch + 1, elapsed_train_time))
def train(args): if not os.path.exists('checkpoints'): os.mkdir('checkpoints') # Setup Augmentations data_aug = transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05)), ]) # Setup Dataloader data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) t_loader = data_loader(data_path, is_transform=True, split='train', version='simplified', img_size=(args.img_rows, args.img_cols), augmentations=data_aug, train_fold_num=args.train_fold_num, num_train_folds=args.num_train_folds, seed=args.seed) v_loader = data_loader(data_path, is_transform=True, split='val', version='simplified', img_size=(args.img_rows, args.img_cols), num_val=args.num_val, seed=args.seed) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=2, shuffle=True, pin_memory=True, drop_last=True) valloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=2, pin_memory=True) # Setup Metrics running_metrics = runningScore(n_classes) # Setup Model # model = get_model(args.arch, n_classes, use_cbam=args.use_cbam) model = torchvision.models.mobilenet_v2(pretrained=True) num_ftrs = model.last_channel model.classifier = nn.Sequential( nn.Dropout(0.2), nn.Linear(num_ftrs, n_classes), ) model.cuda() # Check if model has custom optimizer / loss if hasattr(model, 'optimizer'): optimizer = model.optimizer else: ##optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.l_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate, weight_decay=args.weight_decay) # if args.num_cycles > 0: # len_trainloader = int(5e6) # 4960414 # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_train_folds*len_trainloader//args.num_cycles, eta_min=args.l_rate*1e-1) # else: # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2, 4, 6, 8], gamma=0.5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.5,patience=5,cooldown=5,min_lr=1e-7) if hasattr(model, 'loss'): print('Using custom loss') loss_fn = model.loss else: loss_fn = F.cross_entropy start_epoch = 0 if args.resume is not None: if os.path.isfile(args.resume): print("Loading model and optimizer from checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) model_dict = model.state_dict() if checkpoint.get('model_state', -1) == -1: model_dict.update(convert_state_dict(checkpoint, load_classifier=args.load_classifier)) else: model_dict.update(convert_state_dict(checkpoint['model_state'], load_classifier=args.load_classifier)) print("Loaded checkpoint '{}' (epoch {}, mapk {:.5f}, top1_acc {:7.3f}, top2_acc {:7.3f} top3_acc {:7.3f})" .format(args.resume, checkpoint['epoch'], checkpoint['mapk'], checkpoint['top1_acc'], checkpoint['top2_acc'], checkpoint['top3_acc'])) model.load_state_dict(model_dict) if checkpoint.get('optimizer_state', None) is not None: optimizer.load_state_dict(checkpoint['optimizer_state']) start_epoch = checkpoint['epoch'] else: print("No checkpoint found at '{}'".format(args.resume)) loss_sum = 0.0 for epoch in range(start_epoch, args.n_epoch): start_train_time = timeit.default_timer() model.train() optimizer.zero_grad() for i, (images, labels, recognized, _) in enumerate(trainloader): images = images.cuda() labels = labels.cuda() recognized = recognized.cuda() outputs = model(images) loss = (loss_fn(outputs, labels.view(-1), ignore_index=t_loader.ignore_index, reduction='none') * recognized.view(-1)).mean() # loss = loss / float(args.iter_size) # Accumulated gradients loss_sum = loss_sum + loss loss.backward() if (i+1) % args.print_train_freq == 0: print("Epoch [%d/%d] Iter [%6d/%6d] Loss: %.4f" % (epoch+1, args.n_epoch, i+1, len(trainloader), loss_sum)) if (i+1) % args.iter_size == 0 or i == len(trainloader) - 1: optimizer.step() optimizer.zero_grad() loss_sum = 0.0 mapk_val = AverageMeter() top1_acc_val = AverageMeter() top2_acc_val = AverageMeter() top3_acc_val = AverageMeter() mean_loss_val = AverageMeter() model.eval() with torch.no_grad(): for i_val, (images_val, labels_val, recognized_val, _) in tqdm(enumerate(valloader)): images_val = images_val.cuda() labels_val = labels_val.cuda() recognized_val = recognized_val.cuda() outputs_val = model(images_val) loss_val = (loss_fn(outputs_val, labels_val.view(-1), ignore_index=v_loader.ignore_index, reduction='none') * recognized_val.view(-1)).mean() mean_loss_val.update(loss_val, n=images_val.size(0)) _, pred = outputs_val.topk(k=3, dim=1, largest=True, sorted=True) running_metrics.update(labels_val, pred[:, 0]) acc1, acc2, acc3 = accuracy(outputs_val, labels_val, topk=(1, 2, 3)) top1_acc_val.update(acc1, n=images_val.size(0)) top2_acc_val.update(acc2, n=images_val.size(0)) top3_acc_val.update(acc3, n=images_val.size(0)) mapk_v = mapk(labels_val, pred, k=3) mapk_val.update(mapk_v, n=images_val.size(0)) print('Mean Average Precision (MAP) @ 3: {:.5f}'.format(mapk_val.avg)) print('Top 3 accuracy: {:7.3f} / {:7.3f} / {:7.3f}'.format(top1_acc_val.avg, top2_acc_val.avg, top3_acc_val.avg)) print('Mean val loss: {:.4f}'.format(mean_loss_val.avg)) score, class_iou = running_metrics.get_scores() for k, v in score.items(): print(k, v) #for i in range(n_classes): # print(i, class_iou[i]) scheduler.step(mean_loss_val.avg) state = {'epoch': epoch+1, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'mapk': mapk_val.avg, 'top1_acc': top1_acc_val.avg, 'top2_acc': top2_acc_val.avg, 'top3_acc': top3_acc_val.avg,} torch.save(state, "checkpoints/{}_{}_{}_{}x{}_{}-{}-{}_model.pth".format(args.arch, args.dataset, epoch+1, args.img_rows, args.img_cols, args.train_fold_num, args.num_train_folds, args.num_val)) running_metrics.reset() mapk_val.reset() top1_acc_val.reset() top2_acc_val.reset() top3_acc_val.reset() mean_loss_val.reset() elapsed_train_time = timeit.default_timer() - start_train_time print('Training time (epoch {0:5d}): {1:10.5f} seconds'.format(epoch+1, elapsed_train_time))
def merge(args): if not os.path.exists(args.root_results): os.makedirs(args.root_results) # Setup Dataloader data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) loader = data_loader(data_path, split=args.split, transforms=None, fold_num=0, num_folds=1, no_gt=args.no_gt, seed=args.seed, no_load_images=True) n_classes = loader.n_classes testloader = data.DataLoader(loader, batch_size=args.batch_size)#, num_workers=2, pin_memory=True) avg_y_prob = np.zeros((loader.__len__(), 1, 1024, 1024), dtype=np.float32) avg_y_pred_sum = np.zeros((loader.__len__(),), dtype=np.int32) fold_list = [] for prob_file_name in glob.glob(os.path.join(args.root_results, '*.npy')): prob = np.load(prob_file_name, mmap_mode='r') for i in range(loader.__len__()): avg_y_prob[i, :, :, :] += prob[i, :, :, :] fold_list.append(prob_file_name) print(prob_file_name) avg_y_prob = avg_y_prob / len(fold_list) ##avgprob_file_name = 'prob_{}_avg'.format(len(fold_list)) ##np.save(os.path.join(args.root_results, '{}.npy'.format(avgprob_file_name)), avg_y_prob) avg_y_pred = (avg_y_prob > args.thresh).astype(np.int) avg_y_pred_sum = avg_y_pred.sum(3).sum(2).sum(1) avg_y_pred_sum_argsorted = np.argsort(avg_y_pred_sum)[::-1] pruned_idx = int(avg_y_pred_sum_argsorted.shape[0]*args.non_empty_ratio) mask_sum_thresh = int(avg_y_pred_sum[avg_y_pred_sum_argsorted[pruned_idx]]) if pruned_idx < avg_y_pred_sum_argsorted.shape[0] else 0 running_metrics = runningScore(n_classes=2, weight_acc_non_empty=args.weight_acc_non_empty) pred_dict = collections.OrderedDict() num_non_empty_masks = 0 for i, (_, labels, names) in tqdm(enumerate(testloader)): labels = labels.cuda() prob = avg_y_prob[i*args.batch_size:i*args.batch_size+labels.size(0), :, :, :] pred = (prob > args.thresh).astype(np.int) pred = torch.from_numpy(pred).long().cuda() pred_sum = pred.sum(3).sum(2).sum(1) for k in range(labels.size(0)): if pred_sum[k] > mask_sum_thresh: num_non_empty_masks += 1 else: pred[k, :, :, :] = torch.zeros_like(pred[k, :, :, :]) if args.only_non_empty: pred[k, :, 0, 0] = 1 if not args.no_gt: running_metrics.update(labels.long(), pred.long()) for k in range(labels.size(0)): name = names[0][k] if pred_dict.get(name, None) is None: mask = pred[k, 0, :, :].cpu().numpy() rle = loader.mask2rle(mask) pred_dict[name] = rle print('# non-empty masks: {:5d} (non_empty_ratio: {:.5f} / mask_sum_thresh: {:6d})'.format(num_non_empty_masks, args.non_empty_ratio, mask_sum_thresh)) if not args.no_gt: dice, dice_empty, dice_non_empty, miou, acc, acc_empty, acc_non_empty = running_metrics.get_scores() print('Dice (per image): {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.format(dice, dice_empty, dice_non_empty)) print('Classification accuracy: {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.format(acc, acc_empty, acc_non_empty)) print('Overall mIoU: {:.5f}'.format(miou)) running_metrics.reset() # Create submission csv_file_name = 'merged_{}_{}_{}_{}'.format(args.split, len(fold_list), args.thresh, args.non_empty_ratio) sub = pd.DataFrame.from_dict(pred_dict, orient='index') sub.index.names = ['ImageId'] sub.columns = ['EncodedPixels'] sub.to_csv(os.path.join(args.root_results, '{}.csv'.format(csv_file_name)))
def test(args): model_file_name = os.path.split(args.model_path)[1] model_name = model_file_name[:model_file_name.find('_')] # Setup Dataloader data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) loader = data_loader(data_path, split=args.split, is_transform=True, img_size=(args.img_rows, args.img_cols), no_gt=args.no_gt, seed=args.seed) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) n_classes = loader.n_classes testloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=4, pin_memory=True) # Setup Model model = torchvision.models.mobilenet_v2(pretrained=True) num_ftrs = model.last_channel model.classifier = nn.Sequential( nn.Dropout(0.2), nn.Linear(num_ftrs, n_classes), ) model.cuda() checkpoint = torch.load(args.model_path) state = convert_state_dict(checkpoint['model_state']) model_dict = model.state_dict() model_dict.update(state) model.load_state_dict(model_dict) print( "Loaded checkpoint '{}' (epoch {}, mapk {:.5f}, top1_acc {:7.3f}, top2_acc {:7.3f} top3_acc {:7.3f})" .format(args.model_path, checkpoint['epoch'], checkpoint['mapk'], checkpoint['top1_acc'], checkpoint['top2_acc'], checkpoint['top3_acc'])) running_metrics = runningScore(n_classes) pred_dict = collections.OrderedDict() mapk = AverageMeter() model.eval() with torch.no_grad(): for i, (images, labels, _, names) in tqdm(enumerate(testloader)): plt.imshow((images[0].numpy().transpose(1, 2, 0) - np.min(images[0].numpy().transpose(1, 2, 0))) / (np.max(images[0].numpy().transpose(1, 2, 0) - np.min(images[0].numpy().transpose(1, 2, 0))))) plt.show() images = images.cuda() if args.tta: images_flip = flip(images, dim=3) outputs = model(images) if args.tta: outputs_flip = model(images_flip) prob = F.softmax(outputs, dim=1) if args.tta: prob_flip = F.softmax(outputs_flip, dim=1) prob = (prob + prob_flip) / 2.0 _, pred = prob.topk(k=3, dim=1, largest=True, sorted=True) for k in range(images.size(0)): pred_dict[int(names[0][k])] = loader.encode_pred_name( pred[k, :]) if not args.no_gt: running_metrics.update(labels, pred) mapk_val = mapk(labels, pred, k=3) mapk.update(mapk_val, n=images.size(0)) print('Mean Average Precision (MAP) @ 3: {:.5f}'.format(mapk.avg)) if not args.no_gt: print('Mean Average Precision (MAP) @ 3: {:.5f}'.format(mapk.avg)) score, class_iou = running_metrics.get_scores() for k, v in score.items(): print(k, v) #for i in range(n_classes): # print(i, class_iou[i]) running_metrics.reset() mapk.reset() # Create submission sub = pd.DataFrame.from_dict(pred_dict, orient='index') sub.index.names = ['key_id'] sub.columns = ['word'] sub.to_csv('{}_{}x{}.csv'.format(args.split, args.img_rows, args.img_cols))
def train(args): if not os.path.exists('checkpoints'): os.mkdir('checkpoints') # Setup Augmentations & Transforms rgb_mean = [122.7717 / 255., 115.9465 / 255., 102.9801 / 255.] if args.norm_type == 'gn' and args.load_pretrained else [ 0.485, 0.456, 0.406 ] rgb_std = [1. / 255., 1. / 255., 1. / 255.] if args.norm_type == 'gn' and args.load_pretrained else [ 0.229, 0.224, 0.225 ] data_trans = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(size=(args.img_rows, args.img_cols)), transforms.ToTensor(), transforms.Normalize(mean=rgb_mean, std=rgb_std), ]) # Setup Dataloader data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) t_loader = data_loader(data_path, transforms=data_trans, in_channels=args.in_channels, split='train', augmentations=True, fold_num=args.fold_num, num_folds=args.num_folds, only_non_empty=args.only_non_empty, seed=args.seed, mask_dilation_size=args.mask_dilation_size) v_loader = data_loader(data_path, transforms=data_trans, in_channels=args.in_channels, split='val', fold_num=args.fold_num, num_folds=args.num_folds, only_non_empty=args.only_non_empty, seed=args.seed) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=2, pin_memory=True, shuffle=args.only_non_empty, drop_last=args.only_non_empty) valloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=2, pin_memory=True) # Setup Model model = get_model(args.arch, n_classes=1, in_channels=args.in_channels, norm_type=args.norm_type, load_pretrained=args.load_pretrained, use_cbam=args.use_cbam) model.to(torch.device(args.device)) running_metrics = runningScore( n_classes=2, weight_acc_non_empty=args.weight_acc_non_empty, device=args.device) # Check if model has custom optimizer / loss if hasattr(model, 'optimizer'): optimizer = model.optimizer else: warmup_iter = int(args.n_iter * 5. / 100.) milestones = [ int(args.n_iter * 30. / 100.) - warmup_iter, int(args.n_iter * 60. / 100.) - warmup_iter, int(args.n_iter * 90. / 100.) - warmup_iter ] # [30, 60, 90] gamma = 0.5 #0.1 if args.optimizer_type == 'sgd': optimizer = torch.optim.SGD(group_weight(model), lr=args.l_rate, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optimizer_type == 'adam': optimizer = torch.optim.Adam(group_weight(model), lr=args.l_rate, weight_decay=args.weight_decay) else: #if args.optimizer_type == 'radam': optimizer = RAdam(group_weight(model), lr=args.l_rate, weight_decay=args.weight_decay) if args.num_cycles > 0: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=(args.n_iter - warmup_iter) // args.num_cycles, eta_min=args.l_rate * 0.1) else: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=milestones, gamma=gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, total_epoch=warmup_iter, min_lr_mul=0.1, after_scheduler=scheduler) start_iter = 0 if args.resume is not None: if os.path.isfile(args.resume): print("Loading model and optimizer from checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device( args.device)) #, encoding="latin1") model_dict = model.state_dict() if checkpoint.get('model_state', None) is not None: model_dict.update(convert_state_dict( checkpoint['model_state'])) else: model_dict.update(convert_state_dict(checkpoint)) start_iter = checkpoint.get('iter', -1) dice_val = checkpoint.get('dice', -1) wacc_val = checkpoint.get('wacc', -1) print("Loaded checkpoint '{}' (iter {}, dice {:.5f}, wAcc {:.5f})". format(args.resume, start_iter, dice_val, wacc_val)) model.load_state_dict(model_dict) if checkpoint.get('optimizer_state', None) is not None: optimizer.load_state_dict(checkpoint['optimizer_state']) del model_dict del checkpoint torch.cuda.empty_cache() else: print("No checkpoint found at '{}'".format(args.resume)) start_iter = args.start_iter if args.start_iter >= 0 else start_iter scale_weight = torch.tensor([1.0, 0.4, 0.4, 0.4]).to(torch.device(args.device)) dice_weight = [args.dice_weight0, args.dice_weight1] lv_margin = [args.lv_margin0, args.lv_margin1] total_loss_sum = 0.0 ms_loss_sum = 0.0 cls_loss_sum = 0.0 t_loader.__gen_batchs__(args.batch_size, ratio=args.ratio) trainloader_iter = iter(trainloader) optimizer.zero_grad() start_train_time = timeit.default_timer() elapsed_train_time = 0.0 best_dice = -100.0 best_wacc = -100.0 for i in range(start_iter, args.n_iter): #""" model.train() if i % args.iter_size == 0: if args.num_cycles == 0: scheduler_warmup.step(i) else: scheduler_warmup.step(i // args.num_cycles) try: images, labels, _ = next(trainloader_iter) except: t_loader.__gen_batchs__(args.batch_size, ratio=args.ratio) trainloader_iter = iter(trainloader) images, labels, _ = next(trainloader_iter) images = images.to(torch.device(args.device)) labels = labels.to(torch.device(args.device)) outputs, outputs_gap = model(images) labels_gap = torch.where( labels.sum(3, keepdim=True).sum(2, keepdim=True) > 0, torch.ones(labels.size(0), 1, 1, 1).to(torch.device(args.device)), torch.zeros(labels.size(0), 1, 1, 1).to(torch.device(args.device))) cls_loss = F.binary_cross_entropy_with_logits( outputs_gap, labels_gap) if args.lambda_cls > 0 else torch.tensor(0.0).to( labels.device) ms_loss = multi_scale_loss(outputs, labels, scale_weight=scale_weight, reduction='mean', alpha=args.alpha, gamma=args.gamma, dice_weight=dice_weight, lv_margin=lv_margin, lambda_fl=args.lambda_fl, lambda_dc=args.lambda_dc, lambda_lv=args.lambda_lv) total_loss = ms_loss + args.lambda_cls * cls_loss total_loss = total_loss / float(args.iter_size) total_loss.backward() total_loss_sum = total_loss_sum + total_loss.item() ms_loss_sum = ms_loss_sum + ms_loss.item() cls_loss_sum = cls_loss_sum + cls_loss.item() if (i + 1) % args.print_train_freq == 0: print("Iter [%7d/%7d] Loss: %7.4f (MS: %7.4f / CLS: %7.4f)" % (i + 1, args.n_iter, total_loss_sum, ms_loss_sum, cls_loss_sum)) if (i + 1) % args.iter_size == 0: optimizer.step() optimizer.zero_grad() total_loss_sum = 0.0 ms_loss_sum = 0.0 cls_loss_sum = 0.0 #""" if args.eval_freq > 0 and (i + 1) % args.eval_freq == 0: state = { 'iter': i + 1, 'model_state': model.state_dict(), } #'optimizer_state': optimizer.state_dict(),} if (i + 1) % int(args.eval_freq / args.save_freq) == 0: torch.save( state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format( args.arch, args.dataset, i + 1, args.img_rows, args.img_cols, args.fold_num, args.num_folds)) dice_val = 0.0 thresh = 0.5 mask_sum_thresh = 0 mean_loss_val = AverageMeter() model.eval() with torch.no_grad(): for i_val, (images_val, labels_val, _) in enumerate(valloader): images_val = images_val.to(torch.device(args.device)) labels_val = labels_val.to(torch.device(args.device)) outputs_val, outputs_gap_val = model(images_val) pred_val = (F.sigmoid(outputs_val if not isinstance( outputs_val, tuple) else outputs_val[0]) > thresh).long() #outputs_val.max(1)[1] pred_val_sum = pred_val.sum(3).sum(2).sum(1) for k in range(labels_val.size(0)): if pred_val_sum[k] < mask_sum_thresh: pred_val[k, :, :, :] = torch.zeros_like( pred_val[k, :, :, :]) labels_gap_val = torch.where( labels_val.sum(3, keepdim=True).sum(2, keepdim=True) > 0, torch.ones(labels_val.size(0), 1, 1, 1).to(torch.device(args.device)), torch.zeros(labels_val.size(0), 1, 1, 1).to(torch.device(args.device))) cls_loss_val = F.binary_cross_entropy_with_logits( outputs_gap_val, labels_gap_val ) if args.lambda_cls > 0 else torch.tensor(0.0).to( labels_val.device) ms_loss_val = multi_scale_loss(outputs_val, labels_val, scale_weight=scale_weight, reduction='mean', alpha=args.alpha, gamma=args.gamma, dice_weight=dice_weight, lv_margin=lv_margin, lambda_fl=args.lambda_fl, lambda_dc=args.lambda_dc, lambda_lv=args.lambda_lv) loss_val = ms_loss_val + args.lambda_cls * cls_loss_val mean_loss_val.update(loss_val.item(), n=labels_val.size(0)) running_metrics.update(labels_val.long(), pred_val.long()) dice_val, dice_empty_val, dice_non_empty_val, miou_val, wacc_val, acc_empty_val, acc_non_empty_val = running_metrics.get_scores( ) print( 'Dice (per image): {:.5f} (empty: {:.5f} / non-empty: {:.5f})'. format(dice_val, dice_empty_val, dice_non_empty_val)) print('wAcc: {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.format( wacc_val, acc_empty_val, acc_non_empty_val)) print('Overall mIoU: {:.5f}'.format(miou_val)) print('Mean val loss: {:.4f}'.format(mean_loss_val.avg)) state['dice'] = dice_val state['wacc'] = wacc_val state['miou'] = miou_val running_metrics.reset() mean_loss_val.reset() if (i + 1) % int(args.eval_freq / args.save_freq) == 0: torch.save( state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format( args.arch, args.dataset, i + 1, args.img_rows, args.img_cols, args.fold_num, args.num_folds)) if best_dice <= dice_val: best_dice = dice_val torch.save( state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format( args.arch, args.dataset, 'best-dice', args.img_rows, args.img_cols, args.fold_num, args.num_folds)) if best_wacc <= wacc_val: best_wacc = wacc_val torch.save( state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format( args.arch, args.dataset, 'best-wacc', args.img_rows, args.img_cols, args.fold_num, args.num_folds)) elapsed_train_time = timeit.default_timer() - start_train_time print('Training time (iter {0:5d}): {1:10.5f} seconds'.format( i + 1, elapsed_train_time)) if args.saving_last_time > 0 and (i + 1) % args.iter_size == 0 and ( timeit.default_timer() - start_train_time) > args.saving_last_time: state = { 'iter': i + 1, 'model_state': model.state_dict(), #} 'optimizer_state': optimizer.state_dict(), } torch.save( state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format( args.arch, args.dataset, i + 1, args.img_rows, args.img_cols, args.fold_num, args.num_folds)) return print('best_dice: {:.5f}; best_wacc: {:.5f}'.format(best_dice, best_wacc))
def test(args): if not os.path.exists(args.root_results): os.makedirs(args.root_results) model_file_name = os.path.split(args.model_path)[1] model_name = model_file_name[:model_file_name.find('_')] # Setup Transforms rgb_mean = [122.7717 / 255., 115.9465 / 255., 102.9801 / 255.] if args.norm_type == 'gn' and args.load_pretrained else [ 0.485, 0.456, 0.406 ] rgb_std = [1. / 255., 1. / 255., 1. / 255.] if args.norm_type == 'gn' and args.load_pretrained else [ 0.229, 0.224, 0.225 ] data_trans = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(size=(args.img_rows, args.img_cols)), transforms.ToTensor(), transforms.Normalize(mean=rgb_mean, std=rgb_std), ]) # Setup Dataloader data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) loader = data_loader(data_path, split=args.split, in_channels=args.in_channels, transforms=data_trans, fold_num=args.fold_num, num_folds=args.num_folds, no_gt=args.no_gt, seed=args.seed) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) n_classes = loader.n_classes testloader = data.DataLoader( loader, batch_size=args.batch_size) #, num_workers=2, pin_memory=True) # Setup Model model = get_model(model_name, n_classes=1, in_channels=args.in_channels, norm_type=args.norm_type, use_cbam=args.use_cbam) model.cuda() checkpoint = torch.load(args.model_path) #, encoding="latin1") state = convert_state_dict(checkpoint['model_state']) model_dict = model.state_dict() model_dict.update(state) model.load_state_dict(model_dict) saved_iter = checkpoint.get('iter', -1) dice_val = checkpoint.get('dice', -1) wacc_val = checkpoint.get('wacc', -1) print("Loaded checkpoint '{}' (iter {}, dice {:.5f}, wAcc {:.5f})".format( args.model_path, saved_iter, dice_val, wacc_val)) running_metrics = runningScore( n_classes=2, weight_acc_non_empty=args.weight_acc_non_empty) y_prob = np.zeros((loader.__len__(), 1, 1024, 1024), dtype=np.float32) y_pred_sum = np.zeros((loader.__len__(), ), dtype=np.int32) pred_dict = collections.OrderedDict() num_non_empty_masks = 0 model.eval() with torch.no_grad(): for i, (images, labels, _) in tqdm(enumerate(testloader)): images = images.cuda() labels = labels.cuda() if args.tta: bs, c, h, w = images.size() images = torch.cat( [images, torch.flip(images, dims=[3])], dim=0) # hflip outputs = model(images, return_aux=False) prob = F.sigmoid(outputs) if args.tta: prob = prob.view(-1, bs, 1, h, w) prob[1, :, :, :, :] = torch.flip(prob[1, :, :, :, :], dims=[3]) prob = prob.mean(0) pred = (prob > args.thresh).long() pred_sum = pred.sum(3).sum(2).sum(1) y_prob[i * args.batch_size:i * args.batch_size + labels.size(0), :, :, :] = prob.cpu().numpy() y_pred_sum[i * args.batch_size:i * args.batch_size + labels.size(0)] = pred_sum.cpu().numpy() y_pred_sum_argsorted = np.argsort(y_pred_sum)[::-1] pruned_idx = int(y_pred_sum_argsorted.shape[0] * args.non_empty_ratio) mask_sum_thresh = int( y_pred_sum[y_pred_sum_argsorted[pruned_idx]] ) if pruned_idx < y_pred_sum_argsorted.shape[0] else 0 for i, (_, labels, names) in tqdm(enumerate(testloader)): labels = labels.cuda() prob = torch.from_numpy( y_prob[i * args.batch_size:i * args.batch_size + labels.size(0), :, :, :]).float().cuda() pred = (prob > args.thresh).long() pred_sum = pred.sum(3).sum(2).sum(1) for k in range(labels.size(0)): if pred_sum[k] > mask_sum_thresh: num_non_empty_masks += 1 else: pred[k, :, :, :] = torch.zeros_like(pred[k, :, :, :]) if args.only_non_empty: pred[k, :, 0, 0] = 1 if not args.no_gt: running_metrics.update(labels.long(), pred.long()) """ if args.split == 'test': for k in range(labels.size(0)): name = names[0][k] if pred_dict.get(name, None) is None: mask = pred[k, 0, :, :].cpu().numpy() rle = loader.mask2rle(mask) pred_dict[name] = rle #""" print( '# non-empty masks: {:5d} (non_empty_ratio: {:.5f} / mask_sum_thresh: {:6d})' .format(num_non_empty_masks, args.non_empty_ratio, mask_sum_thresh)) if not args.no_gt: dice, dice_empty, dice_non_empty, miou, wacc, acc_empty, acc_non_empty = running_metrics.get_scores( ) print('Dice (per image): {:.5f} (empty: {:.5f} / non-empty: {:.5f})'. format(dice, dice_empty, dice_non_empty)) print('wAcc: {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.format( wacc, acc_empty, acc_non_empty)) print('Overall mIoU: {:.5f}'.format(miou)) running_metrics.reset() if args.split == 'test': fold_num, num_folds = model_file_name.split('_')[4].split('-') prob_file_name = 'prob-{}_{}x{}_{}_{}_{}-{}'.format( args.split, args.img_rows, args.img_cols, model_name, saved_iter, fold_num, num_folds) np.save( os.path.join(args.root_results, '{}.npy'.format(prob_file_name)), y_prob) """