def __init__(self, config_file=CONFIG_FILE): # Load Parameters self.args_ = config.load_cfg_from_cfg_file(config_file) self.logger_ = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( str(x) for x in self.args_.test_gpu) value_scale = 255 mean = [0.485, 0.456, 0.406] self.mean_ = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] self.std_ = [item * value_scale for item in std] # self.colors_ = np.loadtxt(self.args_.colors_path).astype('uint8') # Load Model if self.args_.arch == 'psp': from model.pspnet import PSPNet self.model_ = PSPNet(layers=self.args_.layers, classes=self.args_.classes, zoom_factor=self.args_.zoom_factor, pretrained=False) elif self.args_.arch == 'psa': from model.psanet import PSANet self.model_ = PSANet( layers=self.args_.layers, classes=self.args_.classes, zoom_factor=self.args_.zoom_factor, compact=self.args_.compact, shrink_factor=self.args_.shrink_factor, mask_h=self.args_.mask_h, mask_w=self.args_.mask_w, normalization_factor=self.args_.normalization_factor, psa_softmax=self.args_.psa_softmax, pretrained=False) self.model_ = torch.nn.DataParallel(self.model_).cuda() cudnn.benchmark = True if os.path.isfile(self.args_.model_path): self.logger_ = get_logger().info( "=> loading checkpoint '{}'".format(self.args_.model_path)) checkpoint = torch.load(self.args_.model_path) self.model_.load_state_dict(checkpoint['state_dict'], strict=False) self.logger_ = get_logger().info( "=> loaded checkpoint '{}'".format(self.args_.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( self.args_.model_path))
def main(): global args, logger args = get_parser() check(args) logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] gray_folder = os.path.join(args.save_folder, 'gray') color_folder = os.path.join(args.save_folder, 'color') test_transform = transform.Compose([transform.ToTensor()]) test_data = dataset.SemData(split=args.split, data_root=args.data_root, data_list=args.test_list, transform=test_transform) index_start = args.index_start if args.index_step == 0: index_end = len(test_data.data_list) else: index_end = min(index_start + args.index_step, len(test_data.data_list)) test_data.data_list = test_data.data_list[index_start:index_end] test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) colors = np.loadtxt(args.colors_path).astype('uint8') names = [line.rstrip('\n') for line in open(args.names_path)] if not args.has_prediction: if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path)) test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, gray_folder, color_folder, colors) if args.split != 'test': cal_acc(test_data.data_list, gray_folder, args.classes, names)
def main(): global args, logger args = get_parser() check(args) logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] colors = np.loadtxt(args.colors_path).astype('uint8') if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = False #True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path)) with open(args.image) as f: image_files = f.read().splitlines() for file in image_files: image = file.split() image = os.path.join('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/', image[0]) test(model.eval(), image, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, colors) if (args.image).split('/')[-1] == 'training.txt' train_label_list = os.listdir('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/train_label') with open('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/training.txt', 'w') as f: for label in train_label_list: f.write('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/train_label/'+label+'\n') else: val_label_list = os.listdir('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/val_label/') with open('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/validation.txt', 'w') as f: for label in val_label_list: f.write('/scratch/mw3706/dim/Deep_Image_Matting_Reproduce/pspnet/data/portrait/label/val_label/'+label+'\n')
def main(): global args, logger args = get_parser() check(args) logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] colors = np.loadtxt(args.colors_path).astype('uint8') if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) paths = glob.glob(args.image + '/scene*/color/*00.jpg') for path in paths: test(model.eval(), path, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, colors)
def main_worker(gpu, ngpus_per_node, argss): global args args = argss ## step.1 设置分布式相关参数 # 1.1 分布式初始化 if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # 分布式初始化 ## step.2 构建网络 # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------# criterion = nn.CrossEntropyLoss( ignore_index=args.ignore_label) # 交叉熵损失函数, 根据情况自己修改 if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion) modules_ori = [ model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 ] modules_new = [model.ppm, model.cls, model.aux] elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, psa_type=args.psa_type, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, criterion=criterion) modules_ori = [ model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 ] modules_new = [model.psa, model.cls, model.aux] # ---------------------------------------------------- END ---------------------------------------------------# ## step.3 设置优化器 params_list = [] # 模型参数列表 for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) # 原来backbone网络 学习率 0.01 for module in modules_new: params_list.append( dict(params=module.parameters(), lr=args.base_lr * 10)) # 新加入预测网络 学习率 0.1 args.index_split = 5 optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) # SGD优化器 # 3.x 设置sync_bn from torch.nn.SyncBatchNorm if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) ## step.4 多线程分布式工作 # 4.1 判断是否是在主进程中, 如果在进行如下程序 if main_process(): global logger, writer logger = get_logger() # 设置logger writer = SummaryWriter(args.save_path) # 设置writer logger.info(args) # 输出参数列表 logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) # 输出网络列表 # 4.2 分布式工作 if args.distributed: torch.cuda.set_device(gpu) # 指定编号为gpu的那一张显卡 args.batch_size = int(args.batch_size / ngpus_per_node) # 每张卡的训练的batch size args.batch_size_val = int(args.batch_size_val / ngpus_per_node) # 每张卡的评测的batch size args.workers = int( (args.workers + ngpus_per_node - 1) / ngpus_per_node) # 每张卡工作的数目 model = torch.nn.parallel.DistributedDataParallel( model.cuda(), device_ids=[gpu]) # 加载torch分布式 else: model = torch.nn.DataParallel(model.cuda()) # 数据并行 ## step.5 加载网络权重 # 5.1 直接加载网络预权重 if args.weight: if os.path.isfile(args.weight): if main_process(): logger.info("=> loading weight '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) if main_process(): logger.info("=> loaded weight '{}'".format(args.weight)) else: if main_process(): logger.info("=> no weight found at '{}'".format(args.weight)) # 5.2 加载上次没训练完的模型权重 if args.resume: if os.path.isfile(args.resume): if main_process(): logger.info("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(args.resume) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) if main_process(): logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: if main_process(): logger.info("=> no checkpoint found at '{}'".format( args.resume)) ## step.7 设置数据loader # 7.1 loader参数设置 value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = transform.Compose([ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) # 组合数据预处理 # 7.2 训练数据, 可以根据需要自己修改或写 # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------# train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) # ---------------------------------------------------- END ---------------------------------------------------# if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_data) # 分布式下数据loader else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) if args.evaluate: # evaluate数据 val_transform = transform.Compose([ transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) if args.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler( val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader( val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) ## step.8 主循环 for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 if args.distributed: train_sampler.set_epoch(epoch) # 8.1 训练函数 # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------# loss_train, mIoU_train, mAcc_train, allAcc_train = train( train_loader, model, optimizer, epoch) # ---------------------------------------------------- END ---------------------------------------------------# if main_process(): writer.add_scalar('loss_train', loss_train, epoch_log) writer.add_scalar('mIoU_train', mIoU_train, epoch_log) writer.add_scalar('mAcc_train', mAcc_train, epoch_log) writer.add_scalar('allAcc_train', allAcc_train, epoch_log) # 8.2 保存checkpoint if (epoch_log % args.save_freq == 0) and main_process(): filename = args.save_path + '/train_epoch_' + str( epoch_log) + '.pth' logger.info('Saving checkpoint to: ' + filename) torch.save( { 'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename) if epoch_log / args.save_freq > 2: deletename = args.save_path + '/train_epoch_' + str( epoch_log - args.save_freq * 2) + '.pth' os.remove(deletename) # 训练一个epoch之后evaluate if args.evaluate: loss_val, mIoU_val, mAcc_val, allAcc_val = validate( val_loader, model, criterion) if main_process(): writer.add_scalar('loss_val', loss_val, epoch_log) writer.add_scalar('mIoU_val', mIoU_val, epoch_log) writer.add_scalar('mAcc_val', mAcc_val, epoch_log) writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
def main_worker(gpu, ngpus_per_node, argss): global args args = argss if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) if args.arch == 'psp': from model.kdnet import KDNet model = KDNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion, temperature=args.temperature, alpha=args.alpha) modules_ori = [ model.student_net.layer0, model.student_net.layer1, model.student_net.layer2, model.student_net.layer3, model.student_net.layer4 ] modules_new = [ model.student_net.ppm, model.student_net.cls, model.student_net.aux ] teacher_net = model.teacher_loader elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, psa_type=args.psa_type, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, criterion=criterion) modules_ori = [ model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 ] modules_new = [model.psa, model.cls, model.aux] params_list = [] for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) for module in modules_new: params_list.append( dict(params=module.parameters(), lr=args.base_lr * 10)) args.index_split = 5 optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if main_process(): global logger, writer logger = get_logger() writer = SummaryWriter(args.save_path) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) if args.distributed: torch.cuda.set_device(gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.batch_size_val = int(args.batch_size_val / ngpus_per_node) args.workers = int( (args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model.cuda(), device_ids=[gpu], find_unused_parameters=True) else: model = torch.nn.DataParallel(model.cuda()) if args.weight: if os.path.isfile(args.weight): if main_process(): logger.info("=> loading weight '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) if main_process(): logger.info("=> loaded weight '{}'".format(args.weight)) else: if main_process(): logger.info("=> no weight found at '{}'".format(args.weight)) if args.resume: if os.path.isfile(args.resume): if main_process(): logger.info("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(args.resume) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) if main_process(): logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: if main_process(): logger.info("=> no checkpoint found at '{}'".format( args.resume)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = transform.Compose([ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_data) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) if args.evaluate: val_transform = transform.Compose([ transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) if args.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler( val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader( val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 if args.distributed: train_sampler.set_epoch(epoch) loss_train, mIoU_train, mAcc_train, allAcc_train = train( train_loader, model, optimizer, epoch) if main_process(): writer.add_scalar('loss_train', loss_train, epoch_log) writer.add_scalar('mIoU_train', mIoU_train, epoch_log) writer.add_scalar('mAcc_train', mAcc_train, epoch_log) writer.add_scalar('allAcc_train', allAcc_train, epoch_log) if (epoch_log % args.save_freq == 0) and main_process(): filename = args.save_path + '/train_epoch_' + str( epoch_log) + '.pth' logger.info('Saving checkpoint to: ' + filename) torch.save( { 'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename) if epoch_log / args.save_freq > 2: deletename = args.save_path + '/train_epoch_' + str( epoch_log - args.save_freq * 2) + '.pth' os.remove(deletename) if args.evaluate: loss_val, mIoU_val, mAcc_val, allAcc_val = validate( val_loader, model, criterion) if main_process(): writer.add_scalar('loss_val', loss_val, epoch_log) writer.add_scalar('mIoU_val', mIoU_val, epoch_log) writer.add_scalar('mAcc_val', mAcc_val, epoch_log) writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
def main( config_name, weights_url='https://github.com/deepparrot/semseg/releases/download/0.1/pspnet50-ade20k.pth', weights_name='pspnet50-ade20k.pth'): args = config.load_cfg_from_cfg_file(config_name) check(args) os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( str(x) for x in args.test_gpu) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] gray_folder = os.path.join(args.save_folder, 'gray') color_folder = os.path.join(args.save_folder, 'color') args.data_root = './.data/vision/ade20k' args.val_list = './.data/vision/ade20k/validation.txt' args.test_list = './.data/vision/ade20k/validation.txt' print(args.data_root) test_transform = transform.Compose([transform.ToTensor()]) test_data = dataset.SemData(split=args.split, data_root=args.data_root, data_list=args.test_list, transform=test_transform) index_start = args.index_start if args.index_step == 0: index_end = len(test_data.data_list) else: index_end = min(index_start + args.index_step, len(test_data.data_list)) test_data.data_list = test_data.data_list[index_start:index_end] test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) colors = np.loadtxt(args.colors_path).astype('uint8') names = [] if not args.has_prediction: if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True local_checkpoint, _ = urllib.request.urlretrieve( weights_url, weights_name) if os.path.isfile(local_checkpoint): checkpoint = torch.load(local_checkpoint) model.load_state_dict(checkpoint['state_dict'], strict=False) else: raise RuntimeError( "=> no checkpoint found at '{}'".format(local_checkpoint)) test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, gray_folder, color_folder, colors) if args.split != 'test': cal_acc(test_data.data_list, gray_folder, args.classes, names)
def main(): global args, logger args = get_parser() # check(args) logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gen_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] gray_folder = os.path.join(args.save_folder.replace('ss', 'video'), 'gray') test_transform = transform.Compose( [transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) test_data = dataset.SemData( split='test', data_root=args.data_root, data_list='./data/list/cityscapes/val_video_img_sam.lst', transform=test_transform) index_start = args.index_start if args.index_step == 0: index_end = len(test_data.data_list) else: index_end = min(index_start + args.index_step, len(test_data.data_list)) test_data.data_list = test_data.data_list[index_start:index_end] test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_gen, shuffle=False, num_workers=args.workers, pin_memory=True) colors = np.loadtxt(args.colors_path).astype('uint8') if not args.has_prediction: if args.arch == 'psp': from model.origin_pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif args.arch == 'psp18': from model.pspnet_18 import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, flow=False, pretrained=False) elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False) elif args.arch == 'mobile': from model.mobile import DenseASPP model = DenseASPP(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, flow=False) elif args.arch == 'antipsp18': from model.antipspnet18 import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, flow=False) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.ckpt_path): logger.info("=> loading checkpoint '{}'".format(args.ckpt_path)) checkpoint = torch.load(args.ckpt_path) student_ckpt = transfer_ckpt(checkpoint) a, b = model.load_state_dict(student_ckpt, strict=False) print('unexpected keys:', a) print('missing keys:', b) logger.info("=> loaded checkpoint '{}'".format(args.ckpt_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.ckpt_path)) test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, 1024, 2048, args.scales, gray_folder, colors)
def main(): global args, logger args = get_parser() if args.test_in_nyu_label_space: args.colors_path = 'nyu/nyu_colors.txt' args.names_path = 'nyu/nyu_names.txt' if args.if_cluster: args.data_root = args.data_root_cluster args.project_path = args.project_path_cluster args.data_config_path = 'data' for key in ['train_list', 'val_list', 'test_list', 'colors_path', 'names_path']: args[key] = os.path.join(args.data_config_path, args[key]) for key in ['save_path', 'model_path', 'save_folder']: args[key] = os.path.join(args.project_path, args[key]) # for key in ['save_path', 'model_path', 'save_folder']: # args[key] = args[key] % args.exp_name check(args) logger = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] gray_folder = os.path.join(args.save_folder, 'gray') color_folder = os.path.join(args.save_folder, 'color') transform_list_test = [] if args.resize: transform_list_test.append(transform.Resize((args.resize_h_test, args.resize_w_test))) transform_list_test += [ transform.Crop([args.test_h, args.test_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ] test_transform = transform.Compose(transform_list_test) test_data = dataset.SemData(split=args.split, data_root=args.data_root, data_list=args.test_list, transform=test_transform, is_master=True, args=args) # test_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=test_transform, is_master=True, args=args) index_start = args.index_start if args.index_step == 0: index_end = len(test_data.data_list) else: index_end = min(index_start + args.index_step, len(test_data.data_list)) test_data.data_list = test_data.data_list[index_start:index_end] test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) colors = np.loadtxt(args.colors_path).astype('uint8') names = [line.rstrip('\n') for line in open(args.names_path)] args.read_image = test_data.read_image if not args.has_prediction: if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False) elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False) logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=True) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path)) pred_path_list, target_path_list = test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, gray_folder, color_folder, colors) if args.split != 'test' or (args.split == 'test' and args.test_has_gt): cal_acc(test_data.data_list, gray_folder, args.classes, names, pred_path_list=pred_path_list, target_path_list=target_path_list)
class PSPNetSematicSegmentation(object): def __init__(self, config_file=CONFIG_FILE): # Load Parameters self.args_ = config.load_cfg_from_cfg_file(config_file) self.logger_ = get_logger() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( str(x) for x in self.args_.test_gpu) value_scale = 255 mean = [0.485, 0.456, 0.406] self.mean_ = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] self.std_ = [item * value_scale for item in std] # self.colors_ = np.loadtxt(self.args_.colors_path).astype('uint8') # Load Model if self.args_.arch == 'psp': from model.pspnet import PSPNet self.model_ = PSPNet(layers=self.args_.layers, classes=self.args_.classes, zoom_factor=self.args_.zoom_factor, pretrained=False) elif self.args_.arch == 'psa': from model.psanet import PSANet self.model_ = PSANet( layers=self.args_.layers, classes=self.args_.classes, zoom_factor=self.args_.zoom_factor, compact=self.args_.compact, shrink_factor=self.args_.shrink_factor, mask_h=self.args_.mask_h, mask_w=self.args_.mask_w, normalization_factor=self.args_.normalization_factor, psa_softmax=self.args_.psa_softmax, pretrained=False) self.model_ = torch.nn.DataParallel(self.model_).cuda() cudnn.benchmark = True if os.path.isfile(self.args_.model_path): self.logger_ = get_logger().info( "=> loading checkpoint '{}'".format(self.args_.model_path)) checkpoint = torch.load(self.args_.model_path) self.model_.load_state_dict(checkpoint['state_dict'], strict=False) self.logger_ = get_logger().info( "=> loaded checkpoint '{}'".format(self.args_.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( self.args_.model_path)) def get_label_colors(self, driveable=True): if (driveable >= 0): colors = [ # [ 0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0] ] colors[driveable] = [55, 195, 55] else: colors = [ # [ 0, 0, 0], [128, 64, 128], [244, 35, 232], [70, 70, 70], [55, 195, 55], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152], [0, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32], ] return dict(zip(range(19), colors)) def net_process(self, image, flip=True): input = torch.from_numpy(image.transpose((2, 0, 1))).float() if self.std_ is None: for t, m in zip(input, self.mean_): t.sub_(m) else: for t, m, s in zip(input, self.mean_, self.std_): t.sub_(m).div_(s) input = input.unsqueeze(0).cuda() if flip: input = torch.cat([input, input.flip(3)], 0) with torch.no_grad(): output = self.model_(input) _, _, h_i, w_i = input.shape _, _, h_o, w_o = output.shape if (h_o != h_i) or (w_o != w_i): output = F.interpolate(output, (h_i, w_i), mode='bilinear', align_corners=True) output = F.softmax(output, dim=1) if flip: output = (output[0] + output[1].flip(2)) / 2 else: output = output[0] output = output.data.cpu().numpy() output = output.transpose(1, 2, 0) return output def scale_process(self, model, image, classes, crop_h, crop_w, h, w, mean, std=None, stride_rate=2 / 3): ori_h, ori_w, _ = image.shape pad_h = max(crop_h - ori_h, 0) pad_w = max(crop_w - ori_w, 0) pad_h_half = int(pad_h / 2) pad_w_half = int(pad_w / 2) if pad_h > 0 or pad_w > 0: image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=mean) new_h, new_w, _ = image.shape stride_h = int(np.ceil(crop_h * stride_rate)) stride_w = int(np.ceil(crop_w * stride_rate)) grid_h = int(np.ceil(float(new_h - crop_h) / stride_h) + 1) grid_w = int(np.ceil(float(new_w - crop_w) / stride_w) + 1) prediction_crop = np.zeros((new_h, new_w, classes), dtype=float) count_crop = np.zeros((new_h, new_w), dtype=float) for index_h in range(0, grid_h): for index_w in range(0, grid_w): s_h = index_h * stride_h e_h = min(s_h + crop_h, new_h) s_h = e_h - crop_h s_w = index_w * stride_w e_w = min(s_w + crop_w, new_w) s_w = e_w - crop_w image_crop = image[s_h:e_h, s_w:e_w].copy() count_crop[s_h:e_h, s_w:e_w] += 1 prediction_crop[s_h:e_h, s_w:e_w, :] += self.net_process(image_crop) prediction_crop /= np.expand_dims(count_crop, 2) prediction_crop = prediction_crop[pad_h_half:pad_h_half + ori_h, pad_w_half:pad_w_half + ori_w] prediction = cv2.resize(prediction_crop, (w, h), interpolation=cv2.INTER_LINEAR) return prediction def test(self, model, image, classes, base_size, crop_h, crop_w, scales): image = cv2.cvtColor( image, cv2.COLOR_BGR2RGB ) # convert cv2 read image from BGR order to RGB order h, w, _ = image.shape prediction = np.zeros((h, w, classes), dtype=float) for scale in scales: long_size = round(scale * base_size) new_h = long_size new_w = long_size if h > w: new_w = round(long_size / float(h) * w) else: new_h = round(long_size / float(w) * h) image_scale = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) prediction += self.scale_process(model, image_scale, classes, crop_h, crop_w, h, w, self.mean_, self.std_) prediction = self.scale_process(model, image_scale, classes, crop_h, crop_w, h, w, self.mean_, self.std_) prediction = np.argmax(prediction, axis=2) gray = np.uint8(prediction) # color = colorize(gray, colors) return gray def colorize(self, labels, label_colors): labels = np.clip(labels, 0, len(label_colors) - 1) r = labels.copy() g = labels.copy() b = labels.copy() for l in range(0, 19): r[labels == l] = label_colors[l][0] g[labels == l] = label_colors[l][1] b[labels == l] = label_colors[l][2] rgb = np.zeros((labels.shape[0], labels.shape[1], 3)) rgb[:, :, 0] = r rgb[:, :, 1] = g rgb[:, :, 2] = b return np.uint8(rgb) def process_img_driveable(self, img, size, drivable_idx=-1): img_resized = cv2.resize( img, (int(size[1]), int(size[0]))) # uint8 with RGB mode labeled_img = self.test(self.model_.eval(), img_resized, self.args_.classes, self.args_.base_size, self.args_.test_h, self.args_.test_w, self.args_.scales) segmented_img = self.colorize(labeled_img, self.get_label_colors(driveable=-1)) drivable_img = self.colorize( labeled_img, self.get_label_colors(driveable=drivable_idx)) return segmented_img, drivable_img
def main_worker(gpu, ngpus_per_node, argss): global args args = argss if args.sync_bn: if args.multiprocessing_distributed: BatchNorm = apex.parallel.SyncBatchNorm else: from lib.sync_bn.modules import BatchNorm2d BatchNorm = BatchNorm2d else: BatchNorm = nn.BatchNorm2d if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) from util.crit import CriterionDSN criterion = CriterionDSN(ignore_index=args.ignore_label) if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion, BatchNorm=BatchNorm) modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4] modules_new = [model.ppm, model.cls, model.aux] elif args.arch == 'psp18': from model.pspnet_18 import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion, BatchNorm=BatchNorm, flow=True) modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4] modules_new = [model.ppm, model.cls, model.aux] elif args.arch == 'antipsp18': from model.antipspnet18 import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion, BatchNorm=BatchNorm, flow=True) modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4] modules_new = [model.ppm, model.cls, model.aux] elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, psa_type=args.psa_type, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, criterion=criterion, BatchNorm=BatchNorm) modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4] modules_new = [model.psa, model.cls, model.aux] elif args.arch == 'mobile': from model.mobile import DenseASPP model = DenseASPP(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion, BatchNorm=BatchNorm, flow=True) modules_ori = [model.features] modules_new = [model.ppm, model.cls, model.aux] params_list = [] args.index_split = 5 if 'mobile' in args.arch: args.index_split = 1 if args.tune_weight: student_ckpt = torch.load(args.tune_weight, map_location=torch.device('cpu')) new_params = model.state_dict().copy() if 'state_dict' in student_ckpt: student_ckpt = student_ckpt['state_dict'] for i in student_ckpt: if 'module' in i or 'student' in i: new_params[i.replace('module.', '').replace('student.', '')] = student_ckpt[i] else: new_params[i] = student_ckpt[i] # print('load:', i_parts) a, b = model.load_state_dict(new_params, strict=False) print('missing keys:', a) del new_params del student_ckpt for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) for module in modules_new: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) else: for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) for module in modules_new: params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10)) optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) criterion_list = nn.ModuleDict({'ce': criterion}) from model.flow_model import FlowModel model = FlowModel(model, criterion_list, args) if main_process(): global logger, writer logger = get_logger() writer = SummaryWriter(args.save_path) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) if args.distributed: torch.cuda.set_device(gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.batch_size_val = int(args.batch_size_val / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) if args.use_apex: model, optimizer = apex.amp.initialize(model.cuda(), optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale) model = apex.parallel.DistributedDataParallel(model) else: model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu]) else: model = torch.nn.DataParallel(model.cuda()) if args.resume: if os.path.isfile(args.resume): if main_process(): logger.info("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(args.resume) checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) if main_process(): logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) else: if main_process(): logger.info("=> no checkpoint found at '{}'".format(args.resume)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = transform.Compose([ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) train_data = dataset.VideoData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform,frame_gap=args.frame_gap,random_frame=args.random_frame) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) if args.evaluate: val_transform = transform.Compose([ transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) val_data = dataset.VideoData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) if args.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 if args.distributed: train_sampler.set_epoch(epoch) loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, optimizer, epoch) if main_process(): writer.add_scalar('loss_train', loss_train, epoch_log) writer.add_scalar('mIoU_train', mIoU_train, epoch_log) writer.add_scalar('mAcc_train', mAcc_train, epoch_log) writer.add_scalar('allAcc_train', allAcc_train, epoch_log) if (epoch_log % args.save_freq == 0) and main_process(): filename = args.save_path + '/train_epoch_' + str(epoch_log) + '.pth' logger.info('Saving checkpoint to: ' + filename) torch.save({'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename) if epoch_log / args.save_freq > 2: deletename = args.save_path + '/train_epoch_' + str(epoch_log - args.save_freq * 2) + '.pth' os.remove(deletename) if args.evaluate: loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion) if main_process(): writer.add_scalar('loss_val', loss_val, epoch_log) writer.add_scalar('mIoU_val', mIoU_val, epoch_log) writer.add_scalar('mAcc_val', mAcc_val, epoch_log) writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
def main_worker(gpu, ngpus_per_node, argss): global args args = argss if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion, args=args) modules_ori = [ model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 ] modules_new = [model.ppm, model.cls, model.aux] elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, psa_type=args.psa_type, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, criterion=criterion) modules_ori = [ model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 ] modules_new = [model.psa, model.cls, model.aux] params_list = [] for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) for module in modules_new: params_list.append( dict(params=module.parameters(), lr=args.base_lr * 10)) args.index_split = 5 optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if main_process(): global logger, writer logger = get_logger() writer = SummaryWriter(args.save_path) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) else: logger = None if args.distributed: torch.cuda.set_device(gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.batch_size_val = int(args.batch_size_val / ngpus_per_node) args.workers = int( (args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu]) else: model = torch.nn.DataParallel(model.cuda()) if args.weight: if os.path.isfile(args.weight): if main_process(): logger.info("=> loading weight '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) if main_process(): logger.info("=> loaded weight '{}'".format(args.weight)) else: if main_process(): logger.info("=> no weight found at '{}'".format(args.weight)) if args.resume != 'none': if os.path.isfile(args.resume): if main_process(): logger.info("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(args.resume) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] # model.load_state_dict(checkpoint['state_dict']) # optimizer.load_state_dict(checkpoint['optimizer']) # print(checkpoint['optimizer'].keys()) if args.if_remove_cls: if main_process(): logger.info( '=====!!!!!!!===== Remove cls layer in resuming...') checkpoint['state_dict'] = { x: checkpoint['state_dict'][x] for x in checkpoint['state_dict'].keys() if ('module.cls' not in x and 'module.aux' not in x) } # checkpoint['optimizer'] = {x: checkpoint['optimizer'][x] for x in checkpoint['optimizer'].keys() if ('module.cls' not in x and 'module.aux' not in x)} # if main_process(): # print('----', checkpoint['state_dict'].keys()) # print('----', checkpoint['optimizer'].keys()) # print('----1', checkpoint['optimizer']['state'].keys()) model.load_state_dict(checkpoint['state_dict'], strict=False) if not args.if_remove_cls: optimizer.load_state_dict(checkpoint['optimizer']) if main_process(): logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: if main_process(): logger.info("=> no checkpoint found at '{}'".format( args.resume)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] transform_list_train = [] if args.resize: transform_list_train.append( transform.Resize((args.resize_h, args.resize_w))) transform_list_train += [ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ] train_transform = transform.Compose(transform_list_train) train_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.train_list, transform=train_transform, logger=logger, is_master=main_process(), args=args) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_data) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) if args.evaluate: transform_list_val = [] if args.resize: transform_list_val.append( transform.Resize((args.resize_h, args.resize_w))) transform_list_val += [ transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ] val_transform = transform.Compose(transform_list_val) val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform, is_master=main_process(), args=args) args.read_image = val_data.read_image if args.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler( val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader( val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 # if args.evaluate and args.val_every_iter == -1: # # logger.info('Validating.....') # loss_val, mIoU_val, mAcc_val, allAcc_val, return_dict = validate(val_loader, model, criterion, args) # if main_process(): # writer.add_scalar('VAL/loss_val', loss_val, epoch_log) # writer.add_scalar('VAL/mIoU_val', mIoU_val, epoch_log) # writer.add_scalar('VAL/mAcc_val', mAcc_val, epoch_log) # writer.add_scalar('VAL/allAcc_val', allAcc_val, epoch_log) # for sample_idx in range(len(return_dict['image_name_list'])): # writer.add_text('VAL-image_name/%d'%sample_idx, return_dict['image_name_list'][sample_idx], epoch) # writer.add_image('VAL-image/%d'%sample_idx, return_dict['im_list'][sample_idx], epoch, dataformats='HWC') # writer.add_image('VAL-color_label/%d'%sample_idx, return_dict['color_GT_list'][sample_idx], epoch, dataformats='HWC') # writer.add_image('VAL-color_pred/%d'%sample_idx, return_dict['color_pred_list'][sample_idx], epoch, dataformats='HWC') if args.distributed: train_sampler.set_epoch(epoch) loss_train, mIoU_train, mAcc_train, allAcc_train = train( train_loader, model, optimizer, epoch, epoch_log, val_loader, criterion) if main_process(): writer.add_scalar('TRAIN/loss_train', loss_train, epoch_log) writer.add_scalar('TRAIN/mIoU_train', mIoU_train, epoch_log) writer.add_scalar('TRAIN/mAcc_train', mAcc_train, epoch_log) writer.add_scalar('TRAIN/allAcc_train', allAcc_train, epoch_log)