labels = labels.to(device, dtype=torch.long)
    labels = labels.squeeze(1)
    
    # get loss
    optimizer.zero_grad()
    outputs = model(images)

    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

    # metrics
    preds = outputs.detach().max(dim=1)[1].cpu().numpy()
    targets = labels.cpu().numpy()
    metrics.update(targets, preds)

    end = time.time()

    if step%10==0:
      print('Epoch: ',str(epoch),' Iter: ',step,'Loss: ',loss.item(),)

    print('iter time: ',end-start)  

  # update training_loss, training_accuracy and training_iou 
  train_loss = train_loss/float(len(train_loader))
  train_loss_list.append(train_loss)
  results = metrics.get_results()
  train_iou = results["Mean IoU"]
  train_iou_list.append(train_iou)
Exemple #2
0
def main():
    opts = get_argparser().parse_args()
    opts = modify_command_options(opts)

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    print("Device: %s"%device)

    # Set up random seed
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Set up dataloader
    _, val_dst = get_dataset(opts)
    val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size if opts.crop_val else 1 , shuffle=False, num_workers=opts.num_workers)
    print("Dataset: %s, Val set: %d"%(opts.dataset, len(val_dst)))
    
    # Set up model
    print("Backbone: %s"%opts.backbone)
    model = DeepLabv3(num_classes=opts.num_classes, backbone=opts.backbone, pretrained=True, momentum=opts.bn_mom, output_stride=opts.output_stride, use_separable_conv=opts.use_separable_conv)
    if opts.use_gn==True:
        print("[!] Replace BatchNorm with GroupNorm!")
        model = utils.convert_bn2gn(model)

    if torch.cuda.device_count()>1: # Parallel
        print("%d GPU parallel"%(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
        model_ref = model.module # for ckpt
    else:
        model_ref = model
    model = model.to(device)
    
    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    if opts.save_path is not None:
        utils.mkdir(opts.save_path)

    # Restore
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt)
        model_ref.load_state_dict(checkpoint["model_state"])
        print("Model restored from %s"%opts.ckpt)
    else:
        print("[!] Retrain")
    
    label2color = utils.Label2Color(cmap=utils.color_map(opts.dataset)) # convert labels to images
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],  
                               std=[0.229, 0.224, 0.225])  # denormalization for ori images
    model.eval()
    metrics.reset()
    idx = 0

    if opts.save_path is not None:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
        

    with torch.no_grad():
        for i, (images, labels) in tqdm( enumerate( val_loader ) ):
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            outputs = model(images)
            preds = outputs.detach().max(dim=1)[1].cpu().numpy()
            targets = labels.cpu().numpy()
            
            metrics.update(targets, preds)
            if opts.save_path is not None:
                for i in range(len(images)):
                    image = images[i].detach().cpu().numpy()
                    target = targets[i]
                    pred = preds[i]

                    image = (denorm(image) * 255).transpose(1,2,0).astype(np.uint8)
                    target = label2color(target).astype(np.uint8)
                    pred = label2color(pred).astype(np.uint8)

                    Image.fromarray(image).save(os.path.join(opts.save_path, '%d_image.png'%idx) )
                    Image.fromarray(target).save(os.path.join(opts.save_path, '%d_target.png'%idx) )
                    Image.fromarray(pred).save(os.path.join(opts.save_path, '%d_pred.png'%idx) )
                    
                    fig = plt.figure()
                    plt.imshow(image)
                    plt.axis('off')
                    plt.imshow(pred, alpha=0.7)
                    ax = plt.gca()
                    ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    plt.savefig(os.path.join(opts.save_path, '%d_overlay.png'%idx), bbox_inches='tight', pad_inches=0)
                    plt.close()
                    idx+=1
                
    score = metrics.get_results()
    print(metrics.to_str(score))
    if opts.save_path is not None:
        with open(os.path.join(opts.save_path, 'score.txt'), mode='w') as f:
            f.write(metrics.to_str(score))