labels = labels.to(device, dtype=torch.long) labels = labels.squeeze(1) # get loss optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() # metrics preds = outputs.detach().max(dim=1)[1].cpu().numpy() targets = labels.cpu().numpy() metrics.update(targets, preds) end = time.time() if step%10==0: print('Epoch: ',str(epoch),' Iter: ',step,'Loss: ',loss.item(),) print('iter time: ',end-start) # update training_loss, training_accuracy and training_iou train_loss = train_loss/float(len(train_loader)) train_loss_list.append(train_loss) results = metrics.get_results() train_iou = results["Mean IoU"] train_iou_list.append(train_iou)
def main(): opts = get_argparser().parse_args() opts = modify_command_options(opts) os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) print("Device: %s"%device) # Set up random seed torch.manual_seed(opts.random_seed) torch.cuda.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # Set up dataloader _, val_dst = get_dataset(opts) val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size if opts.crop_val else 1 , shuffle=False, num_workers=opts.num_workers) print("Dataset: %s, Val set: %d"%(opts.dataset, len(val_dst))) # Set up model print("Backbone: %s"%opts.backbone) model = DeepLabv3(num_classes=opts.num_classes, backbone=opts.backbone, pretrained=True, momentum=opts.bn_mom, output_stride=opts.output_stride, use_separable_conv=opts.use_separable_conv) if opts.use_gn==True: print("[!] Replace BatchNorm with GroupNorm!") model = utils.convert_bn2gn(model) if torch.cuda.device_count()>1: # Parallel print("%d GPU parallel"%(torch.cuda.device_count())) model = torch.nn.DataParallel(model) model_ref = model.module # for ckpt else: model_ref = model model = model.to(device) # Set up metrics metrics = StreamSegMetrics(opts.num_classes) if opts.save_path is not None: utils.mkdir(opts.save_path) # Restore if opts.ckpt is not None and os.path.isfile(opts.ckpt): checkpoint = torch.load(opts.ckpt) model_ref.load_state_dict(checkpoint["model_state"]) print("Model restored from %s"%opts.ckpt) else: print("[!] Retrain") label2color = utils.Label2Color(cmap=utils.color_map(opts.dataset)) # convert labels to images denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images model.eval() metrics.reset() idx = 0 if opts.save_path is not None: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt with torch.no_grad(): for i, (images, labels) in tqdm( enumerate( val_loader ) ): images = images.to(device, dtype=torch.float32) labels = labels.to(device, dtype=torch.long) outputs = model(images) preds = outputs.detach().max(dim=1)[1].cpu().numpy() targets = labels.cpu().numpy() metrics.update(targets, preds) if opts.save_path is not None: for i in range(len(images)): image = images[i].detach().cpu().numpy() target = targets[i] pred = preds[i] image = (denorm(image) * 255).transpose(1,2,0).astype(np.uint8) target = label2color(target).astype(np.uint8) pred = label2color(pred).astype(np.uint8) Image.fromarray(image).save(os.path.join(opts.save_path, '%d_image.png'%idx) ) Image.fromarray(target).save(os.path.join(opts.save_path, '%d_target.png'%idx) ) Image.fromarray(pred).save(os.path.join(opts.save_path, '%d_pred.png'%idx) ) fig = plt.figure() plt.imshow(image) plt.axis('off') plt.imshow(pred, alpha=0.7) ax = plt.gca() ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator()) plt.savefig(os.path.join(opts.save_path, '%d_overlay.png'%idx), bbox_inches='tight', pad_inches=0) plt.close() idx+=1 score = metrics.get_results() print(metrics.to_str(score)) if opts.save_path is not None: with open(os.path.join(opts.save_path, 'score.txt'), mode='w') as f: f.write(metrics.to_str(score))