def get_dataset(opts): """ Dataset And Augmentation """ if opts.dataset == 'voc': train_transform = et.ExtCompose([ #et.ExtResize(size=opts.crop_size), et.ExtRandomScale((0.5, 2.0)), et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True), et.ExtRandomHorizontalFlip(), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) if opts.crop_val: val_transform = et.ExtCompose([ et.ExtResize(opts.crop_size), et.ExtCenterCrop(opts.crop_size), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) else: val_transform = et.ExtCompose([ et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='train', download=opts.download, transform=train_transform) val_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='val', download=False, transform=val_transform) if opts.dataset == 'cityscapes': train_transform = et.ExtCompose([ #et.ExtResize( 512 ), et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)), et.ExtColorJitter( brightness=0.5, contrast=0.5, saturation=0.5 ), et.ExtRandomHorizontalFlip(), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) val_transform = et.ExtCompose([ et.ExtResize( 256 ), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dst = Cityscapes(root=opts.data_root, split='train', transform=train_transform) val_dst = Cityscapes(root=opts.data_root, split='val', transform=val_transform) return train_dst, val_dst
def get_dataset(opts): """ Dataset And Augmentation """ if opts.dataset=='voc': train_transform = ExtCompose( [ ExtRandomScale((0.5, 2.0)), ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True), ExtRandomHorizontalFlip(), ExtToTensor(), ExtNormalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) if opts.crop_val: val_transform = ExtCompose([ ExtResize(size=opts.crop_size), ExtCenterCrop(size=opts.crop_size), ExtToTensor(), ExtNormalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) else: # no crop, batch size = 1 val_transform = ExtCompose([ ExtToTensor(), ExtNormalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) train_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='train', download=opts.download, transform=train_transform) val_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='val', download=False, transform=val_transform) if opts.dataset=='cityscapes': train_transform = ExtCompose( [ ExtScale(0.5), ExtRandomCrop(size=(opts.crop_size, opts.crop_size)), ExtRandomHorizontalFlip(), ExtToTensor(), ExtNormalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) val_transform = ExtCompose( [ ExtScale(0.5), ExtToTensor(), ExtNormalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) train_dst = Cityscapes(root=opts.data_root, split='train', download=opts.download, target_type='semantic', transform=train_transform) val_dst = Cityscapes(root=opts.data_root, split='test', target_type='semantic', download=False, transform=val_transform) return train_dst, val_dst
def func_per_iteration(self, data, device): img = data['data'] label = data['label'] name = data['fn'] # label = label - 1 pred = self.sliding_eval(img, config.eval_crop_size, config.eval_stride_rate, device) hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes, pred, label) results_dict = { 'hist': hist_tmp, 'labeled': labeled_tmp, 'correct': correct_tmp } if self.save_path is not None: fn = name + '.png' pred, fn = Cityscapes.transform_label(pred, fn) cv2.imwrite(os.path.join(self.save_path, fn), pred) logger.info('Save the image ' + fn) if self.show_image: colors = self.dataset.get_class_colors image = img clean = np.zeros(label.shape) comp_img = show_img(colors, config.background, image, clean, label, pred) cv2.imshow('comp_image', comp_img) cv2.waitKey(0) return results_dict
def main(): model = torch.load('unet.pth').to(device) # dataset = Cityscapes(opt.root, split='val', resize=opt.resize, crop=opt.crop) dataset = Cityscapes(opt.root, resize=opt.resize, crop=opt.crop) inputs, labels = random.choice(dataset) inputs = inputs.unsqueeze(0) labels = labels.unsqueeze(0) inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) outputs = outputs.detach() gt = classes_to_rgb(labels[0], dataset) seg = classes_to_rgb(outputs[0], dataset) fig, ax = plt.subplots(1, 3) ax[0].imshow(inputs[0].permute(1, 2, 0)) ax[1].imshow(gt.permute(1, 2, 0)) ax[2].imshow(seg.permute(1, 2, 0)) plt.show()
def inference(test_img_dir, test_gt_dir): net = ResUnet('resnet34') net.load_state_dict(torch.load('./resunet.pth')) net.eval() test_file_list = load_sem_seg(test_img_dir, test_gt_dir) test_dataset = Cityscapes(test_file_list) testloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) img, label = next(iter(testloader)) pred = net(img) pred = pred.squeeze().detach().numpy() print(pred)
def main(): if opt.is_continue: model = torch.load('unet.pth').to(device) else: model = Unet(19).to(device) dataset = Cityscapes(opt.root, resize=opt.resize, crop=opt.crop) dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=1) criterion = BCELoss().to(device) optimizer = Adam(model.parameters(), lr=0.001) t_now = time.time() for epoch in range(opt.n_epochs): print('epoch {}'.format(epoch)) for i, batch in enumerate(dataloader): inputs, labels = batch inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() if i % 100 == 0: print(loss) print('time:', time.time() - t_now) t_now = time.time() print(loss) torch.save(model, 'unet.pth')
def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) distill_criterion = nn.KLDivLoss() # data loader ########################### if config.is_test: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_eval_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } else: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } train_loader = get_train_loader(config, Cityscapes, test=config.is_test) # Model ####################################### models = [] evaluators = [] testers = [] lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load( os.path.join(config.load_path, "arch_%d.pt" % arch_idx)) else: state = torch.load( os.path.join( config.load_path, "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch)))) model = Network([ state["alpha_%d_0" % arch_idx].detach(), state["alpha_%d_1" % arch_idx].detach(), state["alpha_%d_2" % arch_idx].detach() ], [ None, state["beta_%d_1" % arch_idx].detach(), state["beta_%d_2" % arch_idx].detach() ], [ state["ratio_%d_0" % arch_idx].detach(), state["ratio_%d_1" % arch_idx].detach(), state["ratio_%d_2" % arch_idx].detach() ], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx == 0) mIoU02 = state["mIoU02"] latency02 = state["latency02"] obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"] latency12 = state["latency12"] obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), width=getattr(model, "widths%d" % b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig( os.path.join(config.save, "path_width%d.png" % arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig( os.path.join(config.save, "path_width_all%d.png" % arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), )) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) logging.info("last:" + str(model.lasts)) model = model.cuda() init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') if arch_idx == 0 and len(config.arch_idx) > 1: partial = torch.load( os.path.join(config.teacher_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) elif config.is_eval: partial = torch.load( os.path.join(config.eval_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) evaluators.append(evaluator) tester = SegTester(Cityscapes(data_setting, 'test', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) testers.append(tester) # Optimizer ################################### base_lr = config.lr if arch_idx == 1 or len(config.arch_idx) == 1: # optimize teacher solo OR student (w. distill from teacher) optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) models.append(model) # Cityscapes ########################################### if config.is_eval: logging.info(config.load_path) logging.info(config.eval_path) logging.info(config.save) # validation print("[validation...]") with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], 0) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], 0) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) exit(0) tbar = tqdm(range(config.nepochs), ncols=80) for epoch in tbar: logging.info(config.load_path) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train_mIoUs = train(train_loader, models, ohem_criterion, distill_criterion, optimizer, logger, epoch) torch.cuda.empty_cache() for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/train_teacher", train_mIoUs[idx], epoch) logging.info("teacher's train_mIoU %.3f" % (train_mIoUs[idx])) else: logger.add_scalar("mIoU/train_student", train_mIoUs[idx], epoch) logging.info("student's train_mIoU %.3f" % (train_mIoUs[idx])) adjust_learning_rate(base_lr, 0.992, optimizer, epoch + 1, config.nepochs) # validation if not config.is_test and ((epoch + 1) % 10 == 0 or epoch == 0): tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], epoch) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], epoch) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx)) # test if config.is_test and (epoch + 1) >= 250 and (epoch + 1) % 10 == 0: tbar.set_description("[Epoch %d/%d][test...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): test(epoch, models, testers, logger) for idx, arch_idx in enumerate(config.arch_idx): save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx))
default=False, action='store_true') parser.add_argument('--save_path', '-p', default=None) args = parser.parse_args() all_dev = parse_devices(args.devices) if config.is_test: eval_source = config.test_source else: eval_source = config.eval_source mp_ctx = mp.get_context('spawn') network = CPNet(config.num_classes, criterion=None) data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': eval_source } dataset = Cityscapes(data_setting, 'val', None) with torch.no_grad(): segmentor = SegEvaluator(dataset, config.num_classes, config.image_mean, config.image_std, network, config.eval_scale_array, config.eval_flip, all_dev, args.verbose, args.save_path, args.show_image) segmentor.run(config.snapshot_dir, args.epochs, config.val_log_file, config.link_val_log_file)
def train(network, train_img_dir, train_gt_dir, val_img_dir, val_gt_dir, lr, epochs, lbl_conversion=False): device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") net = ResUnet(network) net.to(device).train() optimizer = optim.SGD(net.head.parameters(), lr=lr, momentum=0.9) criterion = nn.CrossEntropyLoss() if lbl_conversion: label_conversion(train_gt_dir) label_conversion(val_gt_dir) train_file_list = load_sem_seg(train_img_dir, train_gt_dir) train_dataset = Cityscapes(train_file_list) trainloader = DataLoader(train_dataset, batch_size=25, shuffle=True, num_workers=8, pin_memory=True) val_file_list = load_sem_seg(val_img_dir, val_gt_dir) val_dataset = Cityscapes(val_file_list) valloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) check_point_path = './checkpoints/' print('Begin of the Training') for epoch in range(epochs): running_loss = 0.0 for idx, data in enumerate(trainloader): optimizer.zero_grad() inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) outputs = net(inputs) loss = criterion(outputs, labels) print(epoch, idx, loss.item()) loss.backward() optimizer.step() running_loss += loss.item() if idx % 20 == 19: # print every 20 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, idx + 1, running_loss / 2000)) running_loss = 0.0 print('End of the Training') PATH = './resunet.pth' torch.save(net.state_dict(), PATH)
def main(): args, args_text = _parse_args() # dist init torch.distributed.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:26442', world_size=1, rank=0) config.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) args.world_size = torch.distributed.get_world_size() args.local_rank = torch.distributed.get_rank() logging.info("rank: {} world_size: {}".format(args.local_rank, args.world_size)) if args.local_rank == 0: create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(config)) else: logger = None # preparation ################ np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # data loader ########################### if config.is_test: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_eval_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } else: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } with open(config.json_file, 'r') as f: model_dict = json.loads(f.read()) model = Network(model_dict["ops"], model_dict["paths"], model_dict["downs"], model_dict["widths"], model_dict["lasts"], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width) if args.local_rank == 0: logging.info("net: " + str(model)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), ), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) logging.info("last:" + str(model.lasts)) with open(os.path.join(config.save, 'args.yaml'), 'w') as f: f.write(args_text) model = model.cuda() init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') model = load_pretrain(model, config.model_path) # partial = torch.load(config.model_path) # state = model.state_dict() # pretrained_dict = {k: v for k, v in partial.items() if k in state} # state.update(pretrained_dict) # model.load_state_dict(state) eval_model = model evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, eval_model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False, show_prediction=False) tester = SegTester(Cityscapes(data_setting, 'test', None), config.num_classes, config.image_mean, config.image_std, eval_model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_prediction=False) # Cityscapes ########################################### logging.info(config.model_path) logging.info(config.save) with torch.no_grad(): if config.is_test: # test print("[test...]") with torch.no_grad(): test(0, model, tester, logger) else: # validation print("[validation...]") valid_mIoU = infer(model, evaluator, logger) logger.add_scalar("mIoU/val", valid_mIoU, 0) logging.info("Model valid_mIoU %.3f" % (valid_mIoU))
def main(): """Create the model and start the training.""" w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True gpu = args.gpu tau = torch.ones(1) * args.tau tau = tau.cuda(args.gpu) # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not args.num_classes == 19 or not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params, False) elif args.model == 'DeepLabVGG': model = DeeplabVGG(pretrained=True, num_classes=args.num_classes) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes) model_D2 = FCDiscriminator(num_classes=args.num_classes) model_D1.train() model_D1.cuda(args.gpu) model_D2.train() model_D2.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] weak_transform = transforms.Compose([ # transforms.RandomCrop(32, 4), # transforms.RandomRotation(30), # transforms.Resize(1024), transforms.ToTensor(), # transforms.Normalize(mean, std), # RandomCrop(768) ]) target_transform = transforms.Compose([ # transforms.RandomCrop(32, 4), # transforms.RandomRotation(30), # transforms.Normalize(mean, std) # transforms.Resize(1024), # transforms.ToTensor(), # RandomCrop(768) ]) label_set = GTA5( root=args.data_dir, num_cls=19, split='all', remap_labels=True, transform=weak_transform, target_transform=target_transform, scale=input_size, # crop_transform=RandomCrop(int(768*(args.scale/1024))), ) unlabel_set = Cityscapes( root=args.data_dir_target, split=args.set, remap_labels=True, transform=weak_transform, target_transform=target_transform, scale=input_size_target, # crop_transform=RandomCrop(int(768*(args.scale/1024))), ) test_set = Cityscapes( root=args.data_dir_target, split='val', remap_labels=True, transform=weak_transform, target_transform=target_transform, scale=input_size_target, # crop_transform=RandomCrop(768) ) label_loader = data.DataLoader(label_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=False) unlabel_loader = data.DataLoader(unlabel_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=False) test_loader = data.DataLoader(test_set, batch_size=2, shuffle=False, num_workers=args.num_workers, pin_memory=False) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) [model, model_D2, model_D2], [optimizer, optimizer_D1, optimizer_D2 ] = amp.initialize([model, model_D2, model_D2], [optimizer, optimizer_D1, optimizer_D2], opt_level="O1", num_losses=7) optimizer.zero_grad() optimizer_D1.zero_grad() optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() interp = Interpolate(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) interp_target = Interpolate(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) interp_test = Interpolate(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # interp_test = Interpolate(size=(1024, 2048), mode='bilinear', align_corners=True) normalize_transform = transforms.Compose([ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # labels for adversarial training source_label = 0 target_label = 1 max_mIoU = 0 total_loss_seg_value1 = [] total_loss_adv_target_value1 = [] total_loss_D_value1 = [] total_loss_con_value1 = [] total_loss_seg_value2 = [] total_loss_adv_target_value2 = [] total_loss_D_value2 = [] total_loss_con_value2 = [] hist = np.zeros((num_cls, num_cls)) # for i_iter in range(args.num_steps): for i_iter, (batch, batch_un) in enumerate( zip(roundrobin_infinite(label_loader), roundrobin_infinite(unlabel_loader))): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_con_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 loss_con_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source images, labels = batch images_orig = images images = transform_batch(images, normalize_transform) images = Variable(images).cuda(args.gpu) pred1, pred2 = model(images) pred1 = interp(pred1) pred2 = interp(pred2) loss_seg1 = loss_calc(pred1, labels, args.gpu) loss_seg2 = loss_calc(pred2, labels, args.gpu) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss: scaled_loss.backward() # loss.backward() loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size # train with target images_tar, labels_tar = batch_un images_tar_orig = images_tar images_tar = transform_batch(images_tar, normalize_transform) images_tar = Variable(images_tar).cuda(args.gpu) pred_target1, pred_target2 = model(images_tar) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1, dim=1)) D_out2 = model_D2(F.softmax(pred_target2, dim=1)) loss_adv_target1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda(args.gpu)) loss_adv_target2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda(args.gpu)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss: scaled_loss.backward() # loss.backward() loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy( ) / args.iter_size loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy( ) / args.iter_size # train with consistency loss # unsupervise phase policies = RandAugment().get_batch_policy(args.batch_size) rand_p1 = np.random.random(size=args.batch_size) rand_p2 = np.random.random(size=args.batch_size) random_dir = np.random.choice([-1, 1], size=[args.batch_size, 2]) images_aug = aug_batch_tensor(images_tar_orig, policies, rand_p1, rand_p2, random_dir) images_aug_orig = images_aug images_aug = transform_batch(images_aug, normalize_transform) images_aug = Variable(images_aug).cuda(args.gpu) pred_target_aug1, pred_target_aug2 = model(images_aug) pred_target_aug1 = interp_target(pred_target_aug1) pred_target_aug2 = interp_target(pred_target_aug2) pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() max_pred1, psuedo_label1 = torch.max(F.softmax(pred_target1, dim=1), 1) max_pred2, psuedo_label2 = torch.max(F.softmax(pred_target2, dim=1), 1) psuedo_label1 = psuedo_label1.cpu().numpy().astype(np.float32) psuedo_label1_thre = psuedo_label1.copy() psuedo_label1_thre[(max_pred1 < tau).cpu().numpy().astype( np.bool)] = 255 # threshold to don't care psuedo_label1_thre = aug_batch_numpy(psuedo_label1_thre, policies, rand_p1, rand_p2, random_dir) psuedo_label2 = psuedo_label2.cpu().numpy().astype(np.float32) psuedo_label2_thre = psuedo_label2.copy() psuedo_label2_thre[(max_pred2 < tau).cpu().numpy().astype( np.bool)] = 255 # threshold to don't care psuedo_label2_thre = aug_batch_numpy(psuedo_label2_thre, policies, rand_p1, rand_p2, random_dir) psuedo_label1_thre = Variable(psuedo_label1_thre).cuda(args.gpu) psuedo_label2_thre = Variable(psuedo_label2_thre).cuda(args.gpu) if (psuedo_label1_thre != 255).sum().cpu().numpy() > 0: # nll_loss doesn't support empty tensors loss_con1 = loss_calc(pred_target_aug1, psuedo_label1_thre, args.gpu) loss_con_value1 += loss_con1.data.cpu().numpy() / args.iter_size else: loss_con1 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu) if (psuedo_label2_thre != 255).sum().cpu().numpy() > 0: # nll_loss doesn't support empty tensors loss_con2 = loss_calc(pred_target_aug2, psuedo_label2_thre, args.gpu) loss_con_value2 += loss_con2.data.cpu().numpy() / args.iter_size else: loss_con2 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu) loss = args.lambda_con * loss_con1 + args.lambda_con * loss_con2 # proper normalization loss = loss / args.iter_size with amp.scale_loss(loss, optimizer, loss_id=2) as scaled_loss: scaled_loss.backward() # loss.backward() # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1, dim=1)) D_out2 = model_D2(F.softmax(pred2, dim=1)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda(args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda(args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 with amp.scale_loss(loss_D1, optimizer_D1, loss_id=3) as scaled_loss: scaled_loss.backward() # loss_D1.backward() with amp.scale_loss(loss_D2, optimizer_D2, loss_id=4) as scaled_loss: scaled_loss.backward() # loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1, dim=1)) D_out2 = model_D2(F.softmax(pred_target2, dim=1)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(target_label)).cuda(args.gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(target_label)).cuda(args.gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 with amp.scale_loss(loss_D1, optimizer_D1, loss_id=5) as scaled_loss: scaled_loss.backward() # loss_D1.backward() with amp.scale_loss(loss_D2, optimizer_D2, loss_id=6) as scaled_loss: scaled_loss.backward() # loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().numpy() loss_D_value2 += loss_D2.data.cpu().numpy() optimizer.step() optimizer_D1.step() optimizer_D2.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, loss_con1 = {8:.3f}, loss_con2 = {9:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2, loss_con_value1, loss_con_value2)) total_loss_seg_value1.append(loss_seg_value1) total_loss_adv_target_value1.append(loss_adv_target_value1) total_loss_D_value1.append(loss_D_value1) total_loss_con_value1.append(loss_con_value1) total_loss_seg_value2.append(loss_seg_value2) total_loss_adv_target_value2.append(loss_adv_target_value2) total_loss_D_value2.append(loss_D_value2) total_loss_con_value2.append(loss_con_value2) hist += fast_hist( labels.cpu().numpy().flatten().astype(int), torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int), num_cls) if i_iter % 10 == 0: print('({}/{})'.format(i_iter + 1, int(args.num_steps))) acc_overall, acc_percls, iu, fwIU = result_stats(hist) mIoU = np.mean(iu) per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))] per_class = np.array(per_class).flatten() print( ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class)) print('mIoU : {:0.2f}'.format(np.mean(iu))) print('fwIoU : {:0.2f}'.format(fwIU)) print('pixel acc : {:0.2f}'.format(acc_overall)) per_class = [[classes[i], acc] for i, acc in list(enumerate(acc_percls))] per_class = np.array(per_class).flatten() print( ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class)) avg_train_acc = acc_overall avg_train_loss_seg1 = np.mean(total_loss_seg_value1) avg_train_loss_adv1 = np.mean(total_loss_adv_target_value1) avg_train_loss_dis1 = np.mean(total_loss_D_value1) avg_train_loss_con1 = np.mean(total_loss_con_value1) avg_train_loss_seg2 = np.mean(total_loss_seg_value2) avg_train_loss_adv2 = np.mean(total_loss_adv_target_value2) avg_train_loss_dis2 = np.mean(total_loss_D_value2) avg_train_loss_con2 = np.mean(total_loss_con_value2) print('avg_train_acc :', avg_train_acc) print('avg_train_loss_seg1 :', avg_train_loss_seg1) print('avg_train_loss_adv1 :', avg_train_loss_adv1) print('avg_train_loss_dis1 :', avg_train_loss_dis1) print('avg_train_loss_con1 :', avg_train_loss_con1) print('avg_train_loss_seg2 :', avg_train_loss_seg2) print('avg_train_loss_adv2 :', avg_train_loss_adv2) print('avg_train_loss_dis2 :', avg_train_loss_dis2) print('avg_train_loss_con2 :', avg_train_loss_con2) writer['train'].add_scalar('log/mIoU', mIoU, i_iter) writer['train'].add_scalar('log/acc', avg_train_acc, i_iter) writer['train'].add_scalar('log1/loss_seg', avg_train_loss_seg1, i_iter) writer['train'].add_scalar('log1/loss_adv', avg_train_loss_adv1, i_iter) writer['train'].add_scalar('log1/loss_dis', avg_train_loss_dis1, i_iter) writer['train'].add_scalar('log1/loss_con', avg_train_loss_con1, i_iter) writer['train'].add_scalar('log2/loss_seg', avg_train_loss_seg2, i_iter) writer['train'].add_scalar('log2/loss_adv', avg_train_loss_adv2, i_iter) writer['train'].add_scalar('log2/loss_dis', avg_train_loss_dis2, i_iter) writer['train'].add_scalar('log2/loss_con', avg_train_loss_con2, i_iter) hist = np.zeros((num_cls, num_cls)) total_loss_seg_value1 = [] total_loss_adv_target_value1 = [] total_loss_D_value1 = [] total_loss_con_value1 = [] total_loss_seg_value2 = [] total_loss_adv_target_value2 = [] total_loss_D_value2 = [] total_loss_con_value2 = [] fig = plt.figure(figsize=(15, 15)) labels = labels[0].cpu().numpy().astype(np.float32) ax = fig.add_subplot(331) ax.imshow(print_palette(Image.fromarray(labels).convert('L'))) ax.axis("off") ax.set_title('labels') ax = fig.add_subplot(337) images = images_orig[0].cpu().numpy().transpose((1, 2, 0)) # images += IMG_MEAN ax.imshow(images) ax.axis("off") ax.set_title('datas') _, pred2 = torch.max(pred2, dim=1) pred2 = pred2[0].cpu().numpy().astype(np.float32) ax = fig.add_subplot(334) ax.imshow(print_palette(Image.fromarray(pred2).convert('L'))) ax.axis("off") ax.set_title('predicts') labels_tar = labels_tar[0].cpu().numpy().astype(np.float32) ax = fig.add_subplot(332) ax.imshow(print_palette(Image.fromarray(labels_tar).convert('L'))) ax.axis("off") ax.set_title('tar_labels') ax = fig.add_subplot(338) ax.imshow(images_tar_orig[0].cpu().numpy().transpose((1, 2, 0))) ax.axis("off") ax.set_title('tar_datas') _, pred_target2 = torch.max(pred_target2, dim=1) pred_target2 = pred_target2[0].cpu().numpy().astype(np.float32) ax = fig.add_subplot(335) ax.imshow(print_palette( Image.fromarray(pred_target2).convert('L'))) ax.axis("off") ax.set_title('tar_predicts') print(policies[0], 'p1', rand_p1[0], 'p2', rand_p2[0], 'random_dir', random_dir[0]) psuedo_label2_thre = psuedo_label2_thre[0].cpu().numpy().astype( np.float32) ax = fig.add_subplot(333) ax.imshow( print_palette( Image.fromarray(psuedo_label2_thre).convert('L'))) ax.axis("off") ax.set_title('psuedo_labels') ax = fig.add_subplot(339) ax.imshow(images_aug_orig[0].cpu().numpy().transpose((1, 2, 0))) ax.axis("off") ax.set_title('aug_datas') _, pred_target_aug2 = torch.max(pred_target_aug2, dim=1) pred_target_aug2 = pred_target_aug2[0].cpu().numpy().astype( np.float32) ax = fig.add_subplot(336) ax.imshow( print_palette(Image.fromarray(pred_target_aug2).convert('L'))) ax.axis("off") ax.set_title('aug_predicts') # plt.show() writer['train'].add_figure('image/', fig, global_step=i_iter, close=True) if i_iter % 500 == 0: loss1 = [] loss2 = [] for test_i, batch in enumerate(test_loader): images, labels = batch images_orig = images images = transform_batch(images, normalize_transform) images = Variable(images).cuda(args.gpu) pred1, pred2 = model(images) pred1 = interp_test(pred1) pred1 = pred1.detach() pred2 = interp_test(pred2) pred2 = pred2.detach() loss_seg1 = loss_calc(pred1, labels, args.gpu) loss_seg2 = loss_calc(pred2, labels, args.gpu) loss1.append(loss_seg1.item()) loss2.append(loss_seg2.item()) hist += fast_hist( labels.cpu().numpy().flatten().astype(int), torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int), num_cls) print('test') fig = plt.figure(figsize=(15, 15)) labels = labels[-1].cpu().numpy().astype(np.float32) ax = fig.add_subplot(311) ax.imshow(print_palette(Image.fromarray(labels).convert('L'))) ax.axis("off") ax.set_title('labels') ax = fig.add_subplot(313) ax.imshow(images_orig[-1].cpu().numpy().transpose((1, 2, 0))) ax.axis("off") ax.set_title('datas') _, pred2 = torch.max(pred2, dim=1) pred2 = pred2[-1].cpu().numpy().astype(np.float32) ax = fig.add_subplot(312) ax.imshow(print_palette(Image.fromarray(pred2).convert('L'))) ax.axis("off") ax.set_title('predicts') # plt.show() writer['test'].add_figure('test_image/', fig, global_step=i_iter, close=True) acc_overall, acc_percls, iu, fwIU = result_stats(hist) mIoU = np.mean(iu) per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))] per_class = np.array(per_class).flatten() print( ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class)) print('mIoU : {:0.2f}'.format(mIoU)) print('fwIoU : {:0.2f}'.format(fwIU)) print('pixel acc : {:0.2f}'.format(acc_overall)) per_class = [[classes[i], acc] for i, acc in list(enumerate(acc_percls))] per_class = np.array(per_class).flatten() print( ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class)) avg_test_loss1 = np.mean(loss1) avg_test_loss2 = np.mean(loss2) avg_test_acc = acc_overall print('avg_test_loss2 :', avg_test_loss1) print('avg_test_loss1 :', avg_test_loss2) print('avg_test_acc :', avg_test_acc) writer['test'].add_scalar('log1/loss_seg', avg_test_loss1, i_iter) writer['test'].add_scalar('log2/loss_seg', avg_test_loss2, i_iter) writer['test'].add_scalar('log/acc', avg_test_acc, i_iter) writer['test'].add_scalar('log/mIoU', mIoU, i_iter) hist = np.zeros((num_cls, num_cls)) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth')) break if max_mIoU < mIoU: max_mIoU = mIoU torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D2.pth'))
def main(pretrain=True): config.save = 'search-{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) assert type(pretrain) == bool or type(pretrain) == str update_arch = True if pretrain == True: update_arch = False logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) # Model ####################################### model = Network(config.num_classes, config.layers, ohem_criterion, Fch=config.Fch, width_mult_list=config.width_mult_list, prun_modes=config.prun_modes, stem_head_width=config.stem_head_width) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), ), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) model = model.cuda() if type(pretrain) == str: partial = torch.load(pretrain + "/weights.pt", map_location='cuda:0') state = model.state_dict() pretrained_dict = { k: v for k, v in partial.items() if k in state and state[k].size() == partial[k].size() } state.update(pretrained_dict) model.load_state_dict(state) else: init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) architect = Architect(model, config) # Optimizer ################################### base_lr = config.lr parameters = [] parameters += list(model.stem.parameters()) parameters += list(model.cells.parameters()) parameters += list(model.refine32.parameters()) parameters += list(model.refine16.parameters()) parameters += list(model.head0.parameters()) parameters += list(model.head1.parameters()) parameters += list(model.head2.parameters()) parameters += list(model.head02.parameters()) parameters += list(model.head12.parameters()) optimizer = torch.optim.SGD(parameters, lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) # lr policy ############################## lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.978) # data loader ########################### data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'down_sampling': config.down_sampling } train_loader_model = get_train_loader(config, Cityscapes, portion=config.train_portion) train_loader_arch = get_train_loader(config, Cityscapes, portion=config.train_portion - 1) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, config=config, verbose=False, save_path=None, show_image=False) if update_arch: for idx in range(len(config.latency_weight)): logger.add_scalar("arch/latency_weight%d" % idx, config.latency_weight[idx], 0) logging.info("arch_latency_weight%d = " % idx + str(config.latency_weight[idx])) tbar = tqdm(range(config.nepochs), ncols=80) valid_mIoU_history = [] FPSs_history = [] latency_supernet_history = [] latency_weight_history = [] valid_names = ["8s", "16s", "32s", "8s_32s", "16s_32s"] arch_names = {0: "teacher", 1: "student"} for epoch in tbar: logging.info(pretrain) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) logging.info("update arch: " + str(update_arch)) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train(pretrain, train_loader_model, train_loader_arch, model, architect, ohem_criterion, optimizer, lr_policy, logger, epoch, update_arch=update_arch) torch.cuda.empty_cache() lr_policy.step() # validation tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): if pretrain == True: model.prun_mode = "min" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_min_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_min_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) if len(model._width_mult_list) > 1: model.prun_mode = "max" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_max_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_max_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) model.prun_mode = "random" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar( 'mIoU/val_random_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_random_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) else: valid_mIoUss = [] FPSs = [] model.prun_mode = None for idx in range(len(model._arch_names)): # arch_idx model.arch_idx = idx valid_mIoUs, fps0, fps1 = infer(epoch, model, evaluator, logger) valid_mIoUss.append(valid_mIoUs) FPSs.append([fps0, fps1]) for i in range(5): # preds logger.add_scalar( 'mIoU/val_%s_%s' % (arch_names[idx], valid_names[i]), valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_%s_%s %.3f" % (epoch, arch_names[idx], valid_names[i], valid_mIoUs[i])) if config.latency_weight[idx] > 0: logger.add_scalar( 'Objective/val_%s_8s_32s' % arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000. / fps0), epoch) logging.info( "Epoch %d: Objective_%s_8s_32s %.3f" % (epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000. / fps0))) logger.add_scalar( 'Objective/val_%s_16s_32s' % arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000. / fps1), epoch) logging.info( "Epoch %d: Objective_%s_16s_32s %.3f" % (epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000. / fps1))) valid_mIoU_history.append(valid_mIoUss) FPSs_history.append(FPSs) if update_arch: latency_supernet_history.append(architect.latency_supernet) latency_weight_history.append(architect.latency_weight) save(model, os.path.join(config.save, 'weights.pt')) if type(pretrain) == str: # contains arch_param names: {"alphas": alphas, "betas": betas, "gammas": gammas, "ratios": ratios} for idx, arch_name in enumerate(model._arch_names): state = {} for name in arch_name['alphas']: state[name] = getattr(model, name) for name in arch_name['betas']: state[name] = getattr(model, name) for name in arch_name['ratios']: state[name] = getattr(model, name) state["mIoU02"] = valid_mIoUs[3] state["mIoU12"] = valid_mIoUs[4] if pretrain is not True: state["latency02"] = 1000. / fps0 state["latency12"] = 1000. / fps1 torch.save( state, os.path.join(config.save, "arch_%d_%d.pt" % (idx, epoch))) torch.save(state, os.path.join(config.save, "arch_%d.pt" % (idx))) if update_arch: for idx in range(len(config.latency_weight)): if config.latency_weight[idx] > 0: if (int(FPSs[idx][0] >= config.FPS_max[idx]) + int(FPSs[idx][1] >= config.FPS_max[idx])) >= 1: architect.latency_weight[idx] /= 2 elif (int(FPSs[idx][0] <= config.FPS_min[idx]) + int(FPSs[idx][1] <= config.FPS_min[idx])) > 0: architect.latency_weight[idx] *= 2 logger.add_scalar( "arch/latency_weight_%s" % arch_names[idx], architect.latency_weight[idx], epoch + 1) logging.info("arch_latency_weight_%s = " % arch_names[idx] + str(architect.latency_weight[idx]))
def main(model_name): # TODO: parse args. n_classes = 19 #batch_size = 2 batch_size = 1 #24 n_workers = 12 n_semantic_pretrain = 0 # 500 # First train only on semantics. n_epochs = 500 validation_step = 15 # TODO: implement resize as pil_transform resize = None # (256, 512) cityscapes_directory = "/home/<someuser>/cityscapes" output_directory = "tmp/" # Uncomment next line when you've set all directories. raise ValueError("Please set the input/output directories.") checkpoint = None #checkpoint = ( # "weights/...pth", # <fill in epoch>) # --- Setup loss functions. classification_loss = nn.CrossEntropyLoss(ignore_index=255) regression_loss = nn.MSELoss(reduction='elementwise_mean') print("--- Load model.") if model_name == 'DRNRegressionDownsampled': classification_loss = None regression_loss = nn.MSELoss(reduction='elementwise_mean') dataset_kwargs = { 'pil_transforms': None, 'gt_pil_transforms': [ModeDownsample(8)], 'fit_gt_pil_transforms': [transforms.Resize(size=(784 // 8, 1792 // 8), interpolation=2)], 'input_transforms': [ transforms.Normalize(mean=[0.290101, 0.328081, 0.286964], std=[0.182954, 0.186566, 0.184475]) ], 'tensor_transforms': None } model = DRNRegressionDownsampled( model_name='drn_d_22', classes=n_classes, pretrained_dict=torch.load('./weights/drn_d_22_cityscapes.pth')) model.cuda() parameters = model.parameters() else: raise ValueError("Model \"{}\" not found!".format(model_name)) optimizer = optim.Adam(parameters) start_epoch = 0 if checkpoint is not None: print("Loading from checkpoint {}".format(checkpoint)) model.load_state_dict(torch.load(checkpoint[0])) optimizer.load_state_dict(torch.load(checkpoint[1])) start_epoch = checkpoint[2] + 1 print("--- Setup dataset and dataloaders.") train_set = Cityscapes(data_split='subtrain', cityscapes_directory=cityscapes_directory, **dataset_kwargs) train_loader = data.DataLoader(train_set, batch_size=batch_size, num_workers=n_workers, shuffle=True) val_set = Cityscapes(data_split='subtrainval', cityscapes_directory=cityscapes_directory, **dataset_kwargs) val_loader = data.DataLoader(val_set, batch_size=batch_size, num_workers=n_workers) # Sample 10 validation indices for visualization. #validation_idxs = np.random.choice(np.arange(len(val_set)), # size=min(9, len(val_set)), # replace=False) # Nah, let's pick them ourselves for now. validation_idxs = [17, 241, 287, 304, 123, 458, 1, 14, 139, 388] if True: print("--- Setup visual validation.") # Save them for comparison. check_mkdir('{}/validationimgs'.format(output_directory)) check_mkdir('{}/offsets_gt'.format(output_directory)) check_mkdir('{}/semantic_gt'.format(output_directory)) for validation_idx in validation_idxs: img_pil, _, _ = val_set.load_fit_gt_PIL_images(validation_idx) img, semantic_gt, offset_gt = val_set[validation_idx] img_pil.save("{}/validationimgs/id{:03}.png".format( output_directory, validation_idx)) visualize_semantics( img_pil, semantic_gt, "{}/semantic_gt/id{:03}".format(output_directory, validation_idx)) visualize_positionplusoffset(offset_gt, "{}/offsets_gt/id{:03}_mean".format( output_directory, validation_idx), groundtruth=offset_gt) visualize_offsethsv( offset_gt, "{}/offsets_gt/id{:03}".format(output_directory, validation_idx)) print("--- Training.") rlosses = [] closses = [] for epoch in range(start_epoch, n_epochs): model.train() total_rloss = 0 total_closs = 0 for batch_idx, batch_data in enumerate(train_loader): img = batch_data[0].cuda() semantic_gt = batch_data[1].cuda() instance_offset_gt = batch_data[2].cuda() del batch_data optimizer.zero_grad() outputs = model(img) batch_rloss = 0 batch_closs = 0 loss = 0 closs = 0 rloss = 0 if regression_loss is not None: predicted_offset = outputs[:, -2:] rloss = regression_loss(predicted_offset, instance_offset_gt) batch_rloss += int(rloss.detach().cpu()) total_rloss += batch_rloss loss += rloss if classification_loss is not None: closs = classification_loss(outputs[:, :n_classes], semantic_gt) batch_closs += int(closs.detach().cpu()) total_closs += batch_closs loss += closs loss.backward() optimizer.step() if batch_idx % 30 == 0 and batch_idx != 0: print('\t[batch {}/{}], [batch mean - closs {:5}, rloss {:5}]'. format(batch_idx, len(train_loader), batch_closs / img.size(0), batch_rloss / img.size(0))) del img, semantic_gt, instance_offset_gt, outputs, rloss, closs, loss total_closs /= len(train_set) total_rloss /= len(train_set) print('[epoch {}], [mean train - closs {:5}, rloss {:5}]'.format( epoch, total_closs, total_rloss)) rlosses.append(total_rloss) closses.append(total_closs) plt.plot(np.arange(start_epoch, epoch + 1), rlosses) plt.savefig('{}/rlosses.svg'.format(output_directory)) plt.close('all') plt.plot(np.arange(start_epoch, epoch + 1), closses) plt.savefig('{}/closses.svg'.format(output_directory)) plt.close('all') plt.plot(np.arange(start_epoch, epoch + 1), np.add(rlosses, closses)) plt.savefig('{}/losses.svg'.format(output_directory)) plt.close('all') # --- Visual validation. if (epoch % validation_step) == 0: # Save model parameters. check_mkdir('{}/models'.format(output_directory)) torch.save( model.state_dict(), '{}/models/Net_epoch{}.pth'.format(output_directory, epoch)) torch.save( optimizer.state_dict(), '{}/models/Adam_epoch{}.pth'.format(output_directory, epoch)) # Visualize validation imgs. check_mkdir('{}/offsets'.format(output_directory)) check_mkdir('{}/offsets/means'.format(output_directory)) check_mkdir('{}/semantics'.format(output_directory)) check_mkdir('{}/semantics/overlay'.format(output_directory)) model.eval() for validation_idx in validation_idxs: img_pil, _, _ = val_set.load_PIL_images(validation_idx) img, _, offset_gt = val_set[validation_idx] img = img.unsqueeze(0).cuda() with torch.no_grad(): outputs = model(img) epoch_filename = 'id{:03}_epoch{:05}'\ .format(validation_idx, epoch) if classification_loss is not None: visualize_semantics( img_pil, outputs, "{}/semantics/{}".format(output_directory, epoch_filename), "{}/semantics/overlay/{}".format( output_directory, epoch_filename)) if regression_loss is not None: visualize_offsethsv( outputs.detach(), "{}/offsets/{}".format(output_directory, epoch_filename)) visualize_positionplusoffset(outputs, "{}/offsets/means/{}".format( output_directory, epoch_filename), groundtruth=offset_gt)
def get_dataset(opts): """ Dataset And Augmentation """ if opts.dataset == 'voc': train_transform = et.ExtCompose([ #et.ExtResize(size=opts.crop_size), et.ExtRandomScale((0.5, 2.0)), et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True), et.ExtRandomHorizontalFlip(), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) if opts.crop_val: val_transform = et.ExtCompose([ et.ExtResize(opts.crop_size), et.ExtCenterCrop(opts.crop_size), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) else: val_transform = et.ExtCompose([ et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='train', download=opts.download, transform=train_transform) val_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='val', download=False, transform=val_transform) if opts.dataset == 'cityscapes': train_transform = et.ExtCompose([ #et.ExtResize( 512 ), et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)), et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), et.ExtRandomHorizontalFlip(), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) val_transform = et.ExtCompose([ #et.ExtResize( 512 ), et.ExtToTensor(), et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dst = Cityscapes(root=opts.data_root, split='train', transform=train_transform) val_dst = Cityscapes(root=opts.data_root, split='val', transform=val_transform) if opts.dataset == 'weedcluster': train_dst = WeedClusterDataset(root=opts.data_root, split='train') val_dst = WeedClusterDataset(root=opts.data_root, split='val') if opts.dataset == 'cloudshadow': train_dst = CloudShadowDataset(root=opts.data_root, split='train') val_dst = CloudShadowDataset(root=opts.data_root, split='val') if opts.dataset == 'doubleplant': train_dst = DoublePlantDataset(root=opts.data_root, split='train') val_dst = DoublePlantDataset(root=opts.data_root, split='val') if opts.dataset == 'planterskip': train_dst = PlanterSkipDataset(root=opts.data_root, split='train') val_dst = PlanterSkipDataset(root=opts.data_root, split='val') if opts.dataset == 'standingwater': train_dst = StandingWaterDataset(root=opts.data_root, split='train') val_dst = StandingWaterDataset(root=opts.data_root, split='val') if opts.dataset == 'waterway': train_dst = WaterwayDataset(root=opts.data_root, split='train') val_dst = WaterwayDataset(root=opts.data_root, split='val') return train_dst, val_dst
def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) # data loader ########################### data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'down_sampling': config.down_sampling } # Model ####################################### models = [] evaluators = [] lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load( os.path.join(config.load_path, "arch_%d.pt" % arch_idx)) else: state = torch.load( os.path.join( config.load_path, "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch)))) model = Network([ state["alpha_%d_0" % arch_idx].detach(), state["alpha_%d_1" % arch_idx].detach(), state["alpha_%d_2" % arch_idx].detach() ], [ None, state["beta_%d_1" % arch_idx].detach(), state["beta_%d_2" % arch_idx].detach() ], [ state["ratio_%d_0" % arch_idx].detach(), state["ratio_%d_1" % arch_idx].detach(), state["ratio_%d_2" % arch_idx].detach() ], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx == 0) mIoU02 = state["mIoU02"] latency02 = state["latency02"] obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"] latency12 = state["latency12"] obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) # logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), width=getattr(model, "widths%d" % b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig( os.path.join(config.save, "path_width%d.png" % arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig( os.path.join(config.save, "path_width_all%d.png" % arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), ), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) logging.info("last:" + str(model.lasts)) model = model.cuda() init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') partial = torch.load( os.path.join(config.eval_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=os.path.join( config.save, 'predictions'), show_image=True, show_prediction=True) evaluators.append(evaluator) models.append(model) # Cityscapes ########################################### logging.info(config.load_path) logging.info(config.eval_path) logging.info(config.save) with torch.no_grad(): # validation print("[validation...]") valid_mIoUs = infer(models, evaluators, logger=None) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx]))
def main(model_name, initial_validation): # TODO: parse args. # --- Tunables. # 32GB DRNDSOffsetDisparity, cropped -> 18 # 12GB DRNDSOffsetDisparity, cropped -> 6 # 12GB DRNOffsetDisparity, cropped -> 4 # 12GB DRNOffsetDisparity, original -> 3 # 12GB DRNDSOffsetDisparity, original -> not supported yet: # resize is based on resolution 1792x784 batch_size = 6 # 6 n_workers = 21 n_semantic_pretrain = 0 # 500 # First train only on semantics. n_epochs = 500 validation_step = 5 train_split = 'subtrain' # 'train' val_split = 'subtrainval' # 'val' validate_on_train = False # Note: this doesn't include semantic performance. train_set_length = 24 # 24 # None #cityscapes_directory = "/home/thehn/cityscapes/original" cityscapes_directory = "/home/thehn/cityscapes/cropped_cityscapes" #cityscapes_directory = "/data/Cityscapes" drn_name = 'drn_d_22' # 'drn_d_22' 'drn_d_38' weights = None if 'SL' in model_name: weights = { 'offset_mean_weight': 1e-5, #1e-3 'offset_variance_weight': 1e-4, # 1e-3 'disparity_mean_weight': 1e-7, #1e-3 'disparity_variance_weight': 1e-4 } # 1e-3 output_directory = "tmp/train/{}".format(model_name) #output_directory = "tmp/train/{}_{}"\ # .format(model_name, time.strftime('%m%d-%H%M')) #output_directory = "tmp/train_test" #output_directory = "tmp/train_combined" #raise ValueError("Please set the input/output directories.") print("batch_size =", batch_size) print("train_split =", train_split) print("val_split =", val_split) print(locals()) check_mkdir(output_directory) checkpoint = None check_mkdir(output_directory) checkpoint = None #checkpoint = ( # "/home/thomashehn/Code/box2pix/tmp/train/models/Net_epoch6.pth", # "/home/thomashehn/Code/box2pix/tmp/train/models/Adam_epoch6.pth", # 6) n_classes = 19 mdl = ModelWrapper(model_name, n_classes, weights, drn_name) #for param in parameters: # param.require_grad = False #parameters = [] # weight_decay=1e-6 seems to work so far, but would need more finetuning optimizer = optim.Adam(mdl.parameters, weight_decay=1e-6) start_epoch = 1 if checkpoint is not None: print("Loading from checkpoint {}".format(checkpoint)) mdl.NN.load_state_dict(torch.load(checkpoint[0])) optimizer.load_state_dict(torch.load(checkpoint[1])) start_epoch = checkpoint[2] + 1 mdl.NN, optimizer = amp.initialize(mdl.NN, optimizer, opt_level="O1") # O0, DRNDSDoubleSegSL, bs 6, cropped, 2 epochs -> 11949MB memory, time real 19m34.788s # O1, DRNDSDoubleSegSL, bs 6, cropped, 2 epochs -> 7339MB memory, time real 10m32.431s # O0, DRNDSOffsetDisparity, bs 6, cropped, 2 epochs -> 11875MB memory, time real 18m13.491s # O1, DRNDSOffsetDisparity, bs 6, cropped, 2 epochs -> 7259MB memory, time real 8m51.849s # O0, DRNDSOffsetDisparity, bs 7, cropped, 2 epochs -> memory error # O1, DRNDSOffsetDisparity, bs 7, cropped, 2 epochs -> 8701MB memory, time real 9m13.947s # O2, DRNDSOffsetDisparity, bs 7, cropped, 2 epochs -> 8721MB memory, time real 9m8.563s # O3, DRNDSOffsetDisparity, bs 7, cropped, 2 epochs -> 8693MB memory, time real 9m7.476s print("--- Setup dataset and dataloaders.") mdl.train_set =\ Cityscapes(mdl.types, data_split=train_split, length=train_set_length, cityscapes_directory=cityscapes_directory, **mdl.dataset_kwargs) element = mdl.train_set[0] mdl.train_loader = data.DataLoader(mdl.train_set, batch_size=batch_size, pin_memory=True, num_workers=n_workers, shuffle=True) if not validate_on_train: mdl.val_set = Cityscapes(mdl.types, data_split=val_split, cityscapes_directory=cityscapes_directory, **mdl.val_dataset_kwargs) else: mdl.val_set =\ Cityscapes(mdl.types, data_split=train_split, length=train_set_length, cityscapes_directory=cityscapes_directory, **mdl.val_dataset_kwargs) mdl.val_loader = data.DataLoader(mdl.val_set, batch_size=batch_size, shuffle=False, num_workers=n_workers) # Sample 10 validation indices for visualization. #validation_idxs = np.random.choice(np.arange(len(val_set)), # size=min(9, len(val_set)), # replace=False) # Nah, let's pick them ourselves for now. #validation_idxs = [ 17, 241, 287, 304, 123, # 458, 1, 14, 139, 388] validation_idxs = [17, 1, 14] #validation_idxs = [ 53, 11, 77] metrics = { 'train': { 'classification': [], 'regression': [], 'epochs': [] }, 'validation': { 'classification': [], 'regression': [], 'semantic': [], 'epochs': [] }, 'memory': { 'max_cached': [torch.cuda.max_memory_cached()], 'max_alloc': [torch.cuda.max_memory_allocated()] } } if initial_validation: print("--- Setup visual validation.") model_file = mdl.save_model(output_directory, suffix="e0000") mdl.validation_visual(validation_idxs, output_directory, epoch=0) semantic_score =\ mdl.validation_snapshot(model_file, path.join(output_directory, 'last_prediction'), cityscapes_directory, batch_size, val_split) train_losses = mdl.compute_loss(mdl.train_loader) val_losses = mdl.compute_loss(mdl.val_loader, separate=True) print('Training loss: {:5} (c) + {:5} (r) = {:5}'.format( train_losses[0], train_losses[1], sum(train_losses))) if len(val_losses) > 5: val_dict = { 'offset_mean_loss': [val_losses[2]], 'offset_variance_loss': [val_losses[3]], 'disparity_mean_loss': [val_losses[4]], 'disparity_variance_loss': [val_losses[5]] } print('Validation loss: {:5} (c) + {:5} (r) = {:5}'.format( val_losses[0], val_losses[1], sum(val_losses[:2]))) metrics = { 'train': { 'classification': [train_losses[0]], 'regression': [train_losses[1]], 'epochs': [start_epoch - 1] }, 'validation': { 'classification': [val_losses[0]], 'regression': [val_losses[1]], 'semantic': [semantic_score], 'epochs': [start_epoch - 1], **val_dict }, 'memory': { 'max_cached': [torch.cuda.max_memory_cached()], 'max_alloc': [torch.cuda.max_memory_allocated()] } } print("--- Training.") # First train semantic loss for a while. #~regression_loss_stash = None #~if n_semantic_pretrain > 0 and regression_loss is not None: #~ regression_loss_stash = regression_loss #~ regression_loss = None #upscale = lambda x: nn.functional.interpolate(x, # scale_factor=2, # mode='bilinear', # align_corners=True) for epoch in range(start_epoch, n_epochs + 1): #~if epoch >= n_semantic_pretrain and regression_loss_stash is not None: #~ regression_loss = regression_loss_stash #~ regression_loss_stash = None #~if epoch == 10 and False: #~ parameters = model.parameters() #~ model.train_all = True #~ optimizer = optim.Adam(parameters) mdl.NN.train() total_rloss = 0 total_closs = 0 t_sum_batch = 0 t_sum_opt = 0 for batch_idx, batch_data in enumerate(mdl.train_loader): optimizer.zero_grad() batch_losses = mdl.batch_loss(batch_data) loss = sum(batch_losses) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() #loss.backward() optimizer.step() if batch_idx % 30 == 0 and batch_idx != 0: print('\t[batch {}/{}], [batch mean - closs {:5}, rloss {:5}]'. format(batch_idx, len(mdl.train_loader), float(batch_losses[0]) / batch_data[0].size(0), float(batch_losses[1]) / batch_data[0].size(0))) total_closs += float(batch_losses[0]) total_rloss += float(batch_losses[1]) del loss, batch_data, batch_losses total_closs /= len(mdl.train_set) total_rloss /= len(mdl.train_set) print('[epoch {}], [mean train - closs {:5}, rloss {:5}]'.format( epoch, total_closs, total_rloss)) metrics['train']['classification'].append(total_closs) metrics['train']['regression'].append(total_rloss) metrics['train']['epochs'].append(epoch) metrics['memory']['max_cached'].append(torch.cuda.max_memory_cached()) metrics['memory']['max_alloc'].append( torch.cuda.max_memory_allocated()) # --- Visual validation. if (epoch % validation_step) == 0: print("--- Validation.") mdl.validation_visual(validation_idxs, output_directory, epoch) model_file = mdl.save_model(output_directory, suffix="{:04}".format(epoch)) metrics['validation']['semantic'].append( mdl.validation_snapshot( model_file, path.join(output_directory, 'last_prediction'), cityscapes_directory, batch_size, val_split)) val_losses = mdl.compute_loss(mdl.val_loader, separate=True) if len(val_losses) > 5: if 'offset_mean_loss' not in metrics['validation'].keys(): val_dict = { 'offset_mean_loss': [val_losses[2]], 'offset_variance_loss': [val_losses[3]], 'disparity_mean_loss': [val_losses[4]], 'disparity_variance_loss': [val_losses[5]] } metrics['validation'] = { **metrics['validation'], **val_dict } else: metrics['validation']['offset_mean_loss']\ .append(val_losses[2]) metrics['validation']['offset_variance_loss']\ .append(val_losses[3]) metrics['validation']['disparity_mean_loss']\ .append(val_losses[4]) metrics['validation']['disparity_variance_loss']\ .append(val_losses[5]) print('Separate validation losses: {:5}, {:5}, {:5}, {:5}'. format(*val_losses[2:])) metrics['validation']['classification'].append(val_losses[0]) metrics['validation']['regression'].append(val_losses[1]) metrics['validation']['epochs'].append(epoch) print('Validation loss: {:5} (c) + {:5} (r) = {:5}'.format( val_losses[0], val_losses[1], sum(val_losses[:2]))) # --- Write losses to disk. with open(path.join(output_directory, "metrics.json"), 'w') as outfile: json.dump(metrics, outfile) for key in metrics.keys(): data_set = key set_metrics = metrics[data_set] plot_losses(set_metrics, "{}/{}".format(output_directory, data_set))
sin_outputs = F.softmax(model(img.to(device)), dim=1).detach().cpu().numpy() # 融合 alpha = 0.5 preds = np.concatenate(preds, 1) if dataset.lower() != 'cityscapes': background_p = np.expand_dims(sin_outputs[:, 0, :, :], axis=1) # 抽取单类模型预测的背景概率 preds = np.concatenate((background_p, preds), 1) # 单类模型与多雷模型分数融合 final_preds = alpha * preds + sin_outputs preds = np.argmax(final_preds, axis=1) if dataset == 'voc': pred = voc_cmap()[preds.squeeze(axis=0)].astype(np.uint8) else: pred = Cityscapes.decode_target(preds.squeeze(axis=0)).astype( np.uint8) Image.fromarray(pred).save(result_path) print("Prediction is saved in %s" % result_path) else: # 模型加载 model = model_map[model_name](num_classes=sin_model_class, output_stride=16) weights = torch.load(ckpt_path)["model_state"] model.load_state_dict(weights) model.to(device) model.eval() with torch.no_grad(): print(img.shape) img = img.to(device)