def prepareDataset(rootpath, trainlist, vallist, mean, std): # prepare dataset transform before training trans = transform.Compose([ transform.RandScale([0.5,2]), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop(inputHW,crop_type='rand',padding=mean, ignore_label=255), transform.ToTensor(), transform.Normalize(mean=mean,std=std) ]) # val transform valTrans = transform.Compose([ transform.Crop(inputHW, crop_type='center', padding=mean, ignore_label=255), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) # training data trainData = dataset.SemData(split='train', data_root=rootpath, data_list=trainlist, transform=trans) trainDataLoader = torch.utils.data.DataLoader(trainData, batch_size=160, shuffle=True, num_workers=32, pin_memory=True, drop_last=True) # val data valData = dataset.SemData(split='val', data_root=rootpath, data_list=vallist, transform=valTrans) valDataLoader = torch.utils.data.DataLoader(valData, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) # return datasets return trainDataLoader, valDataLoader
def prepare_dataset(rootpath, trainlist, vallist, mean, std): # train transform template trans = transform.Compose([ transform.Resize(INPUTHW), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) # val transform template valtrans = transform.Compose([ transform.Resize(INPUTHW), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) # training data train_dataset = data.SemData(split='train', data_root=rootpath, data_list=trainlist, transform=trans) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=24, shuffle=True, num_workers=8, pin_memory=True, drop_last=True) # val data val_dataset = data.SemData(split='val', data_root=rootpath, data_list=vallist, transform=valtrans) val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True) return train_dataloader, val_dataloader
def train(gpu, ngpus_per_node, argss): global args args = argss if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) teacher_model = None if args.teacher_model_path: teacher_model = PSPNet(layers=args.teacher_layers, classes=args.classes, zoom_factor=args.zoom_factor) kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature) args.save_path = os.path.join(args.save_path, kd_path) if not os.path.exists(args.save_path): os.mkdir(args.save_path) if args.arch == 'dct': model = DCTNet(layers=args.layers, classes=args.classes, vec_dim=300) # modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4] # modules_new = [model.cls, model.aux] # DCT4 modules_ori = [model.cp, model.sp, model.head] modules_new = [] args.index_split = len( modules_ori ) # the module after index_split need multiply 10 at learning rate params_list = [] for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) for module in modules_new: params_list.append( dict(params=module.parameters(), lr=args.base_lr * 10)) # args.index_split = 5 optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if teacher_model is not None: teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm( teacher_model) if main_process(): global logger, writer logger = get_logger() writer = SummaryWriter(args.save_path) # tensorboardX logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) if teacher_model is not None: logger.info(teacher_model) if args.distributed: torch.cuda.set_device(gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.batch_size_val = int(args.batch_size_val / ngpus_per_node) args.workers = int( (args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu]) if teacher_model is not None: teacher_model = torch.nn.parallel.DistributedDataParallel( teacher_model.cuda(), device_ids=[gpu]) else: model = torch.nn.DataParallel(model.cuda()) if teacher_model is not None: teacher_model = torch.nn.DataParallel(teacher_model.cuda()) if teacher_model is not None: checkpoint = torch.load( args.teacher_model_path, map_location=lambda storage, loc: storage.cuda()) teacher_model.load_state_dict(checkpoint['state_dict'], strict=False) print("=> loading teacher checkpoint '{}'".format( args.teacher_model_path)) if args.use_ohem: criterion = OhemCELoss(thresh=0.7, ignore_index=args.ignore_label).cuda(gpu) else: criterion = nn.CrossEntropyLoss( ignore_index=args.ignore_label).cuda(gpu) kd_criterion = None if teacher_model is not None: kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(gpu) if args.weight: if os.path.isfile(args.weight): if main_process(): logger.info("=> loading weight: '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) if main_process(): logger.info("=> loaded weight '{}'".format(args.weight)) else: if main_process(): logger.info("=> mp weight found at '{}'".format(args.weight)) best_mIoU_val = 0.0 if args.resume: if os.path.isfile(args.resume): if main_process(): logger.info("=> loading checkpoint '{}'".format(args.resume)) # Load all tensors onto GPU checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_iter = checkpoint['iteration'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) best_mIoU_val = checkpoint['best_mIoU_val'] if main_process(): logger.info("=> loaded checkpoint '{}' (iteration {})".format( args.resume, checkpoint['iteration'])) else: if main_process(): logger.info("=> no checkpoint found at '{}'".format( args.resume)) value_scale = 255 ## RGB mean & std rgb_mean = [0.485, 0.456, 0.406] rgb_mean = [item * value_scale for item in rgb_mean] rgb_std = [0.229, 0.224, 0.225] rgb_std = [item * value_scale for item in rgb_std] # DCT mean & std dct_mean = dct_mean_std.train_upscaled_static_mean dct_mean = [item * value_scale for item in dct_mean] dct_std = dct_mean_std.train_upscaled_static_std dct_std = [item * value_scale for item in dct_std] train_transform = transform.Compose([ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=rgb_mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=rgb_mean, ignore_label=args.ignore_label), # transform.GetDctCoefficient(), transform.ToTensor(), transform.Normalize(mean=rgb_mean, std=rgb_std) ]) train_data = dataset.SemData(split='train', img_type='rgb', data_root=args.data_root, data_list=args.train_list, transform=train_transform) # train_transform = transform_rgbdct.Compose([ # transform_rgbdct.RandScale([args.scale_min, args.scale_max]), # transform_rgbdct.RandRotate([args.rotate_min, args.rotate_max], padding=rgb_mean, ignore_label=args.ignore_label), # transform_rgbdct.RandomGaussianBlur(), # transform_rgbdct.RandomHorizontalFlip(), # transform_rgbdct.Crop([args.train_h, args.train_w], crop_type='rand', padding=rgb_mean, ignore_label=args.ignore_label), # transform_rgbdct.GetDctCoefficient(), # transform_rgbdct.ToTensor(), # transform_rgbdct.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std, std_dct=dct_std)]) # train_data = dataset.SemData(split='train', img_type='rgb&dct', data_root=args.data_root, data_list=args.train_list, transform=train_transform) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_data) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, \ shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, \ sampler=train_sampler, drop_last=True) if args.evaluate: # val_h = int(args.base_h * args.scale) # val_w = int(args.base_w * args.scale) val_transform = transform.Compose([ transform.Crop([args.train_h, args.train_w], crop_type='center', padding=rgb_mean, ignore_label=args.ignore_label), # transform.Resize(size=(val_h, val_w)), # transform.GetDctCoefficient(), transform.ToTensor(), transform.Normalize(mean=rgb_mean, std=rgb_std) ]) val_data = dataset.SemData(split='val', img_type='rgb', data_root=args.data_root, data_list=args.val_list, transform=val_transform) # val_transform = transform_rgbdct.Compose([ # transform_rgbdct.Crop([args.train_h, args.train_w], crop_type='center', padding=rgb_mean, ignore_label=args.ignore_label), # # transform.Resize(size=(val_h, val_w)), # transform_rgbdct.GetDctCoefficient(), # transform_rgbdct.ToTensor(), # transform_rgbdct.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std, std_dct=dct_std)]) # val_data = dataset.SemData(split='val', img_type='rgb&dct', data_root=args.data_root, data_list=args.val_list, transform=val_transform) if args.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler( val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, \ shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) # Training Loop batch_time = AverageMeter() data_time = AverageMeter() main_loss_meter = AverageMeter() # aux_loss_meter = AverageMeter() loss_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() # switch to train mode model.train() if teacher_model is not None: teacher_model.eval() end = time.time() max_iter = args.max_iter data_iter = iter(train_loader) epoch = 0 for current_iter in range(args.start_iter, args.max_iter): try: input, target = next(data_iter) if not target.size(0) == args.batch_size: raise StopIteration except StopIteration: epoch += 1 if args.distributed: train_sampler.set_epoch(epoch) if main_process(): logger.info('train_sampler.set_epoch({})'.format(epoch)) data_iter = iter(train_loader) input, target = next(data_iter) # need to update the AverageMeter for new epoch main_loss_meter = AverageMeter() # aux_loss_meter = AverageMeter() loss_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() # measure data loading time data_time.update(time.time() - end) input = input.cuda(non_blocking=True) # input = [input[0].cuda(non_blocking=True), input[1].cuda(non_blocking=True)] target = target.cuda(non_blocking=True) # compute output # main_out, aux_out = model(input) main_out = model(input) # _, H, W = target.shape # main_out = F.interpolate(main_out, size=(H, W), mode='bilinear', align_corners=True) main_loss = criterion(main_out, target) # aux_loss = criterion(aux_out, target) if not args.multiprocessing_distributed: # main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss) main_loss = torch.mean(main_loss) # loss = main_loss + args.aux_weight * aux_loss loss = main_loss optimizer.zero_grad() loss.backward() optimizer.step() n = target.size(0) # if args.multiprocessing_distributed: # main_loss, aux_loss, loss = main_loss.detach() * n, aux_loss * n, loss * n # not considering ignore pixels # count = target.new_tensor([n], dtype=torch.long) # dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(loss), dist.all_reduce(count) # n = count.item() # main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n if args.multiprocessing_distributed: main_loss, loss = main_loss.detach( ) * n, loss * n # not considering ignore pixels count = target.new_tensor([n], dtype=torch.long) dist.all_reduce(main_loss), dist.all_reduce(loss), dist.all_reduce( count) n = count.item() main_loss, loss = main_loss / n, loss / n main_out = main_out.detach().max(1)[1] intersection, union, target = intersectionAndUnionGPU( main_out, target, args.classes, args.ignore_label) if args.multiprocessing_distributed: dist.all_reduce(intersection), dist.all_reduce( union), dist.all_reduce(target) intersection, union, target = intersection.cpu().numpy(), union.cpu( ).numpy(), target.cpu().numpy() intersection_meter.update(intersection), union_meter.update( union), target_meter.update(target) accuracy = sum( intersection_meter.val) / (sum(target_meter.val) + 1e-10) main_loss_meter.update(main_loss.item(), n) # aux_loss_meter.update(aux_loss.item(), n) loss_meter.update(loss.item(), n) batch_time.update(time.time() - end) end = time.time() # Using Poly strategy to change the learning rate current_lr = poly_learning_rate(args.base_lr, current_iter, max_iter, power=args.power) for index in range(0, args.index_split ): # args.index_split = 5 -> ResNet has 5 stages optimizer.param_groups[index]['lr'] = current_lr for index in range(args.index_split, len(optimizer.param_groups)): optimizer.param_groups[index]['lr'] = current_lr * 10 remain_iter = max_iter - current_iter remain_time = remain_iter * batch_time.avg t_m, t_s = divmod(remain_time, 60) t_h, t_m = divmod(t_m, 60) remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s)) iter_log = current_iter + 1 if iter_log % args.print_freq == 0 and main_process(): logger.info('Iter [{}/{}] ' 'LR: {lr:.3e}, ' 'ETA: {remain_time}, ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f}), ' 'Batch: {batch_time.val:.3f} ({batch_time.avg:.3f}), ' 'MainLoss: {main_loss_meter.val:.4f}, ' # 'AuxLoss: {aux_loss_meter.val:.4f}, ' 'Loss: {loss_meter.val:.4f}, ' 'Accuracy: {accuracy:.4f}.'.format( iter_log, args.max_iter, lr=current_lr, remain_time=remain_time, data_time=data_time, batch_time=batch_time, main_loss_meter=main_loss_meter, # aux_loss_meter=aux_loss_meter, loss_meter=loss_meter, accuracy=accuracy)) if main_process(): writer.add_scalar('loss_train_batch', main_loss_meter.val, iter_log) writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), iter_log) writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), iter_log) writer.add_scalar('allAcc_train_batch', accuracy, iter_log) if iter_log % len( train_loader ) == 0 or iter_log == max_iter: # for each epoch or the max interation iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) mIoU_train = np.mean(iou_class) mAcc_train = np.mean(accuracy_class) allAcc_train = sum( intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) loss_train = main_loss_meter.avg if main_process(): logger.info('Train result at iteration [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'\ .format(iter_log, max_iter, mIoU_train, mAcc_train, allAcc_train)) writer.add_scalar('loss_train', loss_train, iter_log) writer.add_scalar('mIoU_train', mIoU_train, iter_log) writer.add_scalar('mAcc_train', mAcc_train, iter_log) writer.add_scalar('allAcc_train', allAcc_train, iter_log) # if iter_log % args.save_freq == 0: is_best = False if args.evaluate: loss_val, mIoU_val, mAcc_val, allAcc_val = validate( val_loader, model, criterion) model.train() # the mode change from eval() to train() if main_process(): writer.add_scalar('loss_val', loss_val, iter_log) writer.add_scalar('mIoU_val', mIoU_val, iter_log) writer.add_scalar('mAcc_val', mAcc_val, iter_log) writer.add_scalar('allAcc_val', allAcc_val, iter_log) if best_mIoU_val < mIoU_val: is_best = True best_mIoU_val = mIoU_val logger.info('==>The best val mIoU: %.3f' % (best_mIoU_val)) if main_process(): save_checkpoint( { 'iteration': iter_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_mIoU_val': best_mIoU_val }, is_best, args.save_path) logger.info('Saving checkpoint to:{}/iter_{}.pth or last.pth with mIoU:{:.3f}'\ .format(args.save_path, iter_log, mIoU_val)) if is_best: logger.info('Saving checkpoint to:{}/best.pth with mIoU:{:.3f}'\ .format(args.save_path, best_mIoU_val)) if main_process(): writer.close( ) # it must close the writer, otherwise it will appear the EOFError! logger.info( '==>Training done! The best val mIoU during training: %.3f' % (best_mIoU_val))
def evaluate(respth='./res', dspth='/data2/.encoding/data/cityscapes', checkpoint=None): args = get_parser() ## logger logger = get_logger() ## model logger.info('\n') logger.info('====' * 20) logger.info('evaluating the model ...\n') logger.info('setup and restore model') n_classes = 19 net = FANet(layers=18, classes=n_classes) # if checkpoint is None: # save_pth = osp.join(respth, 'model_final.pth') # else: # save_pth = checkpoint # net.load_state_dict(torch.load(save_pth)) if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) # checkpoint = torch.load(args.model_path, map_location=torch.device('cpu')) checkpoint = torch.load(args.model_path) net.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) net.cuda() net.eval() ## dataset # batchsize = 1 # n_workers = 2 # dsval = CityScapes(dspth, mode='val') # dl = DataLoader(dsval, # batch_size = batchsize, # shuffle = False, # num_workers = n_workers, # drop_last = False) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] val_transform = transform.Compose([ # transform.Crop([args.test_h, args.test_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) dl = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) ## evaluator logger.info('compute the mIOU') evaluator = MscEval(net, dl, scales=[1.0], flip=False) ## eval mIOU = evaluator.evaluate() logger.info('mIOU is: {:.6f}'.format(mIOU))
if __name__ == "__main__": value_scale = 255 mean = [74.8559803, 79.1187336, 80.7307415] # mean = [0.485, 0.456, 0.406] # mean = [item * value_scale for item in mean] # std = [0.229, 0.224, 0.225] std = [19.19655189, 19.56021428, 24.39020428] # std = [item * value_scale for item in std] train_transform = transform.Compose([ transform.Resize((512, 512)), # transform.RandScale([0.5, 2]), transform.RandRotate([0, 45], padding=mean, ignore_label=0), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([256, 256], crop_type='rand', padding=mean, ignore_label=0), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) dataset = MyDataset("dataset/train", transform=train_transform) batchSize = 16 validation_split = 2 shuffle_dataset = True random_seed = 42 train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, test_size]) train_loader = DataLoader(train_dataset,
torch.distributed.init_process_group(backend='gloo', init_method='env://') synchronize() device = 'cuda' # valid_trans = transform.Compose( # [ # transform.Resize(args.test_min_size, args.test_max_size), # transform.ToTensor(), # transform.Normalize(args.pixel_mean, args.pixel_std) # ] # ) valid_trans = transform.Compose([ transform.Resize_For_Efficientnet(compund_coef=args.backbone_coef), transform.ToTensor(), transform.Normalize(args.pixel_mean, args.pixel_std) ]) valid_set = COCODataset("data/coco2017/", 'val', valid_trans) # backbone = vovnet39(pretrained=False) # backbone = resnet18(pretrained=False) #backbone = resnet50(pretrained=False) #model = ATSS(args, backbone) model = Efficientnet_Bifpn_ATSS(args, compound_coef=args.backbone_coef, load_backboe_weight=False) # load weight model_file = args.weight_path
def evaluate(): args = get_parser() logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) if args.arch == 'psp': model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'nonlocal': model = Nonlocal(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'danet': model = DANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'sanet': model = SANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'fanet': model = FANet(layers=args.layers, classes=args.classes) elif args.arch == 'fftnet': model = FFTNet(layers=args.layers, classes=args.classes) elif args.arch == 'fftnet_23': model = FFTNet23(layers=args.layers, classes=args.classes) elif args.arch == 'bise_v1': model = BiseNet(layers=args.layers, classes=args.classes, with_sp=args.with_sp) elif args.arch == 'dct': model = DCTNet(layers=args.layers, classes=args.classes, vec_dim=300) elif args.arch == 'triple': model = TriSeNet(layers=args.layers, classes=args.classes) elif args.arch == 'triple_1': model = TriSeNet1(layers=args.layers, classes=args.classes) elif args.arch == 'ppm': model = PPM_Net(backbone=args.backbone, layers=args.layers, classes=args.classes) elif args.arch == 'fc': model = FC_Net(backbone=args.backbone, layers=args.layers, classes=args.classes) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) # checkpoint = torch.load(args.model_path, map_location=torch.device('cpu')) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=True) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) value_scale = 255 ## RGB mean & std rgb_mean = [0.485, 0.456, 0.406] rgb_mean = [item * value_scale for item in rgb_mean] rgb_std = [0.229, 0.224, 0.225] rgb_std = [item * value_scale for item in rgb_std] # DCT mean & std dct_mean = dct_mean_std.train_upscaled_static_mean dct_mean = [item * value_scale for item in dct_mean] dct_std = dct_mean_std.train_upscaled_static_std dct_std = [item * value_scale for item in dct_std] val_h = int(args.base_h * args.scale) val_w = int(args.base_w * args.scale) val_transform = transform.Compose([ transform.Resize(size=(val_h, val_w)), # transform.GetDctCoefficient(), transform.ToTensor(), transform.Normalize(mean=rgb_mean, std=rgb_std) ]) val_data = dataset.SemData(split='val', img_type='rgb', data_root=args.data_root, data_list=args.val_list, transform=val_transform) val_loader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) # val_transform = transform.Compose([ # # transform.Resize(size=(val_h, val_w)), # transform.GetDctCoefficient(), # transform.ToTensor(), # transform.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std, std_dct=dct_std)]) # val_data = dataset.SemData(split='val', img_type='rgb&dct', data_root=args.data_root, data_list=args.val_list, transform=val_transform) # val_loader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) # test_transform = transform.Compose([ # transform.ToTensor(), # transform.Normalize(mean=mean, std=std)]) # # test_transform = transform.Compose([transform.ToTensor()]) # test_data = dataset.SemData(split='test', data_root=args.data_root, data_list=args.test_list, transform=test_transform) # val_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) logger.info('>>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>>') batch_time = AverageMeter() data_time = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() model.eval() end = time.time() results = [] with torch.no_grad(): for i, (input, target) in enumerate(val_loader): data_time.update(time.time() - end) # _, _, H, W = input.shape input = input.cuda(non_blocking=True) # input = [input[0].cuda(non_blocking=True), input[1].cuda(non_blocking=True)] target = target.cuda(non_blocking=True) # if args.scale != 1.0: # input = F.interpolate(input, size=(val_h, val_w), mode='bilinear', align_corners=True) if args.teacher_model_path != None and args.arch == 'sanet': output, _, _ = model(input) else: output = model(input) # if args.scale != 1.0: # output = F.interpolate(output, size=(H, W), mode='bilinear', align_corners=True) _, H, W = target.shape # output = F.interpolate(output, size=(H, W), mode='bilinear', align_corners=True) output = output.detach().max(1)[1] results.append(output.cpu().numpy().reshape(H, W)) intersection, union, target = intersectionAndUnionGPU( output, target, args.classes, args.ignore_label) intersection, union, target = intersection.cpu().numpy( ), union.cpu().numpy(), target.cpu().numpy() intersection_meter.update(intersection), union_meter.update( union), target_meter.update(target) accuracy = sum( intersection_meter.val) / (sum(target_meter.val) + 1e-10) batch_time.update(time.time() - end) end = time.time() if ((i + 1) % 10 == 0) or (i + 1 == len(val_loader)): logger.info( 'Val: [{}/{}] ' 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Accuracy {accuracy:.4f}.'.format(i + 1, len(val_loader), data_time=data_time, batch_time=batch_time, accuracy=accuracy)) iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) mIoU = np.mean(iou_class) mAcc = np.mean(accuracy_class) allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format( mIoU, mAcc, allAcc)) for i in range(args.classes): logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format( i, iou_class[i], accuracy_class[i])) logger.info('<<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<<<<<')
def test(): args = get_parser() logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] test_h = int(args.base_h * args.scale) test_w = int(args.base_w * args.scale) test_transform = transform.Compose( [transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) # test_transform = transform.Compose([transform.ToTensor()]) test_data = dataset.SemData(split='test', data_root=args.data_root, data_list=args.test_list, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) if args.arch == 'psp': model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'nonlocal': model = Nonlocal(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'danet': model = DANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'sanet': model = SANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'fanet': model = FANet(layers=args.layers, classes=args.classes) elif args.arch == 'bise_v1': model = BiseNet(layers=args.layers, classes=args.classes, with_sp=args.with_sp) elif args.arch == 'dct': model = DCTNet(layers=args.layers, classes=args.classes, use_dct=True, vec_dim=300) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) # checkpoint = torch.load(args.model_path, map_location=torch.device('cpu')) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict']) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) logger.info('>>>>>>>>>>>>>>>>> Start Testing >>>>>>>>>>>>>>>>>') data_time = AverageMeter() batch_time = AverageMeter() model.eval() end = time.time() results = [] with torch.no_grad(): for i, (input, _) in enumerate(test_loader): data_time.update(time.time() - end) input = np.squeeze(input.numpy(), axis=0) image = np.transpose(input, (1, 2, 0)) h, w, _ = image.shape prediction = np.zeros((h, w, args.classes), dtype=float) input = torch.from_numpy(image.transpose((2, 0, 1))).float() # for t, m, s in zip(input, mean, std): # t.sub_(m).div_(s) input = input.unsqueeze(0).cuda() # if args.scale != 1.0: # input = F.interpolate(input, size=(new_H, new_W), mode='bilinear', align_corners=True) if args.teacher_model_path != None and args.arch == 'sanet': output, _, _ = model(input) else: output = model(input) # if args.scale != 1.0: # output = F.interpolate(output, size=(H, W), mode='bilinear', align_corners=True) output = F.softmax(output, dim=1) output = output[0] output = output.data.cpu().numpy() prediction = output.transpose(1, 2, 0) prediction = np.argmax(prediction, axis=2) results.append(prediction) batch_time.update(time.time() - end) end = time.time() if ((i + 1) % 10 == 0) or (i + 1 == len(test_loader)): logger.info( 'Test: [{}/{}] ' 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}).'. format(i + 1, len(test_loader), data_time=data_time, batch_time=batch_time)) logger.info('<<<<<<<<<<<<<<<<<< End Testing <<<<<<<<<<<<<<<<<<<<<') logger.info('Convert to Label ID') result_files = dataset.results2img(results=results, data_root=args.data_root, data_list=args.test_list, save_dir='./val_result', to_label_id=True) logger.info('Convert to Label ID Finished')
import torch import torch.utils.data from torch.utils.data import DataLoader from torchvision import transforms from utils import dataset, transform data_dir = './dataset' bs = 8 mean = [58.6573, 65.9755, 56.4990, 100.8296] std = [60.2980, 59.0457, 58.0989, 47.1224] train_transform = transform.Compose( [transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) train_data = dataset.SemData(split='train', data_root=data_dir, transform=train_transform) train_loader = DataLoader(dataset=train_data, batch_size=bs, shuffle=True, num_workers=4, pin_memory=True, drop_last=False) data, label, _, _ = next(iter(train_loader)) for i in range(len(data)): print(data[i].mean(), data[i].std())
def main_worker(local_rank, ngpus_per_node, argss): global args args = argss dist.init_process_group(backend=args.dist_backend) teacher_model = None if args.teacher_model_path: teacher_model = PSPNet(layers=args.teacher_layers, classes=args.classes, zoom_factor=args.zoom_factor) kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature) args.save_path = os.path.join(args.save_path, kd_path) if not os.path.exists(args.save_path): os.mkdir(args.save_path) if args.arch == 'psp': model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4] modules_new = [model.ppm, model.cls, model.aux] elif args.arch == 'bise_v1': model = BiseNet(num_classes=args.classes) modules_ori = [model.sp, model.cp] modules_new = [model.ffm, model.conv_out, model.conv_out16, model.conv_out32] params_list = [] for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) for module in modules_new: params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10)) args.index_split = 5 optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if teacher_model is not None: teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model) if main_process(): global logger, writer logger = get_logger() writer = SummaryWriter(args.save_path) # tensorboardX logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) if teacher_model is not None: logger.info(teacher_model) if args.distributed: torch.cuda.set_device(local_rank) args.batch_size = int(args.batch_size / ngpus_per_node) args.batch_size_val = int(args.batch_size_val / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank]) if teacher_model is not None: teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model.cuda(), device_ids=[local_rank]) else: model = torch.nn.DataParallel(model.cuda()) if teacher_model is not None: teacher_model = torch.nn.DataParallel(teacher_model.cuda()) if teacher_model is not None: checkpoint = torch.load(args.teacher_model_path, map_location=lambda storage, loc: storage.cuda()) teacher_model.load_state_dict(checkpoint['state_dict'], strict=False) print("=> loading teacher checkpoint '{}'".format(args.teacher_model_path)) criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda(local_rank) kd_criterion = None if teacher_model is not None: kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(local_rank) if args.weight: if os.path.isfile(args.weight): if main_process(): logger.info("=> loading weight: '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) if main_process(): logger.info("=> loaded weight '{}'".format(args.weight)) else: if main_process(): logger.info("=> mp weight found at '{}'".format(args.weight)) best_mIoU_val = 0.0 if args.resume: if os.path.isfile(args.resume): if main_process(): logger.info("=> loading checkpoint '{}'".format(args.resume)) # Load all tensors onto GPU checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) best_mIoU_val = checkpoint['best_mIoU_val'] if main_process(): logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['point'])) else: if main_process(): logger.info("=> no checkpoint found at '{}'".format(args.resume)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = transform.Compose([ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) if args.evaluate: val_transform = transform.Compose([ transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) if args.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 if args.distributed: # Use .set_epoch() method to reshuffle the dataset partition at every iteration train_sampler.set_epoch(epoch) loss_train, mIoU_train, mAcc_train, allAcc_train = train(local_rank, train_loader, model, teacher_model, criterion, kd_criterion, optimizer, epoch) if main_process(): writer.add_scalar('loss_train', loss_train, epoch_log) writer.add_scalar('mIoU_train', mIoU_train, epoch_log) writer.add_scalar('mAcc_train', mAcc_train, epoch_log) writer.add_scalar('allAcc_train', allAcc_train, epoch_log) is_best = False if args.evaluate: loss_val, mIoU_val, mAcc_val, allAcc_val = validate(local_rank, val_loader, model, criterion) if main_process(): writer.add_scalar('loss_val', loss_val, epoch_log) writer.add_scalar('mIoU_val', mIoU_val, epoch_log) writer.add_scalar('mAcc_val', mAcc_val, epoch_log) writer.add_scalar('allAcc_val', allAcc_val, epoch_log) if best_mIoU_val < mIoU_val: is_best = True best_mIoU_val = mIoU_val logger.info('==>The best val mIoU: %.3f' % (best_mIoU_val)) if (epoch_log % args.save_freq == 0) and main_process(): save_checkpoint( { 'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_mIoU_val': best_mIoU_val }, is_best, args.save_path ) if is_best: logger.info('Saving checkpoint to:' + args.save_path + '/best.pth with mIoU: ' + str(best_mIoU_val) ) else: logger.info('Saving checkpoint to:' + args.save_path + '/last.pth with mIoU: ' + str(mIoU_val) ) if main_process(): writer.close() # it must close the writer, otherwise it will appear the EOFError! logger.info('==>Training done!\nBest mIoU: %.3f' % (best_mIoU_val))
def main(): global args, logger args = get_parser() check(args) logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] test_h = int(args.base_h * args.scale + 1) test_w = int(args.base_w * args.scale + 1) if args.teacher_model_path: kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature) kd_save = kd_path + '/val/ss' args.save_folder = os.path.join(args.save_folder, kd_save) args.save_path = os.path.join(args.save_path, kd_path) args.model_path = os.path.join(args.save_path, args.model_path) gray_folder = os.path.join(args.save_folder, 'gray') color_folder = os.path.join(args.save_folder, 'color') test_transform = transform.Compose([transform.ToTensor()]) test_data = dataset.SemData(split='test', data_root=args.data_root, data_list=args.test_list, transform=test_transform) index_start = args.index_start if args.index_step == 0: index_end = len(test_data.data_list) else: index_end = min(index_start + args.index_step, len(test_data.data_list)) test_data.data_list = test_data.data_list[index_start:index_end] test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) colors = np.loadtxt(args.colors_path).astype('uint8') names = [line.rstrip('\n') for line in open(args.names_path)] if not args.has_prediction: if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'nonlocal': from model.nonlocal_net import Nonlocal model = Nonlocal(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'sanet': from model.sanet import SANet model = SANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) elif args.arch == 'bise_v1': from model.bisenet_v1 import BiseNet model = BiseNet(layers=args.layers, classes=args.classes, with_sp=args.with_sp) elif args.arch == 'fanet': from model.fanet import FANet model = FANet(layers=args.layers, classes=args.classes) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict']) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path)) test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, test_h, test_w, args.scales, gray_folder, color_folder, colors) if args.split != 'test': cal_acc(test_data.data_list, gray_folder, args.classes, names)
print(Path.db_root_dir('mf')) data_dir = './dataset' n_class = 9 mean = [58.6573, 65.9755, 56.4990, 100.8296] std = [60.2980, 59.0457, 58.0989, 47.1224] scale_min, scale_max = 0.5, 2.0 rotate_min, rotate_max = -10, 10 train_h = 256 # default: 480 train_w = 512 # default: 640 train_transform = transform.Compose([ transform.RandScale([scale_min, scale_max]), transform.RandRotate([rotate_min, rotate_max], padding=mean, ignore_label=255), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([train_h, train_w], crop_type='rand', padding=mean, ignore_label=255), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) val_transform = transform.Compose( [transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) train_data = dataset.SemData(split='train', data_root=data_dir, transform=train_transform) val_data = dataset.SemData(split='val',
def evaluate(): args = get_parser() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( str(x) for x in args.test_gpu) if args.arch == 'triple': model = TriSeNet(layers=args.layers, classes=args.classes) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.model_path): print("=> loading checkpoint '{}'".format(args.model_path)) # checkpoint = torch.load(args.model_path, map_location=torch.device('cpu')) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=True) print("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) value_scale = 255 ## RGB mean & std rgb_mean = [0.485, 0.456, 0.406] rgb_mean = [item * value_scale for item in rgb_mean] rgb_std = [0.229, 0.224, 0.225] rgb_std = [item * value_scale for item in rgb_std] val_transform = transform.Compose([ transform.ToTensor(), transform.Normalize(mean=rgb_mean, std=rgb_std) ]) train_data = SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=val_transform) val_loader = torch.utils.data.DataLoader(train_data, batch_size=1, \ shuffle=False, num_workers=args.workers, pin_memory=True) print('>>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>>') batch_time = AverageMeter() data_time = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() model.eval() end = time.time() results = [] with torch.no_grad(): for i, (input, target) in enumerate(val_loader): data_time.update(time.time() - end) input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) output = model(input) _, H, W = target.shape output = output.detach().max(1)[1] results.append(output.cpu().numpy().reshape(H, W)) intersection, union, target = intersectionAndUnionGPU( output, target, args.classes, args.ignore_label) intersection, union, target = intersection.cpu().numpy( ), union.cpu().numpy(), target.cpu().numpy() intersection_meter.update(intersection), union_meter.update( union), target_meter.update(target) accuracy = sum( intersection_meter.val) / (sum(target_meter.val) + 1e-10) batch_time.update(time.time() - end) end = time.time() print('Val: [{}/{}] ' 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Accuracy {accuracy:.4f}.'.format(i + 1, len(val_loader), data_time=data_time, batch_time=batch_time, accuracy=accuracy)) iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) mIoU = np.mean(iou_class) mAcc = np.mean(accuracy_class) allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) print('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format( mIoU, mAcc, allAcc)) for i in range(args.classes): print('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format( i, iou_class[i], accuracy_class[i])) print('<<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<<<<<') print('Convert to Label ID') result_files = results2img(results=results, data_root=args.data_root, data_list=args.train_list, save_dir='./visualization/train_result', to_label_id=True) print('Convert to Label ID Finished')
def train(gpu, ngpus_per_node, argss): global args args = argss if args.arch == 'triple': model = TriSeNet(layers=args.layers, classes=args.classes) modules_ori = [ model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 ] # modules_new = [model.down_8_32, model.sa_8_32, model.seg_head] modules_new = [] for key, value in model._modules.items(): if "layer" not in key: modules_new.append(value) args.index_split = len( modules_ori ) # the module after index_split need multiply 10 at learning rate params_list = [] for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) for module in modules_new: params_list.append( dict(params=module.parameters(), lr=args.base_lr * 10)) optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) print("=> creating model ...") print("Classes: {}".format(args.classes)) print(model) model = torch.nn.DataParallel(model.cuda()) cudnn.benchmark = True criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda(gpu) value_scale = 255 ## RGB mean & std mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = transform.Compose([ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) train_data = SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, \ shuffle=True, num_workers=args.workers, pin_memory=True, \ sampler=None, drop_last=True) # Training Loop batch_time = AverageMeter() data_time = AverageMeter() main_loss_meter = AverageMeter() aux_loss_meter = AverageMeter() loss_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() model.train() end = time.time() max_iter = args.max_iter data_iter = iter(train_loader) epoch = 0 for current_iter in range(args.start_iter, args.max_iter): try: input, target = next(data_iter) if not input.size(0) == args.batch_size: raise StopIteration except StopIteration: epoch += 1 data_iter = iter(train_loader) input, target = next(data_iter) # need to update the AverageMeter for new epoch main_loss_meter = AverageMeter() aux_loss_meter = AverageMeter() loss_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() # measure data loading time data_time.update(time.time() - end) input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) main_out = model(input) main_loss = criterion(main_out, target) aux_loss = torch.tensor(0).cuda() loss = main_loss + args.aux_weight * aux_loss optimizer.zero_grad() loss.backward() optimizer.step() n = input.size(0) main_out = main_out.detach().max(1)[1] intersection, union, target = intersectionAndUnionGPU( main_out, target, args.classes, args.ignore_label) intersection, union, target = intersection.cpu().numpy(), union.cpu( ).numpy(), target.cpu().numpy() intersection_meter.update(intersection), union_meter.update( union), target_meter.update(target) accuracy = sum( intersection_meter.val) / (sum(target_meter.val) + 1e-10) main_loss_meter.update(main_loss.item(), n) aux_loss_meter.update(aux_loss.item(), n) loss_meter.update(loss.item(), n) batch_time.update(time.time() - end) end = time.time() # Using Poly strategy to change the learning rate current_lr = poly_learning_rate(args.base_lr, current_iter, max_iter, power=args.power) for index in range(0, args.index_split ): # args.index_split = 5 -> ResNet has 5 stages optimizer.param_groups[index]['lr'] = current_lr for index in range(args.index_split, len(optimizer.param_groups)): optimizer.param_groups[index]['lr'] = current_lr * 10 remain_iter = max_iter - current_iter remain_time = remain_iter * batch_time.avg t_m, t_s = divmod(remain_time, 60) t_h, t_m = divmod(t_m, 60) remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s)) iter_log = current_iter + 1 if iter_log % args.print_freq == 0: print('Iteration: [{}/{}] ' 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'ETA {remain_time} ' 'MainLoss {main_loss_meter.val:.4f} ' 'AuxLoss {aux_loss_meter.val:.4f} ' 'Loss {loss_meter.val:.4f} ' 'Accuracy {accuracy:.4f}.'.format( iter_log, args.max_iter, data_time=data_time, batch_time=batch_time, remain_time=remain_time, main_loss_meter=main_loss_meter, aux_loss_meter=aux_loss_meter, loss_meter=loss_meter, accuracy=accuracy)) save_checkpoint( { 'iteration': iter_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, False, args.save_path)
model = UNet().to(device) if sys.argv[1] == 'R2U_Net': model = R2U_Net().to(device) if sys.argv[1] == 'AttU_Net': model = AttU_Net().to(device) if sys.argv[1] == 'R2AttU_Net': model = R2AttU_Net().to(device) if sys.argv[1] == 'RAUNet': model = RAUNet().to(device) # model = R2U_Net().to(device) train_transform = transform.Compose([ # transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([0, 45], padding=mean, ignore_label=0), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([256, 256], crop_type='rand', padding=mean, ignore_label=0), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) dataset = MyDataset("dataset/train", transform=train_transform) # train(unet, optimizer, dataset, epoch_num=int(sys.argv[1])) for i in range(10): optimizer = optim.Adam(model.parameters(), lr=1e-3) train(model, optimizer, dataset, epoch_num=int(sys.argv[2])) # train(model, optimizer, dataset, epoch_num=10)