def __init__(self, data_loader, opts): #super(Trainer, self).__init__(data_loader, opts) #self.opts = opts self.train_loader = data_loader[0] self.val_loader = data_loader[1] # Set up model model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } self.model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(self.model.classifier) def set_bn_momentum(model, momentum=0.1): for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.momentum = momentum set_bn_momentum(self.model.backbone, momentum=0.01) ##### What is Momentum? 0.01 or 0.99? ##### # Set up metrics self.metrics = StreamSegMetrics(opts.num_classes) # Set up optimizer self.optimizer = torch.optim.SGD(params=[ {'params': self.model.backbone.parameters(), 'lr': 0.1*opts.lr}, {'params': self.model.classifier.parameters(), 'lr': opts.lr}, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) if opts.lr_policy=='poly': self.scheduler = utils.PolyLR(self.optimizer, opts.total_itrs, power=0.9) elif opts.lr_policy=='step': self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opts.step_size, gamma=0.1) # Set up criterion if opts.loss_type == 'focal_loss': self.criterion = utils.FocalLoss(ignore_index=255, size_average=True) elif opts.loss_type == 'cross_entropy': self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') self.best_mean_iu = 0 self.iteration = 0
def __init__(self, args): opts = vars(args) print(opts) model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } self.opts = opts self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = model_map[opts['model']](num_classes=opts['num_classes'], output_stride=opts['output_stride']) if opts['separable_conv'] == 'True' and 'plus' in opts['model']: network.convert_to_separable_conv(model.classifier) def set_bn_momentum(model, momentum=0.1): for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.momentum = momentum set_bn_momentum(self.model.backbone, momentum=0.01) checkpoint = torch.load(opts['ckpt'], map_location=torch.device('cpu')) self.model.load_state_dict(checkpoint['model_state']) self.model = nn.DataParallel(self.model) self.model.to(self.device) self.model.eval() if not os.path.exists(opts['output']): os.makedirs(opts['output']) # create a color pallette, selecting a color for each class self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) self.colors = torch.as_tensor([i for i in range(opts['num_classes'])])[:, None] * self.palette self.colors = (self.colors % 255).numpy().astype("uint8")
def __init__(self, opts): self.denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } self.opts = opts self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model = model_map[opts['model']]( num_classes=opts['n_classes'], output_stride=opts['output_stride']) if opts['separable_conv'] == 'True' and 'plus' in opts['model']: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(self.model.backbone, momentum=0.01) checkpoint = torch.load(opts['checkpoint'], map_location=torch.device('cpu')) self.model.load_state_dict(checkpoint['model_state']) self.model = nn.DataParallel(self.model) self.model.to(self.device) self.model.eval() if not os.path.exists(opts['output']): os.makedirs(opts['output']) if not os.path.exists(opts['score']): os.makedirs(opts['score']) # create a color pallette, selecting a color for each class self.palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) self.colors = torch.as_tensor([i for i in range(opts['n_classes']) ])[:, None] * self.palette self.colors = (self.colors % 255).numpy().astype("uint8")
def main(): opts = get_argparser().parse_args() if opts.dataset.lower() == 'voc': opts.num_classes = 21 elif opts.dataset.lower() == 'cityscapes': opts.num_classes = 19 # Setup visualization vis = Visualizer(port=opts.vis_port, env=opts.vis_env) if opts.enable_vis else None if vis is not None: # display options vis.vis_table("Options", vars(opts)) os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # Setup dataloader if opts.dataset == 'voc' and not opts.crop_val: opts.val_batch_size = 1 train_dst, val_dst = get_dataset(opts) train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2) val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2) print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(train_dst), len(val_dst))) # Set up model model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet, 'doubleattention_resnet50': network.doubleattention_resnet50, 'doubleattention_resnet101': network.doubleattention_resnet101, 'head_resnet50': network.head_resnet50, 'head_resnet101': network.head_resnet101 } model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # Set up metrics metrics = StreamSegMetrics(opts.num_classes) # Set up optimizer optimizer = torch.optim.SGD(params=[ { 'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr }, { 'params': model.classifier.parameters(), 'lr': opts.lr }, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor) if opts.lr_policy == 'poly': scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) elif opts.lr_policy == 'step': scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) # Set up criterion # criterion = utils.get_loss(opts.loss_type) if opts.loss_type == 'focal_loss': criterion = utils.FocalLoss(ignore_index=255, size_average=True) coss_manifode = utils.ManifondLoss(alpha=1).to(device) elif opts.loss_type == 'cross_entropy': criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') coss_manifode = utils.ManifondLoss(alpha=1).to(device) def save_ckpt(path): """ save current model """ torch.save( { "cur_itrs": cur_itrs, "model_state": model.module.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_score": best_score, }, path) print("Model saved as %s" % path) utils.mkdir('checkpoints') # Restore best_score = 0.0 cur_itrs = 0 cur_epochs = 0 if opts.ckpt is not None and os.path.isfile(opts.ckpt): # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) model = nn.DataParallel(model) model.to(device) if opts.continue_training: optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) cur_itrs = checkpoint["cur_itrs"] best_score = checkpoint['best_score'] print("Training state restored from %s" % opts.ckpt) print("Model restored from %s" % opts.ckpt) del checkpoint # free memory else: print("[!] Retrain") model = nn.DataParallel(model) model.to(device) # ========== Train Loop ==========# vis_sample_id = np.random.randint( 0, len(val_loader), opts.vis_num_samples, np.int32) if opts.enable_vis else None # sample idxs for visualization denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images if opts.test_only: model.eval() val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id) print(metrics.to_str(val_score)) return interval_loss = 0 while True: # cur_itrs < opts.total_itrs: # ===== Train ===== model.train() cur_epochs += 1 for (images, labels) in train_loader: cur_itrs += 1 images = images.to(device, dtype=torch.float32) labels = labels.to(device, dtype=torch.long) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) + coss_manifode(outputs, labels) * 0.01 loss = criterion(outputs, labels) loss.backward() optimizer.step() np_loss = loss.detach().cpu().numpy() interval_loss += np_loss if vis is not None: vis.vis_scalar('Loss', cur_itrs, np_loss) if (cur_itrs) % 10 == 0: interval_loss = interval_loss / 10 print("Epoch %d, Itrs %d/%d, Loss=%f" % (cur_epochs, cur_itrs, opts.total_itrs, interval_loss)) interval_loss = 0.0 if (cur_itrs) % opts.val_interval == 0: save_ckpt('checkpoints/latest_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride)) print("validation...") model.eval() val_score, ret_samples = validate( opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id) print(metrics.to_str(val_score)) if val_score['Mean IoU'] > best_score: # save best model best_score = val_score['Mean IoU'] save_ckpt('checkpoints/best_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride)) if vis is not None: # visualize validation score and samples vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc']) vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU']) vis.vis_table("[Val] Class IoU", val_score['Class IoU']) for k, (img, target, lbl) in enumerate(ret_samples): img = (denorm(img) * 255).astype(np.uint8) target = train_dst.decode_target(target).transpose( 2, 0, 1).astype(np.uint8) lbl = train_dst.decode_target(lbl).transpose( 2, 0, 1).astype(np.uint8) concat_img = np.concatenate( (img, target, lbl), axis=2) # concat along width vis.vis_image('Sample %d' % k, concat_img) model.train() scheduler.step() if cur_itrs >= opts.total_itrs: return
def main(): opts = get_argparser().parse_args() if opts.dataset.lower() == 'voc': opts.num_classes = 21 decode_fn = VOCSegmentation.decode_target elif opts.dataset.lower() == 'cityscapes': opts.num_classes = 19 decode_fn = Cityscapes.decode_target os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup dataloader image_files = [] if os.path.isdir(opts.input): for ext in ['png', 'jpeg', 'jpg', 'JPEG']: files = glob(os.path.join(opts.input, '**/*.%s'%(ext)), recursive=True) if len(files)>0: image_files.extend(files) elif os.path.isfile(opts.input): image_files.append(opts.input) # Set up model (all models are 'constructed at network.modeling) model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) if opts.ckpt is not None and os.path.isfile(opts.ckpt): # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) model = nn.DataParallel(model) model.to(device) print("Resume model from %s" % opts.ckpt) del checkpoint else: print("[!] Retrain") model = nn.DataParallel(model) model.to(device) #denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images if opts.crop_val: transform = T.Compose([ T.Resize(opts.crop_size), T.CenterCrop(opts.crop_size), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) else: transform = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) if opts.save_val_results_to is not None: os.makedirs(opts.save_val_results_to, exist_ok=True) with torch.no_grad(): model = model.eval() for img_path in tqdm(image_files): ext = os.path.basename(img_path).split('.')[-1] img_name = os.path.basename(img_path)[:-len(ext)-1] img = Image.open(img_path).convert('RGB') img = transform(img).unsqueeze(0) # To tensor of NCHW img = img.to(device) pred = model(img).max(1)[1].cpu().numpy()[0] # HW colorized_preds = decode_fn(pred).astype('uint8') colorized_preds = Image.fromarray(colorized_preds) if opts.save_val_results_to: colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png'))
def main(): opts = get_argparser().parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) img_h = 512 img_w = 512 torch.cuda.empty_cache() train_data, test_data = get_data_dcm(img_h=img_h, img_w=img_w, iscrop=True) train_loader = data.DataLoader(train_data, batch_size=opts.batch_size, shuffle=True, num_workers=0) val_loader = data.DataLoader(test_data, batch_size=opts.val_batch_size, shuffle=False, num_workers=0) model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } if opts.model != 'unet': opts.num_classes = 3 model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # Set up optimizer optimizer = torch.optim.SGD(params=[ { 'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr }, { 'params': model.classifier.parameters(), 'lr': opts.lr }, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) else: opts.num_classes = 3 model = UNet(n_channels=3, n_classes=3, bilinear=True) # Set up optimizer optimizer = torch.optim.SGD(model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) criterion_bce = nn.BCELoss(reduction='mean') criterion_dice = MulticlassDiceLoss() def save_ckpt(path): """ save current model """ torch.save( { "cur_itrs": cur_itrs, "model_state": model.module.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_score": best_score, }, path) print("Model saved as %s" % path) utils.mkdir('checkpoints') # Restore best_score = 0.0 cur_itrs = 0 cur_epochs = 0 if opts.ckpt is not None and os.path.isfile(opts.ckpt): checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) model = nn.DataParallel(model) model.to(device) if opts.continue_training: optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) cur_itrs = checkpoint["cur_itrs"] best_score = checkpoint['best_score'] print("Training state restored from %s" % opts.ckpt) print("Model restored from %s" % opts.ckpt) del checkpoint # free memory else: print("[!] Retrain") model = nn.DataParallel(model) model.to(device) if opts.test_only: # model.load_state_dict() model.eval() dice_bl, dice_sb, dice_st, acc = validate2( model=model, loader=val_loader, device=device, itrs=cur_itrs, lr=scheduler.get_lr()[-1], criterion_dice=criterion_dice) # save_ckpt("./checkpoints/CT_" + opts.model + "_" + str(round(dice, 3)) + "__" + str(cur_itrs) + ".pkl") print("dice值:", dice_bl) return best_dice_bl = 0 best_dice_sb = 0 best_dice_st = 0 best_dice_avg = 0 interval_loss = 0 train_iter = iter(train_loader) txt_path = './train_info.txt' # txtUtils.clearTxt(txt_path) while True: # cur_itrs < opts.total_itrs: # ===== Train ===== model.train() try: images, labels = train_iter.__next__() except: train_iter = iter(train_loader) images, labels = train_iter.__next__() cur_itrs += 1 # print(images.size()) # print(labels.size()) images = images.to(device, dtype=torch.float32) # labels = labels.to(device, dtype=torch.long) labels = labels.to(device, dtype=torch.float32) # print(images.size()) outputs = model(images) outputs_ = torch.sigmoid(outputs) loss = criterion_bce(outputs_, labels) + criterion_dice( outputs_, labels) optimizer.zero_grad() loss.backward() optimizer.step() np_loss = loss.item() interval_loss += np_loss if (cur_itrs) % 50 == 0: interval_loss = interval_loss / 50 cur_epochs = int(cur_itrs / train_loader.dataset.__len__()) print("Epoch %d, Itrs %d/%d, Loss=%f" % (cur_epochs, cur_itrs, opts.total_itrs, interval_loss)) content = ("Epoch {}, Itrs {}/{}, Loss={}").format( cur_epochs, cur_itrs, opts.total_itrs, interval_loss) txtUtils.writeInfoToTxt(file_path=txt_path, content=content, is_add_time=True) interval_loss = 0.0 # opts.val_interval=5 if (cur_itrs) % 500 == 0: print("validation... lr:", scheduler.get_lr()) content = ("validation... lr:{}").format(scheduler.get_lr()) txtUtils.writeInfoToTxt(file_path=txt_path, content=content, is_add_time=True) # print(outputs) dice_bl, dice_sb, dice_st, acc = validate2( model=model, loader=val_loader, device=device, itrs=cur_itrs, lr=scheduler.get_lr()[-1], criterion_dice=criterion_dice) dice_avg = (dice_bl + dice_sb + dice_st) / 3 content = ( "dice_bl:{}, dice_sb:{}, dice_st:{}, acc:{}, dice_avg:{}" ).format(dice_bl, dice_sb, dice_st, acc, dice_avg) txtUtils.writeInfoToTxt(file_path=txt_path, content=content, is_add_time=True) if best_dice_avg < dice_avg: best_dice_avg = dice_avg save_ckpt("./checkpoints/" + opts.model + "_dice_avg_" + str(round(best_dice_avg, 3)) + "_dice_bl_" + str(round(dice_bl, 3)) + "_dice_sb_" + str(round(dice_sb, 3)) + "_dice_st_" + str(round(dice_st, 3)) + "__" + str(cur_itrs) + ".pkl") print("best avg dice:", best_dice_avg) content = ("best avg dice: {}").format(best_dice_avg) txtUtils.writeInfoToTxt(file_path=txt_path, content=content, is_add_time=True) if best_dice_bl < dice_bl: best_dice_bl = dice_bl content = ("best bl dice: {}").format(best_dice_bl) txtUtils.writeInfoToTxt(file_path=txt_path, content=content, is_add_time=True) if best_dice_sb < dice_sb: best_dice_sb = dice_sb content = ("best sb dice: {}").format(best_dice_sb) txtUtils.writeInfoToTxt(file_path=txt_path, content=content, is_add_time=True) if best_dice_st < dice_st: best_dice_st = dice_st content = ("best st dice: {}").format(best_dice_st) txtUtils.writeInfoToTxt(file_path=txt_path, content=content, is_add_time=True) scheduler.step() if cur_itrs >= opts.total_itrs: return
def main(): opts = get_argparser().parse_args() if opts.dataset.lower() == 'voc': opts.num_classes = 21 elif opts.dataset.lower() == 'cityscapes': opts.num_classes = 19 # Setup visualization vis = Visualizer(port=opts.vis_port, env=opts.vis_env) if opts.enable_vis else None if vis is not None: # display options vis.vis_table("Options", vars(opts)) os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # Setup dataloader if opts.dataset == 'voc' and not opts.crop_val: opts.val_batch_size = 1 # Set up metrics # metrics = StreamSegMetrics(opts.num_classes) metrics = StreamSegMetrics(21) # Set up optimizer # criterion = utils.get_loss(opts.loss_type) if opts.loss_type == 'focal_loss': criterion = utils.FocalLoss(ignore_index=255, size_average=True) elif opts.loss_type == 'cross_entropy': criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') elif opts.loss_type == 'logit': criterion = nn.BCELoss(reduction='mean') def save_ckpt(path): """ save current model """ torch.save({ "cur_itrs": cur_itrs, "model_state": model.module.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_score": best_score, }, path) print("Model saved as %s" % path) utils.mkdir('checkpoints') # Restore best_score = 0.0 cur_itrs = 0 cur_epochs = 0 if opts.ckpt is not None: print("Error --ckpt, can't read model") return _, val_dst, test_dst = get_dataset(opts) val_loader = data.DataLoader( val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2) test_loader = data.DataLoader( test_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2) vis_sample_id = np.random.randint(0, len(test_loader), opts.vis_num_samples, np.int32) if opts.enable_vis else None # sample idxs for visualization denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images # ========== Test Loop ==========# if opts.test_only: print("Dataset: %s, Val set: %d, Test set: %d" % (opts.dataset, len(val_dst), len(test_dst))) metrics = StreamSegMetrics(21) print("val") test_score, ret_samples = test_single(opts=opts, loader=test_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id) print("test") test_score, ret_samples = test_multiple( opts=opts, loader=test_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id) print(metrics.to_str(test_score)) return # ========== Train Loop ==========# utils.mkdir('checkpoints/multiple_model2') for class_num in range(opts.start_class, opts.num_classes): # ========== Dataset ==========# train_dst, val_dst, test_dst = get_dataset_multiple(opts, class_num) train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2) val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2) test_loader = data.DataLoader(test_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2) print("Dataset: %s Class %d, Train set: %d, Val set: %d, Test set: %d" % ( opts.dataset, class_num, len(train_dst), len(val_dst), len(test_dst))) # ========== Model ==========# model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # ========== Params and learning rate ==========# params_list = [ {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr}, {'params': model.classifier.parameters(), 'lr': 0.1 * opts.lr} # opts.lr ] if 'SA' in opts.model: params_list.append({'params': model.attention.parameters(), 'lr': 0.1 * opts.lr}) optimizer = torch.optim.Adam(params=params_list, lr=opts.lr, weight_decay=opts.weight_decay) if opts.lr_policy == 'poly': scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) elif opts.lr_policy == 'step': scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) model = nn.DataParallel(model) model.to(device) best_score = 0.0 cur_itrs = 0 cur_epochs = 0 interval_loss = 0 while True: # cur_itrs < opts.total_itrs: # ===== Train ===== model.train() cur_epochs += 1 for (images, labels) in train_loader: cur_itrs += 1 images = images.to(device, dtype=torch.float32) labels = labels.to(device, dtype=torch.long) # labels=(labels==class_num).float() optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() np_loss = loss.detach().cpu().numpy() interval_loss += np_loss if vis is not None: vis.vis_scalar('Loss', cur_itrs, np_loss) if (cur_itrs) % 10 == 0: interval_loss = interval_loss / 10 print("Epoch %d, Itrs %d/%d, Loss=%f" % (cur_epochs, cur_itrs, opts.total_itrs, interval_loss)) interval_loss = 0.0 if (cur_itrs) % opts.val_interval == 0: save_ckpt('checkpoints/multiple_model2/latest_%s_%s_class%d_os%d.pth' % (opts.model, opts.dataset, class_num, opts.output_stride,)) print("validation...") model.eval() val_score, ret_samples = validate( opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id, class_num=class_num) print(metrics.to_str(val_score)) if val_score['Mean IoU'] > best_score: # save best model best_score = val_score['Mean IoU'] save_ckpt('checkpoints/multiple_model2/best_%s_%s_class%d_os%d.pth' % (opts.model, opts.dataset, class_num, opts.output_stride)) if vis is not None: # visualize validation score and samples vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc']) vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU']) vis.vis_table("[Val] Class IoU", val_score['Class IoU']) for k, (img, target, lbl) in enumerate(ret_samples): img = (denorm(img) * 255).astype(np.uint8) target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8) lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8) concat_img = np.concatenate((img, target, lbl), axis=2) # concat along width vis.vis_image('Sample %d' % k, concat_img) model.train() scheduler.step() if cur_itrs >= opts.total_itrs: save_ckpt('checkpoints/multiple_model2/latest_%s_%s_class%d_os%d.pth' % (opts.model, opts.dataset, class_num, opts.output_stride,)) print("Saving..") break if cur_itrs >= opts.total_itrs: cur_itrs = 0 break print("Model of class %d is trained and saved " % (class_num))
def main(): opts = get_argparser().parse_args() if opts.dataset.lower() == 'voc': opts.num_classes = 21 elif opts.dataset.lower() == 'cityscapes': opts.num_classes = 19 # Setup visualization vis = Visualizer(port=opts.vis_port, env=opts.vis_env) if opts.enable_vis else None if vis is not None: # display options vis.vis_table("Options", vars(opts)) os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # Setup dataloader if opts.dataset=='voc' and not opts.crop_val: opts.val_batch_size = 1 pipe = create_dali_pipeline(batch_size=opts.batch_size, num_threads=8, device_id=0, data_dir="/home/ubuntu/cityscapes") pipe.build() train_loader = DALIGenericIterator(pipe, output_map=['image', 'label'], last_batch_policy=LastBatchPolicy.PARTIAL) # Set up model model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # Set up metrics metrics = StreamSegMetrics(opts.num_classes) # Set up optimizer optimizer = torch.optim.SGD(params=[ {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr}, {'params': model.classifier.parameters(), 'lr': opts.lr}, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) #optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) #torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor) if opts.lr_policy=='poly': scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) elif opts.lr_policy=='step': scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) # Set up criterion #criterion = utils.get_loss(opts.loss_type) if opts.loss_type == 'focal_loss': criterion = utils.FocalLoss(ignore_index=255, size_average=True) elif opts.loss_type == 'cross_entropy': criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') def save_ckpt(path): """ save current model """ torch.save({ "cur_itrs": cur_itrs, "model_state": model.module.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_score": best_score, }, path) print("Model saved as %s" % path) utils.mkdir('checkpoints') # Restore best_score = 0.0 cur_itrs = 0 cur_epochs = 0 if opts.ckpt is not None and os.path.isfile(opts.ckpt): # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) model = nn.DataParallel(model) model.to(device) if opts.continue_training: optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) cur_itrs = checkpoint["cur_itrs"] best_score = checkpoint['best_score'] print("Training state restored from %s" % opts.ckpt) print("Model restored from %s" % opts.ckpt) del checkpoint # free memory else: print("[!] Retrain") model = nn.DataParallel(model) model.to(device) #========== Train Loop ==========# interval_loss = 0 class_conv = [255, 255, 255, 255, 255, 255, 255, 255, 0, 1, 255, 255, 2, 3, 4, 255, 255, 255, 5, 255, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 255, 255, 16, 17, 18] while True: #cur_itrs < opts.total_itrs: # ===== Train ===== model.train() #model = model.half() cur_epochs += 1 while True: train_iter = iter(train_loader) try: nvtx.range_push("Batch " + str(cur_itrs)) nvtx.range_push("Data loading") data = next(train_iter) cur_itrs += 1 images = data[0]['image'].to(dtype=torch.float32) labels = data[0]['label'][:, :, :, 0].to(dtype=torch.long) labels = torch.zeros(data[0]['label'][:, :, :, 0].shape).to(device, dtype=torch.long) nvtx.range_pop() nvtx.range_push("Forward pass") optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) nvtx.range_pop() nvtx.range_push("Backward pass") loss.backward() optimizer.step() nvtx.range_pop() np_loss = loss.detach().cpu().numpy() interval_loss += np_loss nvtx.range_pop() if cur_itrs == 10: break if vis is not None: vis.vis_scalar('Loss', cur_itrs, np_loss) print("Epoch %d, Itrs %d/%d, Loss=%f" % (cur_epochs, cur_itrs, opts.total_itrs, interval_loss)) interval_loss = 0.0 scheduler.step() if cur_itrs >= opts.total_itrs: return except StopIteration: break break
def main(): opts = get_argparser().parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) os.makedirs(opts.logit_dir, exist_ok=True) # Setup dataloader if not opts.crop_val: opts.val_batch_size = 1 val_dst = get_dataset(opts) val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=False, num_workers=4) print("Dataset: voc, Val set: %d" % (len(val_dst))) # Set up model model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # Set up metrics metrics = StreamSegMetrics(opts.num_classes) # Restore if opts.ckpt is not None and os.path.isfile(opts.ckpt): # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) model = nn.DataParallel(model) model.to(device) print("Model restored from %s" % opts.ckpt) del checkpoint # free memory else: assert "no checkpoint" #========== Eval ==========# model.eval() val_score = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) print(metrics.to_str(val_score)) print("\n\n----------- crf -------------") crf_score = crf_inference(opts, val_dst, metrics) print(metrics.to_str(crf_score)) os.system(f"rm -rf {opts.logit_dir}")
def main(criterion): # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # Setup dataloader train_dst, val_dst = get_dataset(opts) train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2) val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=False, num_workers=2) print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(train_dst), len(val_dst))) # Set up model pretrained_backbone = False if "ACE2P" in opts.model else True model = network.model_map[opts.model]( num_classes=opts.num_classes, output_stride=opts.output_stride, pretrained_backbone=pretrained_backbone, use_abn=opts.use_abn) if opts.use_schp: schp_model = network.model_map[opts.model]( num_classes=opts.num_classes, output_stride=opts.output_stride, pretrained_backbone=pretrained_backbone, use_abn=opts.use_abn) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # Set up metrics metrics = StreamSegMetrics(opts.num_classes) # Set up optimizer model_params = [ { 'params': model.backbone.parameters(), 'lr': 0.01 * opts.lr }, { 'params': model.classifier.parameters(), 'lr': opts.lr }, ] optimizer = create_optimizer(opts, model_params=model_params) # optimizer = torch.optim.SGD(params=[ # {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr}, # {'params': model.classifier.parameters(), 'lr': opts.lr}, # ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor) if opts.lr_policy == 'poly': scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) elif opts.lr_policy == 'step': scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) def save_ckpt(path): """ save current model """ torch.save( { "cur_epochs": cur_epochs, "cur_itrs": cur_itrs, "model_state": model.module.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_score": best_score, }, path) print("Model saved as %s" % path) utils.mkdir('checkpoints') # Restore best_score = 0.0 cur_itrs = 0 cur_epochs = 0 cycle_n = 0 if opts.use_schp and opts.schp_ckpt is not None and os.path.isfile( opts.schp_ckpt): # TODO: there is a problem with this part. checkpoint = torch.load(opts.schp_ckpt, map_location=torch.device('cpu')) schp_model.load_state_dict(checkpoint["model_state"]) print("SCHP Model restored from %s" % opts.schp_ckpt) if opts.ckpt is not None and os.path.isfile(opts.ckpt): checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) model = nn.DataParallel(model) model.to(device) if opts.use_schp: schp_model = nn.DataParallel(schp_model) schp_model.to(device) if opts.continue_training: optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) cur_epochs = checkpoint[ "cur_epochs"] - 1 # to start from the last epoch for schp cur_itrs = checkpoint["cur_itrs"] best_score = checkpoint['best_score'] print("Training state restored from %s" % opts.ckpt) print("Model restored from %s" % opts.ckpt) del checkpoint # free memory else: print("[!] Retrain") model = nn.DataParallel(model) model.to(device) if opts.use_schp: schp_model = nn.DataParallel(schp_model) schp_model.to(device) # ========== Train Loop ==========# if opts.test_only: model.eval() val_score = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) print(metrics.to_str(val_score)) return interval_loss = 0 while True: # cur_itrs < opts.total_itrs: # ===== Train ===== criterion.start_log() model.train() cur_epochs += 1 for (images, labels) in train_loader: cur_itrs += 1 # images = images.to(device, dtype=torch.float32) # labels = labels.to(device, dtype=torch.long) images, labels = get_input(images, labels, opts, device, cur_itrs) if opts.use_mixup: images, main_images = images else: main_images = None images = images[:, [2, 1, 0]] # for backbone optimizer.zero_grad() outputs = model(images) if opts.use_schp: # Online Self Correction Cycle with Label Refinement soft_labels = [] if cycle_n >= 1: with torch.no_grad(): if opts.use_mixup: soft_preds = [ schp_model(main_images[0]), schp_model(main_images[1]) ] soft_edges = [None, None] else: soft_preds = schp_model(images) soft_edges = None if 'ACE2P' in opts.model: soft_edges = soft_preds[1][-1] soft_preds = soft_preds[0][-1] # soft_parsing = [] # soft_edge = [] # for soft_pred in soft_preds: # soft_parsing.append(soft_pred[0][-1]) # soft_edge.append(soft_pred[1][-1]) # soft_preds = torch.cat(soft_parsing, dim=0) # soft_edges = torch.cat(soft_edge, dim=0) else: if opts.use_mixup: soft_preds = [None, None] soft_edges = [None, None] else: soft_preds = None soft_edges = None soft_labels.append(soft_preds) soft_labels.append(soft_edges) labels = [labels, soft_labels] # loss = criterion(outputs, labels) loss = calc_loss(criterion, outputs, labels, opts, cycle_n) loss.backward() optimizer.step() criterion.batch_step(len(images)) np_loss = loss.detach().cpu().numpy() interval_loss += np_loss sub_loss_text = '' for sub_loss, sub_prop in zip(criterion.losses, criterion.loss): if sub_prop['weight'] > 0: sub_loss_text += f", {sub_prop['type']}: {sub_loss.item():.4f}" print( f"\rEpoch {cur_epochs}, Itrs {cur_itrs}/{opts.total_itrs}, Loss={np_loss:.4f}{sub_loss_text}", end='') if (cur_itrs) % 10 == 0: interval_loss = interval_loss / 10 print( f"\rEpoch {cur_epochs}, Itrs {cur_itrs}/{opts.total_itrs}, Loss={interval_loss:.4f} {criterion.display_loss().replace('][',', ')}" ) interval_loss = 0.0 torch.cuda.empty_cache() if (cur_itrs) % opts.save_interval == 0 and ( cur_itrs) % opts.val_interval != 0: save_ckpt('checkpoints/latest_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride)) if (cur_itrs) % opts.val_interval == 0: save_ckpt('checkpoints/latest_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride)) print("validation...") model.eval() val_score = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) print(metrics.to_str(val_score)) if val_score['Mean IoU'] > best_score: # save best model best_score = val_score['Mean IoU'] save_ckpt('checkpoints/best_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride)) # save_ckpt('/content/drive/MyDrive/best_%s_%s_os%d.pth' % # (opts.model, opts.dataset, opts.output_stride)) model.train() scheduler.step() if cur_itrs >= opts.total_itrs: criterion.end_log(len(train_loader)) return # Self Correction Cycle with Model Aggregation if opts.use_schp: if (cur_epochs + 1) >= opts.schp_start and ( cur_epochs + 1 - opts.schp_start) % opts.cycle_epochs == 0: print(f'\nSelf-correction cycle number {cycle_n}') schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1)) cycle_n += 1 schp.bn_re_estimate(train_loader, schp_model) schp.save_schp_checkpoint( { 'state_dict': schp_model.state_dict(), 'cycle_n': cycle_n, }, False, "checkpoints", filename= f'schp_{opts.model}_{opts.dataset}_cycle{cycle_n}_checkpoint.pth' ) # schp.save_schp_checkpoint({ # 'state_dict': schp_model.state_dict(), # 'cycle_n': cycle_n, # }, False, '/content/drive/MyDrive/', filename=f'schp_{opts.model}_{opts.dataset}_checkpoint.pth') torch.cuda.empty_cache() criterion.end_log(len(train_loader))
def main(): opts = get_argparser().parse_args() if opts.dataset.lower() == 'voc': opts.num_classes = 21 ignore_index = 255 elif opts.dataset.lower() == 'cityscapes': opts.num_classes = 19 ignore_index = 255 elif opts.dataset.lower() == 'ade20k': opts.num_classes = 150 ignore_index = -1 elif opts.dataset.lower() == 'lvis': opts.num_classes = 1284 ignore_index = -1 elif opts.dataset.lower() == 'coco': opts.num_classes = 182 ignore_index = 255 if (opts.reduce_dim == False): opts.num_channels = opts.num_classes if (opts.test_only == False): writer = SummaryWriter('summary/' + opts.vis_env) # Setup visualization vis = Visualizer(port=opts.vis_port, env=opts.vis_env) if opts.enable_vis else None if vis is not None: # display options vis.vis_table("Options", vars(opts)) os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # Setup dataloader if opts.dataset == 'voc' and not opts.crop_val: opts.val_batch_size = 1 train_dst, val_dst = get_dataset(opts) train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2) val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=False, num_workers=2) print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(train_dst), len(val_dst))) epoch_interval = int(len(train_dst) / opts.batch_size) if (epoch_interval > 5000): opts.val_interval = 5000 else: opts.val_interval = epoch_interval print("Evaluation after %d iterations" % (opts.val_interval)) # Set up model model_map = { #'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, #'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, #'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } if (opts.reduce_dim): num_classes_input = [opts.num_channels, opts.num_classes] else: num_classes_input = [opts.num_classes] model = model_map[opts.model](num_classes=num_classes_input, output_stride=opts.output_stride, reduce_dim=opts.reduce_dim) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # Set up metrics metrics = StreamSegMetrics(opts.num_classes) if opts.reduce_dim: emb_layer = ['embedding.weight'] params_classifier = list( map( lambda x: x[1], list( filter(lambda kv: kv[0] not in emb_layer, model.classifier.named_parameters())))) params_embedding = list( map( lambda x: x[1], list( filter(lambda kv: kv[0] in emb_layer, model.classifier.named_parameters())))) if opts.freeze_backbone: for param in model.backbone.parameters(): param.requires_grad = False optimizer = torch.optim.SGD( params=[ #@{'params': model.backbone.parameters(),'lr':0.1*opts.lr}, { 'params': params_classifier, 'lr': opts.lr }, { 'params': params_embedding, 'lr': opts.lr, 'momentum': 0.95 }, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) else: optimizer = torch.optim.SGD(params=[ { 'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr }, { 'params': params_classifier, 'lr': opts.lr }, { 'params': params_embedding, 'lr': opts.lr }, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) # Set up optimizer else: optimizer = torch.optim.SGD(params=[ { 'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr }, { 'params': model.classifier.parameters(), 'lr': opts.lr }, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) if opts.lr_policy == 'poly': scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) elif opts.lr_policy == 'step': scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) elif opts.lr_policy == 'multi_poly': scheduler = utils.MultiPolyLR(optimizer, opts.total_itrs, power=[0.9, 0.9, 0.95]) # Set up criterion if (opts.reduce_dim): opts.loss_type = 'nn_cross_entropy' else: opts.loss_type = 'cross_entropy' if opts.loss_type == 'cross_entropy': criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean') elif opts.loss_type == 'nn_cross_entropy': criterion = utils.NNCrossEntropy(ignore_index=ignore_index, reduction='mean', num_neighbours=opts.num_neighbours, temp=opts.temp, dataset=opts.dataset) def save_ckpt(path): """ save current model """ torch.save( { "cur_itrs": cur_itrs, "model_state": model.module.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_score": best_score, }, path) print("Model saved as %s" % path) utils.mkdir(opts.checkpoint_dir) # Restore best_score = 0.0 cur_itrs = 0 cur_epochs = 0 if opts.ckpt is not None and os.path.isfile(opts.ckpt): checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) model = nn.DataParallel(model) model.to(device) increase_iters = True if opts.continue_training: optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) cur_itrs = checkpoint["cur_itrs"] best_score = checkpoint['best_score'] print("scheduler state dict :", scheduler.state_dict()) print("Training state restored from %s" % opts.ckpt) print("Model restored from %s" % opts.ckpt) del checkpoint # free memory else: print("[!] Retrain") model = nn.DataParallel(model) model.to(device) vis_sample_id = np.random.randint( 0, len(val_loader), opts.vis_num_samples, np.int32) if opts.enable_vis else None # sample idxs for visualization denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images if opts.test_only: model.eval() val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id) print(metrics.to_str(val_score)) return interval_loss = 0 writer.add_text('lr', str(opts.lr)) writer.add_text('batch_size', str(opts.batch_size)) writer.add_text('reduce_dim', str(opts.reduce_dim)) writer.add_text('checkpoint_dir', opts.checkpoint_dir) writer.add_text('dataset', opts.dataset) writer.add_text('num_channels', str(opts.num_channels)) writer.add_text('num_neighbours', str(opts.num_neighbours)) writer.add_text('loss_type', opts.loss_type) writer.add_text('lr_policy', opts.lr_policy) writer.add_text('temp', str(opts.temp)) writer.add_text('crop_size', str(opts.crop_size)) writer.add_text('model', opts.model) accumulation_steps = 1 writer.add_text('accumulation_steps', str(accumulation_steps)) j = 0 updateflag = False while True: # ===== Train ===== model.train() cur_epochs += 1 for (images, labels) in train_loader: cur_itrs += 1 images = images.to(device, dtype=torch.float32) labels = labels.to(device, dtype=torch.long) if (opts.dataset == 'ade20k' or opts.dataset == 'lvis'): labels = labels - 1 optimizer.zero_grad() if (opts.reduce_dim): outputs, class_emb = model(images) loss = criterion(outputs, labels, class_emb) else: outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() model.zero_grad() j = j + 1 np_loss = loss.detach().cpu().numpy() interval_loss += np_loss if vis is not None: vis.vis_scalar('Loss', cur_itrs, np_loss) vis.vis_scalar('LR', cur_itrs, scheduler.state_dict()['_last_lr'][0]) torch.cuda.empty_cache() del images, labels, outputs, loss if (opts.reduce_dim): del class_emb gc.collect() if (cur_itrs) % 50 == 0: interval_loss = interval_loss / 50 print("Epoch %d, Itrs %d/%d, Loss=%f" % (cur_epochs, cur_itrs, opts.total_itrs, interval_loss)) writer.add_scalar('Loss', interval_loss, cur_itrs) writer.add_scalar('lr', scheduler.state_dict()['_last_lr'][0], cur_itrs) if cur_itrs % opts.val_interval == 0: save_ckpt(opts.checkpoint_dir + '/latest_%d.pth' % (cur_itrs)) if cur_itrs % opts.val_interval == 0: print("validation...") model.eval() val_score, ret_samples = validate( opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id) print(metrics.to_str(val_score)) if val_score['Mean IoU'] > best_score: # save best model best_score = val_score['Mean IoU'] save_ckpt(opts.checkpoint_dir + '/best_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride)) writer.add_scalar('[Val] Overall Acc', val_score['Overall Acc'], cur_itrs) writer.add_scalar('[Val] Mean IoU', val_score['Mean IoU'], cur_itrs) writer.add_scalar('[Val] Mean Acc', val_score['Mean Acc'], cur_itrs) writer.add_scalar('[Val] Freq Acc', val_score['FreqW Acc'], cur_itrs) if vis is not None: # visualize validation score and samples vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc']) vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU']) vis.vis_table("[Val] Class IoU", val_score['Class IoU']) for k, (img, target, lbl) in enumerate(ret_samples): img = (denorm(img) * 255).astype(np.uint8) if (opts.dataset.lower() == 'coco'): target = numpy.asarray( train_dst._colorize_mask(target).convert( 'RGB')).transpose(2, 0, 1).astype(np.uint8) lbl = numpy.asarray( train_dst._colorize_mask(lbl).convert( 'RGB')).transpose(2, 0, 1).astype(np.uint8) else: target = train_dst.decode_target(target).transpose( 2, 0, 1).astype(np.uint8) lbl = train_dst.decode_target(lbl).transpose( 2, 0, 1).astype(np.uint8) concat_img = np.concatenate( (img, target, lbl), axis=2) # concat along width vis.vis_image('Sample %d' % k, concat_img) model.train() scheduler.step() if cur_itrs >= opts.total_itrs: return writer.close()
def main(): opts = get_argparser().parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) img_h = 512 img_w = 512 torch.cuda.empty_cache() train_data, test_data = get_data_dcm(img_h=img_h, img_w=img_w, iscrop=True) train_loader = data.DataLoader(train_data, batch_size=opts.batch_size, shuffle=True, num_workers=0) val_loader = data.DataLoader(test_data, batch_size=opts.val_batch_size, shuffle=False, num_workers=0) model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } if base_model != 'unet': opts.num_classes = 3 model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # Set up optimizer optimizer = torch.optim.SGD(params=[ { 'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr }, { 'params': model.classifier.parameters(), 'lr': opts.lr }, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) else: opts.num_classes = 3 model = UNet(n_channels=3, n_classes=3, bilinear=True) optimizer = torch.optim.SGD(model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) criterion_bce = nn.BCELoss(reduction='mean') criterion_dice = MulticlassDiceLoss() def save_ckpt(path): """ save current model """ torch.save( { "cur_itrs": cur_itrs, "model_state": model.module.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_score": best_score, }, path) print("Model saved as %s" % path) utils.mkdir('checkpoints') # Restore best_score = 0.0 cur_itrs = 0 cur_epochs = 0 if opts.ckpt is not None and os.path.isfile(opts.ckpt): checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) model = nn.DataParallel(model) model.to(device) if opts.continue_training: optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) cur_itrs = checkpoint["cur_itrs"] best_score = checkpoint['best_score'] print("Training state restored from %s" % opts.ckpt) print("Model restored from %s" % opts.ckpt) del checkpoint # free memory else: print("[!] Retrain") model = nn.DataParallel(model) model.to(device) if opts.test_only: # model.load_state_dict() model.eval() dice_bl, dice_sb, dice_st, acc = validate2( model=model, loader=val_loader, device=device, itrs=cur_itrs, lr=scheduler.get_lr()[-1], criterion_dice=criterion_dice) # save_ckpt("./checkpoints/CT_" + opts.model + "_" + str(round(dice, 3)) + "__" + str(cur_itrs) + ".pkl") print("dice值:", dice_bl) return
def main(): opts = parser.parse_args() vis = Visualizer(port=opts.vis_port, env=opts.vis_env) if opts.enable_vis else None if vis is not None: # display options vis.vis_table("Options", vars(opts)) os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print("Device: %s" % device) # Setup random seed torch.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # Set up model model_map = { 'deeplabv3_resnet50': network.deeplabv3_resnet50, 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50, 'deeplabv3_resnet101': network.deeplabv3_resnet101, 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101, 'deeplabv3_mobilenet': network.deeplabv3_mobilenet, 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet } model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.classifier) utils.set_bn_momentum(model.backbone, momentum=0.01) # Set up metrics metrics = StreamSegMetrics(opts.num_classes) # Set up optimizer optimizer = torch.optim.SGD(params=[ {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr}, {'params': model.classifier.parameters(), 'lr': opts.lr}, ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor) if opts.lr_policy == 'poly': scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) elif opts.lr_policy == 'step': scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) else: scheduler = None print("please assign a scheduler!") utils.mkdir('checkpoints') mytrainer = trainer(model, optimizer, scheduler, device, cfg=opts) # ========== Train Loop ==========# #loss_list = ['bound_dice', 'v3_bound_dice'] #loss_list = ['v5_bound_dice', 'v4_bound_dice'] loss_list = ['focal'] if opts.test_only: loss_i = 'v3' ckpt = os.path.join("checkpoints", loss_i, "latest_deeplabv3plus_mobilenet_coco_epoch01.pth") mytrainer.validate(ckpt, loss_i) else: for loss_i in loss_list: mytrainer.train(loss_i)