def main(opts): # ===== Setup distributed ===== distributed.init_process_group(backend='nccl', init_method='env://') if opts.device is not None: device_id = opts.device else: device_id = opts.local_rank device = torch.device(device_id) rank, world_size = distributed.get_rank(), distributed.get_world_size() if opts.device is not None: torch.cuda.set_device(opts.device) else: torch.cuda.set_device(device_id) # ===== Initialize logging ===== logdir_full = f"{opts.logdir}/{opts.dataset}/{opts.name}/" if rank == 0: logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=opts.visualize) else: logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=False) logger.print(f"Device: {device}") checkpoint_path = f"checkpoints/{opts.dataset}/{opts.name}.pth" os.makedirs(f"checkpoints/{opts.dataset}", exist_ok=True) # ===== Setup random seed to reproducibility ===== torch.manual_seed(opts.random_seed) torch.cuda.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # ===== Set up dataset ===== train_dst, val_dst = get_dataset(opts, train=True) train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, sampler=DistributedSampler( train_dst, num_replicas=world_size, rank=rank), num_workers=opts.num_workers, drop_last=True, pin_memory=True) val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size, sampler=DistributedSampler( val_dst, num_replicas=world_size, rank=rank), num_workers=opts.num_workers) logger.info(f"Dataset: {opts.dataset}, Train set: {len(train_dst)}, " f"Val set: {len(val_dst)}, n_classes {opts.num_classes}") logger.info(f"Total batch size is {opts.batch_size * world_size}") # This is necessary for computing the scheduler decay opts.max_iter = opts.max_iter = opts.epochs * len(train_loader) # ===== Set up model and ckpt ===== model = Trainer(device, logger, opts) model.distribute() cur_epoch = 0 if opts.continue_ckpt: opts.ckpt = checkpoint_path if opts.ckpt is not None: assert os.path.isfile( opts.ckpt), "Error, ckpt not found. Check the correct directory" checkpoint = torch.load(opts.ckpt, map_location="cpu") cur_epoch = checkpoint["epoch"] + 1 model.load_state_dict(checkpoint["model_state"]) logger.info("[!] Model restored from %s" % opts.ckpt) del checkpoint else: logger.info("[!] Train from scratch") # ===== Train procedure ===== # print opts before starting training to log all parameters logger.add_table("Opts", vars(opts)) # uncomment if you want qualitative on val # if rank == 0 and opts.sample_num > 0: # sample_ids = np.random.choice(len(val_loader), opts.sample_num, replace=False) # sample idxs for visualization # logger.info(f"The samples id are {sample_ids}") # else: # sample_ids = None label2color = utils.Label2Color(cmap=utils.color_map( opts.dataset)) # convert labels to images denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ]) # de-normalization for original images train_metrics = StreamSegMetrics(opts.num_classes) val_metrics = StreamSegMetrics(opts.num_classes) results = {} # check if random is equal here. logger.print(torch.randint(0, 100, (1, 1))) while cur_epoch < opts.epochs and not opts.test: # ===== Train ===== start = time.time() epoch_loss = model.train(cur_epoch=cur_epoch, train_loader=train_loader, metrics=train_metrics, print_int=opts.print_interval) train_score = train_metrics.get_results() end = time.time() len_ep = int(end - start) logger.info( f"End of Epoch {cur_epoch}/{opts.epochs}, Average Loss={epoch_loss[0] + epoch_loss[1]:.4f}, " f"Class Loss={epoch_loss[0]:.4f}, Reg Loss={epoch_loss[1]}\n" f"Train_Acc={train_score['Overall Acc']:.4f}, Train_Iou={train_score['Mean IoU']:.4f} " f"\n -- time: {len_ep // 60}:{len_ep % 60} -- ") logger.info( f"I will finish in {len_ep * (opts.epochs - cur_epoch) // 60} minutes" ) logger.add_scalar("E-Loss", epoch_loss[0] + epoch_loss[1], cur_epoch) # logger.add_scalar("E-Loss-reg", epoch_loss[1], cur_epoch) # logger.add_scalar("E-Loss-cls", epoch_loss[0], cur_epoch) # ===== Validation ===== if (cur_epoch + 1) % opts.val_interval == 0: logger.info("validate on val set...") val_loss, _ = model.validate(loader=val_loader, metrics=val_metrics, ret_samples_ids=None) val_score = val_metrics.get_results() logger.print("Done validation") logger.info( f"End of Validation {cur_epoch}/{opts.epochs}, Validation Loss={val_loss}" ) log_val(logger, val_metrics, val_score, val_loss, cur_epoch) # keep the metric to print them at the end of training results["V-IoU"] = val_score['Class IoU'] results["V-Acc"] = val_score['Class Acc'] # ===== Save Model ===== if rank == 0: if not opts.debug: save_ckpt(checkpoint_path, model, cur_epoch) logger.info("[!] Checkpoint saved.") cur_epoch += 1 torch.distributed.barrier() # ==== TESTING ===== logger.info("*** Test the model on all seen classes...") # make data loader test_dst = get_dataset(opts, train=False) test_loader = data.DataLoader(test_dst, batch_size=opts.batch_size_test, sampler=DistributedSampler( test_dst, num_replicas=world_size, rank=rank), num_workers=opts.num_workers) if rank == 0 and opts.sample_num > 0: sample_ids = np.random.choice(len(test_loader), opts.sample_num, replace=False) # sample idxs for visual. logger.info(f"The samples id are {sample_ids}") else: sample_ids = None val_loss, ret_samples = model.validate(loader=test_loader, metrics=val_metrics, ret_samples_ids=sample_ids) val_score = val_metrics.get_results() conf_matrixes = val_metrics.get_conf_matrixes() logger.print("Done test on all") logger.info(f"*** End of Test on all, Total Loss={val_loss}") logger.info(val_metrics.to_str(val_score)) log_samples(logger, ret_samples, denorm, label2color, 0) logger.add_figure("Test_Confusion_Matrix_Recall", conf_matrixes['Confusion Matrix']) logger.add_figure("Test_Confusion_Matrix_Precision", conf_matrixes["Confusion Matrix Pred"]) results["T-IoU"] = val_score['Class IoU'] results["T-Acc"] = val_score['Class Acc'] results["T-Prec"] = val_score['Class Prec'] logger.add_results(results) logger.add_scalar("T_Overall_Acc", val_score['Overall Acc']) logger.add_scalar("T_MeanIoU", val_score['Mean IoU']) logger.add_scalar("T_MeanAcc", val_score['Mean Acc']) ret = val_score['Mean IoU'] logger.close() return ret
# metrics preds = outputs.detach().max(dim=1)[1].cpu().numpy() targets = labels.cpu().numpy() metrics.update(targets, preds) end = time.time() if step%10==0: print('Epoch: ',str(epoch),' Iter: ',step,'Loss: ',loss.item(),) print('iter time: ',end-start) # update training_loss, training_accuracy and training_iou train_loss = train_loss/float(len(train_loader)) train_loss_list.append(train_loss) results = metrics.get_results() train_iou = results["Mean IoU"] train_iou_list.append(train_iou) writer.add_scalar("loss/train", train_loss, epoch) writer.add_scalar("iou/train", train_iou, epoch) if epoch%5==0: metrics.reset() model.eval() val_loss = 0.0 for step, (images, labels) in enumerate(val_loader): with torch.no_grad():
def main(): opts = get_argparser().parse_args() opts = modify_command_options(opts) os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) print("Device: %s"%device) # Set up random seed torch.manual_seed(opts.random_seed) torch.cuda.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # Set up dataloader _, val_dst = get_dataset(opts) val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size if opts.crop_val else 1 , shuffle=False, num_workers=opts.num_workers) print("Dataset: %s, Val set: %d"%(opts.dataset, len(val_dst))) # Set up model print("Backbone: %s"%opts.backbone) model = DeepLabv3(num_classes=opts.num_classes, backbone=opts.backbone, pretrained=True, momentum=opts.bn_mom, output_stride=opts.output_stride, use_separable_conv=opts.use_separable_conv) if opts.use_gn==True: print("[!] Replace BatchNorm with GroupNorm!") model = utils.convert_bn2gn(model) if torch.cuda.device_count()>1: # Parallel print("%d GPU parallel"%(torch.cuda.device_count())) model = torch.nn.DataParallel(model) model_ref = model.module # for ckpt else: model_ref = model model = model.to(device) # Set up metrics metrics = StreamSegMetrics(opts.num_classes) if opts.save_path is not None: utils.mkdir(opts.save_path) # Restore if opts.ckpt is not None and os.path.isfile(opts.ckpt): checkpoint = torch.load(opts.ckpt) model_ref.load_state_dict(checkpoint["model_state"]) print("Model restored from %s"%opts.ckpt) else: print("[!] Retrain") label2color = utils.Label2Color(cmap=utils.color_map(opts.dataset)) # convert labels to images denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images model.eval() metrics.reset() idx = 0 if opts.save_path is not None: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt with torch.no_grad(): for i, (images, labels) in tqdm( enumerate( val_loader ) ): images = images.to(device, dtype=torch.float32) labels = labels.to(device, dtype=torch.long) outputs = model(images) preds = outputs.detach().max(dim=1)[1].cpu().numpy() targets = labels.cpu().numpy() metrics.update(targets, preds) if opts.save_path is not None: for i in range(len(images)): image = images[i].detach().cpu().numpy() target = targets[i] pred = preds[i] image = (denorm(image) * 255).transpose(1,2,0).astype(np.uint8) target = label2color(target).astype(np.uint8) pred = label2color(pred).astype(np.uint8) Image.fromarray(image).save(os.path.join(opts.save_path, '%d_image.png'%idx) ) Image.fromarray(target).save(os.path.join(opts.save_path, '%d_target.png'%idx) ) Image.fromarray(pred).save(os.path.join(opts.save_path, '%d_pred.png'%idx) ) fig = plt.figure() plt.imshow(image) plt.axis('off') plt.imshow(pred, alpha=0.7) ax = plt.gca() ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator()) plt.savefig(os.path.join(opts.save_path, '%d_overlay.png'%idx), bbox_inches='tight', pad_inches=0) plt.close() idx+=1 score = metrics.get_results() print(metrics.to_str(score)) if opts.save_path is not None: with open(os.path.join(opts.save_path, 'score.txt'), mode='w') as f: f.write(metrics.to_str(score))