def main(): global args, logger args = get_parser().parse_args() logger = get_logger() # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu) logger.info(args) assert args.classes > 1 assert args.zoom_factor in [1, 2, 4, 8] assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0 assert args.split in ['train', 'val', 'test'] logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] val_transform = transforms.Compose([ transforms.Resize((args.crop_h, args.crop_w)), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) val_data1 = datasets.SegData(split=args.split, data_root=args.data_root, data_list=args.val_list1, transform=val_transform) val_loader1 = torch.utils.data.DataLoader(val_data1, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) from pspnet import PSPNet model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, use_softmax=False, use_aux=False, pretrained=False, syncbn=False).cuda() logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.enabled = True cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) cv2.setNumThreads(0) fff = open('/mnt/lustre/share/dingmingyu/cityscapes/instance_list_new.txt' ).readlines() flag = [] for i, line in enumerate(fff): if i > 100: break img = line.strip().split()[0] img = cv2.imread(img) if 'ignore' in line.strip().split()[1]: gt = cv2.imread(line.strip().split()[2]) flag.append(1) else: flag.append(0) gt = cv2.imread(line.strip().split()[1]) cv2.imwrite('result/result_%d_gt.png' % i, gt) cv2.imwrite('result/result_%d_ori.png' % i, img) validate(val_loader1, val_data1.data_list, model, args.classes, mean, std, args.base_size1, args.crop_h, args.crop_w, flag)
def main(): global args, logger, writer args = get_parser().parse_args() import multiprocessing as mp if mp.get_start_method(allow_none=True) != 'spawn': mp.set_start_method('spawn', force=True) rank, world_size = dist_init(args.port) logger = get_logger() writer = SummaryWriter(args.save_path) #os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu) #if len(args.gpu) == 1: # args.syncbn = False if rank == 0: logger.info(args) if args.bn_group == 1: args.bn_group_comm = None else: assert world_size % args.bn_group == 0 args.bn_group_comm = simple_group_split(world_size, rank, world_size // args.bn_group) if rank == 0: logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) from pspnet import PSPNet model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, syncbn=args.syncbn, group_size=args.bn_group, group=args.bn_group_comm).cuda() logger.info(model) model_ppm = PPM().cuda() # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) # newly introduced layer with lr x10 optimizer = torch.optim.SGD( [{'params': model.layer0.parameters()}, {'params': model.layer1.parameters()}, {'params': model.layer2.parameters()}, {'params': model.layer3.parameters()}, {'params': model.layer4_ICR.parameters()}, {'params': model.layer4_PFR.parameters()}, {'params': model.layer4_PRP.parameters()}, {'params': model_ppm.cls_trans.parameters(), 'lr': args.base_lr * 10}, {'params': model_ppm.cls_quat.parameters(), 'lr': args.base_lr * 10} ], lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) #model = torch.nn.DataParallel(model).cuda() model = DistModule(model) model_ppm = DistModule(model_ppm) cudnn.enabled = True cudnn.benchmark = True criterion = nn.L1Loss().cuda() if args.weight: def map_func(storage, location): return storage.cuda() if os.path.isfile(args.weight): logger.info("=> loading weight '{}'".format(args.weight)) checkpoint = torch.load(args.weight, map_location=map_func) model.load_state_dict(checkpoint['state_dict']) logger.info("=> loaded weight '{}'".format(args.weight)) else: logger.info("=> no weight found at '{}'".format(args.weight)) if args.resume: load_state(args.resume, model, model_ppm, optimizer) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = transforms.Compose([ transforms.Resize(size=(256,256)), #transforms.RandomGaussianBlur(), transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transforms.ColorJitter([0.4,0.4,0.4]), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) train_data = datasets.SegData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) train_sampler = DistributedSampler(train_data) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler) if args.evaluate: val_transform = transforms.Compose([ transforms.Resize(size=(256,256)), transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) val_data = datasets.SegData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) val_sampler = DistributedSampler(val_data) val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) for epoch in range(args.start_epoch, args.epochs + 1): t_loss_train, r_loss_train= train(train_loader, model, model_ppm, criterion, optimizer, epoch, args.zoom_factor, args.batch_size, args.aux_weight) if rank == 0: writer.add_scalar('t_loss_train', t_loss_train, epoch) writer.add_scalar('r_loss_train', r_loss_train, epoch) # write parameters histogram costs lots of time # for name, param in model.named_parameters(): # writer.add_histogram(name, param, epoch) if epoch % args.save_step == 0 and rank == 0: filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth' logger.info('Saving checkpoint to: ' + filename) filename_ppm = args.save_path + '/train_epoch_' + str(epoch) + '_ppm.pth' torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename) torch.save({'epoch': epoch, 'state_dict': model_ppm.state_dict(), 'optimizer': optimizer.state_dict()}, filename_ppm) #if epoch / args.save_step > 2: # deletename = args.save_path + '/train_epoch_' + str(epoch - args.save_step*2) + '.pth' # os.remove(deletename) if args.evaluate: t_loss_val, r_loss_val= validate(val_loader, model, model_ppm, criterion) writer.add_scalar('t_loss_val', t_loss_val, epoch) writer.add_scalar('r_loss_val', r_loss_val, epoch) writer.close()
def main(): global args, logger args = get_parser().parse_args() logger = get_logger() # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu) logger.info(args) logger.info("=> creating model ...") value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] val_transform = transforms.Compose([ transforms.Resize(size=(256, 256)), transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) val_data1 = datasets.SegData(split=args.split, data_root=args.data_root, data_list=args.val_list1, transform=val_transform) val_loader1 = torch.utils.data.DataLoader(val_data1, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) model_pfr = PFR().cuda() model_pfr = torch.nn.DataParallel(model_pfr) model_prp = PRP().cuda() model_prp = torch.nn.DataParallel(model_prp) from pspnet import PSPNet model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, use_softmax=True, pretrained=False, syncbn=False).cuda() logger.info(model) model = torch.nn.DataParallel(model).cuda() cudnn.enabled = True cudnn.benchmark = True if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) checkpoint_pfr = torch.load(args.model_path.replace('.pth', '_pfr.pth')) checkpoint_prp = torch.load(args.model_path.replace('.pth', '_prp.pth')) model_pfr.load_state_dict(checkpoint_pfr['state_dict'], strict=False) model_prp.load_state_dict(checkpoint_prp['state_dict'], strict=False) cv2.setNumThreads(0) validate(val_loader1, val_data1.data_list, model, model_pfr, model_prp)