def main(): """Create the model and start the training.""" if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) writer = SummaryWriter(args.snapshot_dir) gpus = [int(i) for i in args.gpu.split(',')] if not args.gpu == 'None': os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu h, w = map(int, args.input_size.split(',')) input_size = [h, w] cudnn.enabled = True # cudnn related setting cudnn.benchmark = True torch.backends.cudnn.deterministic = False torch.backends.cudnn.enabled = True deeplab = Res_Deeplab(num_classes=args.num_classes) # dump_input = torch.rand((args.batch_size, 3, input_size[0], input_size[1])) # writer.add_graph(deeplab.cuda(), dump_input.cuda(), verbose=False) saved_state_dict = torch.load(args.restore_from) new_params = deeplab.state_dict().copy() for i in saved_state_dict: i_parts = i.split('.') # print(i_parts) if not i_parts[0] == 'fc': new_params['.'.join(i_parts[0:])] = saved_state_dict[i] deeplab.load_state_dict(new_params) model = DataParallelModel(deeplab) model.cuda() criterion = CriterionAll() criterion = DataParallelCriterion(criterion) criterion.cuda() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainloader = data.DataLoader(LIPDataSet(args.data_dir, args.dataset, crop_size=input_size, transform=transform), batch_size=args.batch_size * len(gpus), shuffle=True, num_workers=2, pin_memory=True) #lip_dataset = LIPDataSet(args.data_dir, 'val', crop_size=input_size, transform=transform) #num_samples = len(lip_dataset) #valloader = data.DataLoader(lip_dataset, batch_size=args.batch_size * len(gpus), # shuffle=False, pin_memory=True) optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() total_iters = args.epochs * len(trainloader) for epoch in range(args.start_epoch, args.epochs): model.train() for i_iter, batch in enumerate(trainloader): i_iter += len(trainloader) * epoch lr = adjust_learning_rate(optimizer, i_iter, total_iters) images, labels, edges, _ = batch labels = labels.long().cuda(non_blocking=True) edges = edges.long().cuda(non_blocking=True) preds = model(images) loss = criterion(preds, [labels, edges]) optimizer.zero_grad() loss.backward() optimizer.step() if i_iter % 100 == 0: writer.add_scalar('learning_rate', lr, i_iter) writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter) if i_iter % 500 == 0: images_inv = inv_preprocess(images, args.save_num_images) labels_colors = decode_parsing(labels, args.save_num_images, args.num_classes, is_pred=False) edges_colors = decode_parsing(edges, args.save_num_images, 2, is_pred=False) if isinstance(preds, list): preds = preds[0] preds_colors = decode_parsing(preds[0][-1], args.save_num_images, args.num_classes, is_pred=True) pred_edges = decode_parsing(preds[1][-1], args.save_num_images, 2, is_pred=True) img = vutils.make_grid(images_inv, normalize=False, scale_each=True) lab = vutils.make_grid(labels_colors, normalize=False, scale_each=True) pred = vutils.make_grid(preds_colors, normalize=False, scale_each=True) edge = vutils.make_grid(edges_colors, normalize=False, scale_each=True) pred_edge = vutils.make_grid(pred_edges, normalize=False, scale_each=True) writer.add_image('Images/', img, i_iter) writer.add_image('Labels/', lab, i_iter) writer.add_image('Preds/', pred, i_iter) writer.add_image('Edges/', edge, i_iter) writer.add_image('PredEdges/', pred_edge, i_iter) print('iter = {} of {} completed, loss = {}'.format( i_iter, total_iters, loss.data.cpu().numpy())) torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'LIP_epoch_' + str(epoch) + '.pth')) #parsing_preds, scales, centers = valid(model, valloader, input_size, num_samples, len(gpus)) #mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size) #print(mIoU) #writer.add_scalars('mIoU', mIoU, epoch) end = timeit.default_timer() print(end - start, 'seconds')
def main(): """Create the model and start the training.""" if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) h, w = map(int, args.input_size.split(',')) input_size = [h, w] best_f1 = 0 torch.cuda.set_device(args.local_rank) try: world_size = int(os.environ['WORLD_SIZE']) distributed = world_size > 1 except: distributed = False world_size = 1 if distributed: dist.init_process_group(backend=args.dist_backend, init_method='env://') rank = 0 if not distributed else dist.get_rank() writer = SummaryWriter(osp.join(args.snapshot_dir, TIMESTAMP)) if rank == 0 else None normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = transforms.Compose([ transforms.ToTensor(), normalize, ]) if args.type == 'Helen': train_dataset = HelenDataSet('dataset/Helen_align_with_hair', args.dataset, crop_size=input_size, transform=transform) val_dataset = HelenDataSet('dataset/Helen_align_with_hair', 'test', crop_size=input_size, transform=transform) args.num_classes = 11 elif args.type == 'LaPa': train_dataset = LapaDataset('dataset/LaPa/origin', args.dataset, crop_size=input_size, transform=transform) val_dataset = LapaDataset('dataset/LaPa/origin', 'test', crop_size=input_size, transform=transform) args.num_classes = 11 elif args.type == 'Celeb': train_dataset = CelebAMaskHQDataSet('dataset/CelebAMask-HQ', args.dataset, crop_size=input_size, transform=transform) val_dataset = CelebAMaskHQDataSet('dataset/CelebAMask-HQ', 'test', crop_size=input_size, transform=transform) args.num_classes = 19 elif args.type == 'LIP': train_dataset = LIPDataSet('dataset/LIP', args.dataset, crop_size=input_size, transform=transform) val_dataset = LIPDataSet('dataset/LIP', 'val', crop_size=input_size, transform=transform) args.num_classes = 20 if distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True, sampler=train_sampler) num_samples = len(val_dataset) valloader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=False) cudnn.enabled = True # cudnn related setting cudnn.benchmark = True torch.backends.cudnn.deterministic = False torch.backends.cudnn.enabled = True if distributed: model = AGRNet(args.num_classes) else: model = AGRNet(args.num_classes, InPlaceABN) if args.restore_from is not None: model.load_state_dict( torch.load(args.restore_from, map_location='cuda:{}'.format(args.local_rank)), True) else: resnet_params = torch.load( os.path.join(args.snapshot_dir, 'resnet101-imagenet.pth')) new_params = model.state_dict().copy() for i in resnet_params: i_parts = i.split('.') # print(i_parts) if not i_parts[0] == 'fc': new_params['.'.join(i_parts[0:])] = resnet_params[i] model.load_state_dict(new_params) model.cuda() if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) else: model = SingleGPU(model) # CriterionCrossEntropyEdgeParsing_boundary_agrnet_loss for AGRNet, CriterionCrossEntropyEdgeParsing_boundary_eagrnet_loss for EAGRNet criterion = CriterionCrossEntropyEdgeParsing_boundary_agrnet_loss( loss_weight=[args.l1, args.l2, args.l3, args.l4], num_classes=args.num_classes) criterion.cuda() optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() total_iters = args.epochs * len(trainloader) for epoch in range(args.start_epoch, args.epochs): model.train() if distributed: train_sampler.set_epoch(epoch) for i_iter, batch in enumerate(trainloader): i_iter += len(trainloader) * epoch lr = adjust_learning_rate(optimizer, i_iter, total_iters) images, labels, edges, _ = batch labels = labels.long().cuda(non_blocking=True) edges = edges.long().cuda(non_blocking=True) preds = model(images) loss = criterion(preds, [labels, edges]) optimizer.zero_grad() loss.backward() optimizer.step() with torch.no_grad(): loss = loss.detach() * labels.shape[0] count = labels.new_tensor([labels.shape[0]], dtype=torch.long) if dist.is_initialized(): dist.all_reduce(count, dist.ReduceOp.SUM) dist.all_reduce(loss, dist.ReduceOp.SUM) loss /= count.item() if not dist.is_initialized() or dist.get_rank() == 0: if i_iter % 50 == 0: writer.add_scalar('learning_rate', lr, i_iter) writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter) if i_iter % 500 == 0: images_inv = inv_preprocess(images, args.save_num_images) labels_colors = decode_parsing(labels, args.save_num_images, args.num_classes, is_pred=False) edges_colors = decode_parsing(edges, args.save_num_images, 2, is_pred=False) if isinstance(preds, list): preds = preds[0] preds_colors = decode_parsing(preds[0], args.save_num_images, args.num_classes, is_pred=True) pred_edges = decode_parsing(preds[1], args.save_num_images, 2, is_pred=True) img = vutils.make_grid(images_inv, normalize=False, scale_each=True) lab = vutils.make_grid(labels_colors, normalize=False, scale_each=True) pred = vutils.make_grid(preds_colors, normalize=False, scale_each=True) edge = vutils.make_grid(edges_colors, normalize=False, scale_each=True) pred_edge = vutils.make_grid(pred_edges, normalize=False, scale_each=True) writer.add_image('Images/', img, i_iter) writer.add_image('Labels/', lab, i_iter) writer.add_image('Preds/', pred, i_iter) writer.add_image('Edge/', edge, i_iter) writer.add_image('Pred_edge/', pred_edge, i_iter) print('iter = {} of {} completed, loss = {}'.format( i_iter, total_iters, loss.data.cpu().numpy())) if not dist.is_initialized() or dist.get_rank() == 0: save_path = os.path.join(args.data_dir, TIMESTAMP) if not os.path.exists(save_path): os.makedirs(save_path) parsing_preds, scales, centers = valid( model, valloader, input_size, num_samples, osp.join(args.snapshot_dir, save_path)) mIoU, f1 = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, val_dataset, input_size, 'test', True, type=args.type) if f1['mean'] > best_f1: torch.save(model.module.state_dict(), osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth')) best_f1 = f1['mean'] print(mIoU) print(f1) writer.add_scalars('mIoU', mIoU, epoch) writer.add_scalars('f1', f1, epoch) if epoch % args.test_fre == 0: torch.save( model.module.state_dict(), osp.join(args.snapshot_dir, TIMESTAMP, 'epoch_' + str(epoch) + '.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): """Create the model and start the training.""" if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) writer = SummaryWriter(args.snapshot_dir) gpus = [int(i) for i in args.gpu.split(',')] if not args.gpu == 'None': os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu h, w = [int(i) for i in args.input_size.split(',')] input_size = [h, w] cudnn.enabled = True # cudnn related setting cudnn.benchmark = False torch.backends.cudnn.deterministic = False ##为何使用了非确定性的卷积 torch.backends.cudnn.enabled = True NUM_CLASSES = 7 # parsing NUM_HEATMAP = 15 # pose NUM_PAFS = 28 # pafs normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = transforms.Compose([ transforms.ToTensor(), normalize, ]) # load dataset num_samples = 0 trainloader = data.DataLoader(VOCSegmentation(args.data_dir, args.dataset, crop_size=input_size, stride=args.stride, transform=transform), batch_size=args.batch_size * len(gpus), shuffle=True, num_workers=2, pin_memory=True) valloader = None if args.print_val != 0: valdataset = VOCSegmentation(args.data_dir, 'val', crop_size=input_size, transform=transform) num_samples = len(valdataset) valloader = data.DataLoader( valdataset, batch_size=8 * len(gpus), # batchsize shuffle=False, pin_memory=True) parsingnet = ParsingNet(num_classes=NUM_CLASSES, num_heatmaps=NUM_HEATMAP, num_pafs=NUM_PAFS) criterion_parsing = Criterion() criterion_parsing = DataParallelCriterion(criterion_parsing) criterion_parsing.cuda() optimizer_parsing = optim.SGD(parsingnet.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer_parsing.zero_grad() # 加载预训练参数 print(args.train_continue) if not args.train_continue: checkpoint = torch.load(RESNET_IMAGENET) load_state(parsingnet, checkpoint) else: checkpoint = torch.load(args.restore_from_parsing) if 'current_epoch' in checkpoint: current_epoch = checkpoint['current_epoch'] args.start_epoch = current_epoch if 'state_dict' in checkpoint: checkpoint = checkpoint['state_dict'] load_state(parsingnet, checkpoint) parsingnet = DataParallelModel(parsingnet).cuda() total_iters = args.epochs * len(trainloader) for epoch in range(args.start_epoch, args.epochs): parsingnet.train() for i_iter, batch in enumerate(trainloader): i_iter += len(trainloader) * epoch lr = adjust_parsing_lr(optimizer_parsing, i_iter, total_iters) images, labels, edges, heatmap, pafs, heatmap_mask, pafs_mask, _ = batch images = images.cuda() labels = labels.long().cuda(non_blocking=True) edges = edges.long().cuda(non_blocking=True) heatmap = heatmap.cuda() pafs = pafs.cuda() heatmap_mask = heatmap_mask.cuda() pafs_mask = pafs_mask.cuda() preds = parsingnet(images) loss_parsing = criterion_parsing( preds, [labels, edges, heatmap, pafs, heatmap_mask, pafs_mask], writer, i_iter, total_iters) optimizer_parsing.zero_grad() loss_parsing.backward() optimizer_parsing.step() if i_iter % 100 == 0: writer.add_scalar('parsing_lr', lr, i_iter) writer.add_scalar('loss_total', loss_parsing.item(), i_iter) if i_iter % 500 == 0: if len(gpus) > 1: preds = preds[0] images_inv = inv_preprocess(images, args.save_num_images) parsing_labels_c = decode_parsing(labels, args.save_num_images, is_pred=False) preds_colors = decode_parsing(preds[0][-1], args.save_num_images, is_pred=True) edges_colors = decode_parsing(edges, args.save_num_images, is_pred=False) pred_edges = decode_parsing(preds[1][-1], args.save_num_images, is_pred=True) img = vutils.make_grid(images_inv, normalize=False, scale_each=True) parsing_lab = vutils.make_grid(parsing_labels_c, normalize=False, scale_each=True) pred_v = vutils.make_grid(preds_colors, normalize=False, scale_each=True) edge = vutils.make_grid(edges_colors, normalize=False, scale_each=True) pred_edges = vutils.make_grid(pred_edges, normalize=False, scale_each=True) writer.add_image('Images/', img, i_iter) writer.add_image('Parsing_labels/', parsing_lab, i_iter) writer.add_image('Parsing_Preds/', pred_v, i_iter) writer.add_image('Edges/', edge, i_iter) writer.add_image('Edges_preds/', pred_edges, i_iter) if (epoch + 1) % 15 == 0: if args.print_val != 0: parsing_preds, scales, centers = valid(parsingnet, valloader, input_size, num_samples, gpus) mIoU = compute_mean_ioU(parsing_preds, scales, centers, NUM_CLASSES, args.data_dir, input_size) f = open(os.path.join(args.snapshot_dir, "val_res.txt"), "a+") f.write(str(epoch) + str(mIoU) + '\n') f.close() snapshot_name_parsing = osp.join( args.snapshot_dir, 'PASCAL_parsing_' + str(epoch) + '' + '.pth') torch.save( { 'state_dict': parsingnet.state_dict(), 'optimizer': optimizer_parsing.state_dict(), 'current_epoch': epoch }, snapshot_name_parsing) end = timeit.default_timer() print(end - start, 'seconds')
def main(): """Create the model and start the training.""" cycle_n = 0 start_epoch = args.start_epoch writer = SummaryWriter(osp.join(args.snapshot_dir, TIMESTAMP)) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) h, w = map(int, args.input_size.split(',')) input_size = [h, w] best_f1 = 0 torch.cuda.set_device(args.local_rank) try: world_size = int(os.environ['WORLD_SIZE']) distributed = world_size > 1 except: distributed = False world_size = 1 if distributed: dist.init_process_group(backend=args.dist_backend, init_method='env://') rank = 0 if not distributed else dist.get_rank() log_file = args.snapshot_dir + '/' + TIMESTAMP + 'output.log' logger = get_root_logger(log_file=log_file, log_level='INFO') logger.info(f'Distributed training: {distributed}') cudnn.enabled = True cudnn.benchmark = True torch.backends.cudnn.deterministic = False torch.backends.cudnn.enabled = True if distributed: model = dml_csr.DML_CSR(args.num_classes) schp_model = dml_csr.DML_CSR(args.num_classes) else: model = dml_csr.DML_CSR(args.num_classes, InPlaceABN) schp_model = dml_csr.DML_CSR(args.num_classes, InPlaceABN) if args.restore_from is not None: print('Resume training from {}'.format(args.restore_from)) model.load_state_dict(torch.load(args.restore_from), True) start_epoch = int(float( args.restore_from.split('.')[0].split('_')[-1])) + 1 else: resnet_params = torch.load(RESTORE_FROM) new_params = model.state_dict().copy() for i in resnet_params: i_parts = i.split('.') if not i_parts[0] == 'fc': new_params['.'.join(i_parts[0:])] = resnet_params[i] model.load_state_dict(new_params) model.cuda() args.schp_restore = osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth') if os.path.exists(args.schp_restore): print('Resume schp checkpoint from {}'.format(args.schp_restore)) schp_model.load_state_dict(torch.load(args.schp_restore), True) else: schp_resnet_params = torch.load(RESTORE_FROM) schp_new_params = schp_model.state_dict().copy() for i in schp_resnet_params: i_parts = i.split('.') if not i_parts[0] == 'fc': schp_new_params['.'.join(i_parts[0:])] = schp_resnet_params[i] schp_model.load_state_dict(schp_new_params) schp_model.cuda() if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) schp_model = torch.nn.parallel.DistributedDataParallel( schp_model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) else: model = SingleGPU(model) schp_model = SingleGPU(schp_model) criterion = Criterion(loss_weight=[1, 1, 1, 4, 1], lambda_1=args.lambda_s, lambda_2=args.lambda_e, lambda_3=args.lambda_c, num_classes=args.num_classes) criterion.cuda() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = transforms.Compose([transforms.ToTensor(), normalize]) train_dataset = FaceDataSet(args.data_dir, args.train_dataset, crop_size=input_size, transform=transform) if distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True, sampler=train_sampler) val_dataset = datasets[str(args.model_type)](args.data_dir, args.valid_dataset, crop_size=input_size, transform=transform) num_samples = len(val_dataset) valloader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=False) # Optimizer Initialization optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = SGDRScheduler(optimizer, total_epoch=args.epochs, eta_min=args.learning_rate / 100, warmup_epoch=10, start_cyclical=args.schp_start, cyclical_base_lr=args.learning_rate / 2, cyclical_epoch=args.cycle_epochs) optimizer.zero_grad() total_iters = args.epochs * len(trainloader) start = timeit.default_timer() for epoch in range(start_epoch, args.epochs): model.train() if distributed: train_sampler.set_epoch(epoch) for i_iter, batch in enumerate(trainloader): i_iter += len(trainloader) * epoch if epoch < args.schp_start: lr = adjust_learning_rate(optimizer, i_iter, total_iters) else: lr = lr_scheduler.get_lr()[0] images, labels, edges, semantic_edges, _ = batch labels = labels.long().cuda(non_blocking=True) edges = edges.long().cuda(non_blocking=True) semantic_edges = semantic_edges.long().cuda(non_blocking=True) preds = model(images) if cycle_n >= 1: with torch.no_grad(): soft_preds, soft_edges, soft_semantic_edges = schp_model( images) else: soft_preds = None soft_edges = None soft_semantic_edges = None loss = criterion(preds, [ labels, edges, semantic_edges, soft_preds, soft_edges, soft_semantic_edges ], cycle_n) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() with torch.no_grad(): loss = loss.detach() * labels.shape[0] count = labels.new_tensor([labels.shape[0]], dtype=torch.long) if dist.is_initialized(): dist.all_reduce(count, dist.ReduceOp.SUM) dist.all_reduce(loss, dist.ReduceOp.SUM) loss /= count.item() if not dist.is_initialized() or dist.get_rank() == 0: if i_iter % 50 == 0: writer.add_scalar('learning_rate', lr, i_iter) writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter) if i_iter % 500 == 0: images_inv = inv_preprocess(images, args.save_num_images) labels_colors = decode_parsing(labels, args.save_num_images, args.num_classes, is_pred=False) edges_colors = decode_parsing(edges, args.save_num_images, 2, is_pred=False) semantic_edges_colors = decode_parsing( semantic_edges, args.save_num_images, args.num_classes, is_pred=False) if isinstance(preds, list): preds = preds[0] preds_colors = decode_parsing(preds[0], args.save_num_images, args.num_classes, is_pred=True) pred_edges = decode_parsing(preds[1], args.save_num_images, 2, is_pred=True) pred_semantic_edges_colors = decode_parsing( preds[2], args.save_num_images, args.num_classes, is_pred=True) img = vutils.make_grid(images_inv, normalize=False, scale_each=True) lab = vutils.make_grid(labels_colors, normalize=False, scale_each=True) pred = vutils.make_grid(preds_colors, normalize=False, scale_each=True) edge = vutils.make_grid(edges_colors, normalize=False, scale_each=True) pred_edge = vutils.make_grid(pred_edges, normalize=False, scale_each=True) pred_semantic_edges = vutils.make_grid( pred_semantic_edges_colors, normalize=False, scale_each=True) writer.add_image('Images/', img, i_iter) writer.add_image('Labels/', lab, i_iter) writer.add_image('Preds/', pred, i_iter) writer.add_image('Edge/', edge, i_iter) writer.add_image('Pred_edge/', pred_edge, i_iter) cur_loss = loss.data.cpu().numpy() logger.info( f'iter = {i_iter} of {total_iters} completed, loss = {cur_loss}, lr = {lr}' ) if (epoch + 1) % (args.eval_epochs) == 0: parsing_preds, scales, centers = valid(model, valloader, input_size, num_samples) mIoU, f1 = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, args.valid_dataset, True) if not dist.is_initialized() or dist.get_rank() == 0: torch.save( model.module.state_dict(), osp.join(args.snapshot_dir, TIMESTAMP, 'checkpoint_{}.pth'.format(epoch + 1))) if 'Helen' in args.data_dir: if f1['overall'] > best_f1: torch.save( model.module.state_dict(), osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth')) best_f1 = f1['overall'] else: if f1['Mean_F1'] > best_f1: torch.save( model.module.state_dict(), osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth')) best_f1 = f1['Mean_F1'] writer.add_scalars('mIoU', mIoU, epoch) writer.add_scalars('f1', f1, epoch) logger.info( f'mIoU = {mIoU}, and f1 = {f1} of epoch = {epoch}, util now, best_f1 = {best_f1}' ) if (epoch + 1) >= args.schp_start and ( epoch + 1 - args.schp_start) % args.cycle_epochs == 0: logger.info(f'Self-correction cycle number {cycle_n}') schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1)) cycle_n += 1 schp.bn_re_estimate(trainloader, schp_model) parsing_preds, scales, centers = valid(schp_model, valloader, input_size, num_samples) mIoU, f1 = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, args.valid_dataset, True) if not dist.is_initialized() or dist.get_rank() == 0: torch.save( schp_model.module.state_dict(), osp.join(args.snapshot_dir, TIMESTAMP, 'schp_{}_checkpoint.pth'.format(cycle_n))) if 'Helen' in args.data_dir: if f1['overall'] > best_f1: torch.save( schp_model.module.state_dict(), osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth')) best_f1 = f1['overall'] else: if f1['Mean_F1'] > best_f1: torch.save( schp_model.module.state_dict(), osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth')) best_f1 = f1['Mean_F1'] writer.add_scalars('mIoU', mIoU, epoch) writer.add_scalars('f1', f1, epoch) logger.info( f'mIoU = {mIoU}, and f1 = {f1} of epoch = {epoch}, util now, best_f1 = {best_f1}' ) torch.cuda.empty_cache() end = timeit.default_timer() print('epoch = {} of {} completed using {} s'.format( epoch, args.epochs, (end - start) / (epoch - start_epoch + 1))) end = timeit.default_timer() print(end - start, 'seconds')
def valid(model, valloader, input_size, num_samples, gpus): model.eval() parsing_preds = np.zeros((num_samples, input_size[0], input_size[1]), dtype=np.uint8) scales = np.zeros((num_samples, 2), dtype=np.float32) centers = np.zeros((num_samples, 2), dtype=np.int32) idx = 0 interp = torch.nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) interp_1 = torch.nn.Upsample(size=(384, 384), mode='bilinear', align_corners=True) with torch.no_grad(): for index, batch in enumerate(valloader): image, label_parsing, label_r0, label_r1, label_r2, label_r3, label_l0, label_l1, label_l2, label_l3, label_l4, label_l5, label_edge, meta = batch num_images = image.size(0) if index % 10 == 0: print('%d processd' % (index * num_images)) c = meta['center'].numpy() s = meta['scale'].numpy() scales[idx:idx + num_images, :] = s[:, :] centers[idx:idx + num_images, :] = c[:, :] outputs = model(image.cuda()) if gpus > 1: for output in outputs: parsing = output[0][-1] nums = len(parsing) parsing = interp(parsing).data.cpu().numpy() parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC parsing_preds[idx:idx + nums, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8) idx += nums else: #gt = torch.from_numpy(parsing_anno) gt_parsing_colors = decode_parsing(label_parsing, 2, 20, False) gt_r0_colors = decode_parsing(label_r0, 2, 20, False) gt_r1_colors = decode_parsing(label_r1, 2, 20, False) gt_r2_colors = decode_parsing(label_r2, 2, 20, False) gt_r3_colors = decode_parsing(label_r3, 2, 20, False) #np.set_printoptions(threshold=np.inf) #print(label_l0.numpy()) gt_l0_colors = decode_parsing(label_l0, 2, 20, False) gt_l1_colors = decode_parsing(label_l1, 2, 20, False) gt_l2_colors = decode_parsing(label_l2, 2, 20, False) gt_l3_colors = decode_parsing(label_l3, 2, 20, False) gt_l4_colors = decode_parsing(label_l4, 2, 20, False) gt_l5_colors = decode_parsing(label_l5, 2, 20, False) for i in range(2): scipy.misc.toimage(gt_parsing_colors[i]).save("./pics/{}_{}_gt.png".format(index, i)) scipy.misc.toimage(gt_r0_colors[i]).save("./pics/{}_{}_gt_r0.png".format(index, i)) scipy.misc.toimage(gt_r1_colors[i]).save("./pics/{}_{}_gt_r1.png".format(index, i)) scipy.misc.toimage(gt_r2_colors[i]).save("./pics/{}_{}_gt_r2.png".format(index, i)) scipy.misc.toimage(gt_r3_colors[i]).save("./pics/{}_{}_gt_r3.png".format(index, i)) scipy.misc.toimage(gt_l0_colors[i]).save("./pics/{}_{}_gt_l0.png".format(index, i)) scipy.misc.toimage(gt_l1_colors[i]).save("./pics/{}_{}_gt_l1.png".format(index, i)) scipy.misc.toimage(gt_l2_colors[i]).save("./pics/{}_{}_gt_l2.png".format(index, i)) scipy.misc.toimage(gt_l3_colors[i]).save("./pics/{}_{}_gt_l3.png".format(index, i)) scipy.misc.toimage(gt_l4_colors[i]).save("./pics/{}_{}_gt_l4.png".format(index, i)) scipy.misc.toimage(gt_l5_colors[i]).save("./pics/{}_{}_gt_l5.png".format(index, i)) parsing = outputs[0][0] tmp = interp_1(parsing) tmp = torch.argmax(tmp, dim=1, keepdim=False) ignore_index = label_parsing == 255 tmp[ignore_index] = 0 preds_colors = decode_parsing(tmp, 2, 20, False) pred_r0 = outputs[1][0] pred_r0 = interp_1(pred_r0) pred_r0_colors = decode_parsing(pred_r0, 2, 20, True) pred_r1 = outputs[1][1] pred_r1 = interp_1(pred_r1) pred_r1_colors = decode_parsing(pred_r1, 2, 20, True) pred_r2 = outputs[1][2] pred_r2 = interp_1(pred_r2) pred_r2_colors = decode_parsing(pred_r2, 2, 20, True) pred_r3 = outputs[1][3] pred_r3 = interp_1(pred_r3) pred_r3_colors = decode_parsing(pred_r3, 2, 20, True) pred_l0 = outputs[2][0] pred_l0 = interp_1(pred_l0) pred_l0_colors = decode_parsing(pred_l0, 2, 20, True) pred_l1 = outputs[2][1] pred_l1 = interp_1(pred_l1) pred_l1_colors = decode_parsing(pred_l1, 2, 20, True) pred_l2 = outputs[2][2] pred_l2 = interp_1(pred_l2) pred_l2_colors = decode_parsing(pred_l2, 2, 20, True) pred_l3 = outputs[2][3] pred_l3 = interp_1(pred_l3) pred_l3_colors = decode_parsing(pred_l3, 2, 20, True) pred_l4 = outputs[2][4] pred_l4 = interp_1(pred_l4) pred_l4_colors = decode_parsing(pred_l4, 2, 20, True) pred_l5 = outputs[2][5] pred_l5 = interp_1(pred_l5) pred_l5_colors = decode_parsing(pred_l5, 2, 20, True) for i in range(2): scipy.misc.toimage(preds_colors[i]).save("./pics/{}_{}_pred.png".format(index, i)) scipy.misc.toimage(pred_r0_colors[i]).save("./pics/{}_{}_pred_r0.png".format(index, i)) scipy.misc.toimage(pred_r1_colors[i]).save("./pics/{}_{}_pred_r1.png".format(index, i)) scipy.misc.toimage(pred_r2_colors[i]).save("./pics/{}_{}_pred_r2.png".format(index, i)) scipy.misc.toimage(pred_r3_colors[i]).save("./pics/{}_{}_pred_r3.png".format(index, i)) scipy.misc.toimage(pred_l0_colors[i]).save("./pics/{}_{}_pred_l0.png".format(index, i)) scipy.misc.toimage(pred_l1_colors[i]).save("./pics/{}_{}_pred_l1.png".format(index, i)) scipy.misc.toimage(pred_l2_colors[i]).save("./pics/{}_{}_pred_l2.png".format(index, i)) scipy.misc.toimage(pred_l3_colors[i]).save("./pics/{}_{}_pred_l3.png".format(index, i)) scipy.misc.toimage(pred_l4_colors[i]).save("./pics/{}_{}_pred_l4.png".format(index, i)) scipy.misc.toimage(pred_l5_colors[i]).save("./pics/{}_{}_pred_l5.png".format(index, i)) parsing = interp(parsing).data.cpu().numpy() parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC parsing_preds[idx:idx + num_images, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8) idx += num_images parsing_preds = parsing_preds[:num_samples, :, :] return parsing_preds, scales, centers
def main(): """Create the model and start the training.""" if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) writer = SummaryWriter(args.snapshot_dir) gpus = [int(i) for i in args.gpu.split(',')] if not args.gpu == 'None': os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu h, w = map(int, args.input_size.split(',')) input_size = [h, w] cudnn.enabled = True # cudnn related setting cudnn.benchmark = True torch.backends.cudnn.deterministic = False torch.backends.cudnn.enabled = True deeplab = Res_Deeplab(num_classes=args.num_classes) print(type(deeplab)) # dump_input = torch.rand((args.batch_size, 3, input_size[0], input_size[1])) # writer.add_graph(deeplab.cuda(), dump_input.cuda(), verbose=False) """ HOW DOES IT LOAD ONLY RESNET101 AND NOT THE RSTE OF THE NET ? """ # UNCOMMENT THE FOLLOWING COMMENTARY TO INITIALYZE THE WEIGHTS # Load resnet101 weights trained on imagenet and copy it in new_params saved_state_dict = torch.load(args.restore_from) new_params = deeplab.state_dict().copy() # CHECK IF WEIGHTS BELONG OR NOT TO THE MODEL # belongs = 0 # doesnt_b = 0 # for key in saved_state_dict: # if key in new_params: # belongs+=1 # print('key=', key) # else: # doesnt_b+=1 # # print('key=', key) # print('belongs = ', belongs, 'doesnt_b=', doesnt_b) # print('res101 len',len(saved_state_dict)) # print('new param len',len(new_params)) for i in saved_state_dict: i_parts = i.split('.') # print('i_parts:', i_parts) # exp : i_parts: ['layer2', '3', 'bn2', 'running_mean'] # The deeplab weight modules have diff name than args.restore_from weight modules if i_parts[0] == 'module' and not i_parts[1] == 'fc' : if new_params['.'.join(i_parts[1:])].size() == saved_state_dict[i].size(): new_params['.'.join(i_parts[1:])] = saved_state_dict[i] else: if not i_parts[0] == 'fc': if new_params['.'.join(i_parts[0:])].size() == saved_state_dict[i].size(): new_params['.'.join(i_parts[0:])] = saved_state_dict[i] deeplab.load_state_dict(new_params) # UNCOMMENT UNTIL HERE model = DataParallelModel(deeplab) model.cuda() criterion = CriterionAll() criterion = DataParallelCriterion(criterion) criterion.cuda() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainloader = data.DataLoader(cartoonDataSet(args.data_dir, args.dataset, crop_size=input_size, transform=transform), batch_size=args.batch_size * len(gpus), shuffle=True, num_workers=8, pin_memory=True) #mIoU for Val set val_dataset = cartoonDataSet(args.data_dir, 'val', crop_size=input_size, transform=transform) numVal_samples = len(val_dataset) valloader = data.DataLoader(val_dataset, batch_size=args.batch_size * len(gpus), shuffle=False, pin_memory=True) #mIoU for trainTest set trainTest_dataset = cartoonDataSet(args.data_dir, 'trainTest', crop_size=input_size, transform=transform) numTest_samples = len(trainTest_dataset) testloader = data.DataLoader(trainTest_dataset, batch_size=args.batch_size * len(gpus), shuffle=False, pin_memory=True) optimizer = optim.SGD( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay ) optimizer.zero_grad() # valBatch_idx = 0 total_iters = args.epochs * len(trainloader) for epoch in range(args.start_epoch, args.epochs): model.train() for i_iter, batch in enumerate(trainloader): i_iter += len(trainloader) * epoch lr = adjust_learning_rate(optimizer, i_iter, total_iters) images, labels, _, _ = batch labels = labels.long().cuda(non_blocking=True) preds = model(images) # print('preds size in batch', len(preds)) # print('Size of Segmentation1 tensor output:',preds[0][0].size()) # print('Segmentation2 tensor output:',preds[0][-1].size()) # print('Size of Edge tensor output:',preds[1][-1].size()) loss = criterion(preds, [labels]) optimizer.zero_grad() loss.backward() optimizer.step() if i_iter % 100 == 0: writer.add_scalar('learning_rate', lr, i_iter) writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter) if i_iter % 500 == 0: # print('In iter%500 Size of Segmentation2 GT: ', labels.size()) # print('In iter%500 Size of edges GT: ', edges.size()) images_inv = inv_preprocess(images, args.save_num_images) # print(labels[0]) labels_colors = decode_parsing(labels, args.save_num_images, args.num_classes, is_pred=False) # if isinstance(preds, list): # print(len(preds)) # preds = preds[0] # val_images, _ = valloader[valBatch_idx] # valBatch_idx += 1 # val_sampler = torch.utils.data.RandomSampler(val_dataset,replacement=True, num_samples=args.batch_size * len(gpus)) # sample_valloader = data.DataLoader(val_dataset, batch_size=args.batch_size * len(gpus), # shuffle=False, sampler=val_sampler , pin_memory=True) # val_images, _ = sample_valloader # preds_val = model(val_images) # With multiple GPU, preds return a list, therefore we extract the tensor in the list if len(gpus)>1: preds= preds[0] # preds_val = preds_val[0] # print('In iter%500 Size of Segmentation2 tensor output:',preds[0][0][-1].size()) # preds[0][-1] cause model returns [[seg1, seg2], [edge]] preds_colors = decode_parsing(preds[0][-1], args.save_num_images, args.num_classes, is_pred=True) # preds_val_colors = decode_parsing(preds_val[0][-1], args.save_num_images, args.num_classes, is_pred=True) # print("preds type:",type(preds)) #list # print("preds shape:", len(preds)) #2 # hello = preds[0][-1] # print("preds type [0][-1]:",type(hello)) #<class 'torch.Tensor'> # print("preds len [0][-1]:", len(hello)) #12 # print("preds len [0][-1]:", hello.shape)#torch.Size([12, 8, 96, 96]) # print("preds color's type:",type(preds_colors))#torch.tensor # print("preds color's shape:",preds_colors.shape) #([2,3,96,96]) # print('IMAGE', images_inv.size()) img = vutils.make_grid(images_inv, normalize=False, scale_each=True) lab = vutils.make_grid(labels_colors, normalize=False, scale_each=True) pred = vutils.make_grid(preds_colors, normalize=False, scale_each=True) # print("preD type:",type(pred)) #<class 'torch.Tensor'> # print("preD len:", len(pred))# 3 # print("preD shape:", pred.shape)#torch.Size([3, 100, 198]) # 1=head red, 2=body green , 3=left_arm yellow, 4=right_arm blue, 5=left_leg pink # 6=right_leg skuBlue, 7=tail grey writer.add_image('Images/', img, i_iter) writer.add_image('Labels/', lab, i_iter) writer.add_image('Preds/', pred, i_iter) print('iter = {} of {} completed, loss = {}'.format(i_iter, total_iters, loss.data.cpu().numpy())) print('end epoch:', epoch) if epoch%99 == 0: torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'DFPnet_epoch_' + str(epoch) + '.pth')) if epoch%5 == 0 and epoch<500: # mIou for Val set parsing_preds, scales, centers = valid(model, valloader, input_size, numVal_samples, len(gpus)) ''' Insert a sample of prediction of a val image on tensorboard ''' # generqte a rand number between len(parsing_preds) sample = random.randint(0, len(parsing_preds)-1) #loader resize and convert to tensor the image loader = transforms.Compose([ transforms.Resize(input_size), transforms.ToTensor() ]) # get val segmentation path and open the file list_path = os.path.join(args.data_dir, 'val' + '_id.txt') val_id = [i_id.strip() for i_id in open(list_path)] gt_path = os.path.join(args.data_dir, 'val' + '_segmentations', val_id[sample] + '.png') gt =Image.open(gt_path) gt = loader(gt) #put gt back from 0 to 255 gt = (gt*255).int() # convert pred from ndarray to PIL image then to tensor display_preds = Image.fromarray(parsing_preds[sample]) tensor_display_preds = transforms.ToTensor()(display_preds) #put gt back from 0 to 255 tensor_display_preds = (tensor_display_preds*255).int() # color them val_preds_colors = decode_parsing(tensor_display_preds, num_images=1, num_classes=args.num_classes, is_pred=False) gt_color = decode_parsing(gt, num_images=1, num_classes=args.num_classes, is_pred=False) # put in grid pred_val = vutils.make_grid(val_preds_colors, normalize=False, scale_each=True) gt_val = vutils.make_grid(gt_color, normalize=False, scale_each=True) writer.add_image('Preds_val/', pred_val, epoch) writer.add_image('Gt_val/', gt_val, epoch) mIoUval = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, 'val') print('For val set', mIoUval) writer.add_scalars('mIoUval', mIoUval, epoch) # mIou for trainTest set parsing_preds, scales, centers = valid(model, testloader, input_size, numTest_samples, len(gpus)) mIoUtest = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, 'trainTest') print('For trainTest set', mIoUtest) writer.add_scalars('mIoUtest', mIoUtest, epoch) else: if epoch%20 == 0 and epoch>=500: # mIou for Val set parsing_preds, scales, centers = valid(model, valloader, input_size, numVal_samples, len(gpus)) ''' Insert a sample of prediction of a val image on tensorboard ''' # generqte a rand number between len(parsing_preds) sample = random.randint(0, len(parsing_preds)-1) #loader resize and convert to tensor the image loader = transforms.Compose([ transforms.Resize(input_size), transforms.ToTensor() ]) # get val segmentation path and open the file list_path = os.path.join(args.data_dir, 'val' + '_id.txt') val_id = [i_id.strip() for i_id in open(list_path)] gt_path = os.path.join(args.data_dir, 'val' + '_segmentations', val_id[sample] + '.png') gt =Image.open(gt_path) gt = loader(gt) #put gt back from 0 to 255 gt = (gt*255).int() # convert pred from ndarray to PIL image then to tensor display_preds = Image.fromarray(parsing_preds[sample]) tensor_display_preds = transforms.ToTensor()(display_preds) #put gt back from 0 to 255 tensor_display_preds = (tensor_display_preds*255).int() # color them val_preds_colors = decode_parsing(tensor_display_preds, num_images=1, num_classes=args.num_classes, is_pred=False) gt_color = decode_parsing(gt, num_images=1, num_classes=args.num_classes, is_pred=False) # put in grid pred_val = vutils.make_grid(val_preds_colors, normalize=False, scale_each=True) gt_val = vutils.make_grid(gt_color, normalize=False, scale_each=True) writer.add_image('Preds_val/', pred_val, epoch) writer.add_image('Gt_val/', gt_val, epoch) mIoUval = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, 'val') print('For val set', mIoUval) writer.add_scalars('mIoUval', mIoUval, epoch) # mIou for trainTest set parsing_preds, scales, centers = valid(model, testloader, input_size, numTest_samples, len(gpus)) mIoUtest = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, 'trainTest') print('For trainTest set', mIoUtest) writer.add_scalars('mIoUtest', mIoUtest, epoch) end = timeit.default_timer() print(end - start, 'seconds')