def evaluate(arch, dataset, ignore_label, restore_from, pretrained_model, save_dir, device): import argparse import scipy from scipy import ndimage import cv2 import numpy as np import sys from collections import OrderedDict import os import torch import torch.nn as nn from torch.autograd import Variable import torchvision.models as models import torch.nn.functional as F from torch.utils import data, model_zoo from model.deeplab import Res_Deeplab from model.unet import unet_resnet50 from model.deeplabv3 import resnet101_deeplabv3 from dataset.voc_dataset import VOCDataSet from PIL import Image import matplotlib.pyplot as plt pretrianed_models_dict = { 'semi0.125': 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSemiSegVOC0.125-03c6f81c.pth', 'semi0.25': 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSemiSegVOC0.25-473f8a14.pth', 'semi0.5': 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSemiSegVOC0.5-acf6a654.pth', 'advFull': 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSegVOCFull-92fbc7ee.pth' } class VOCColorize(object): def __init__(self, n=22): self.cmap = color_map(22) self.cmap = torch.from_numpy(self.cmap[:n]) def __call__(self, gray_image): size = gray_image.shape color_image = np.zeros((3, size[0], size[1]), dtype=np.uint8) for label in range(0, len(self.cmap)): mask = (label == gray_image) color_image[0][mask] = self.cmap[label][0] color_image[1][mask] = self.cmap[label][1] color_image[2][mask] = self.cmap[label][2] # handle void mask = (255 == gray_image) color_image[0][mask] = color_image[1][mask] = color_image[2][ mask] = 255 return color_image def color_map(N=256, normalized=False): def bitget(byteval, idx): return ((byteval & (1 << idx)) != 0) dtype = 'float32' if normalized else 'uint8' cmap = np.zeros((N, 3), dtype=dtype) for i in range(N): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << 7 - j) g = g | (bitget(c, 1) << 7 - j) b = b | (bitget(c, 2) << 7 - j) c = c >> 3 cmap[i] = np.array([r, g, b]) cmap = cmap / 255 if normalized else cmap return cmap def get_iou(data_list, class_num, ignore_label, class_names, save_path=None): from multiprocessing import Pool from utils.evaluation import EvaluatorIoU evaluator = EvaluatorIoU(class_num) for truth, prediction in data_list: evaluator.sample(truth, prediction, ignore_value=ignore_label) per_class_iou = evaluator.score() mean_iou = per_class_iou.mean() for i, (class_name, iou) in enumerate(zip(class_names, per_class_iou)): print('class {:2d} {:12} IU {:.2f}'.format(i, class_name, iou)) print('meanIOU: ' + str(mean_iou) + '\n') if save_path: with open(save_path, 'w') as f: for i, (class_name, iou) in enumerate(zip(class_names, per_class_iou)): f.write('class {:2d} {:12} IU {:.2f}'.format( i, class_name, iou) + '\n') f.write('meanIOU: ' + str(mean_iou) + '\n') def show_all(gt, pred): import matplotlib.pyplot as plt from matplotlib import colors from mpl_toolkits.axes_grid1 import make_axes_locatable fig, axes = plt.subplots(1, 2) ax1, ax2 = axes colormap = [(0, 0, 0), (0.5, 0, 0), (0, 0.5, 0), (0.5, 0.5, 0), (0, 0, 0.5), (0.5, 0, 0.5), (0, 0.5, 0.5), (0.5, 0.5, 0.5), (0.25, 0, 0), (0.75, 0, 0), (0.25, 0.5, 0), (0.75, 0.5, 0), (0.25, 0, 0.5), (0.75, 0, 0.5), (0.25, 0.5, 0.5), (0.75, 0.5, 0.5), (0, 0.25, 0), (0.5, 0.25, 0), (0, 0.75, 0), (0.5, 0.75, 0), (0, 0.25, 0.5)] cmap = colors.ListedColormap(colormap) bounds = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 ] norm = colors.BoundaryNorm(bounds, cmap.N) ax1.set_title('gt') ax1.imshow(gt, cmap=cmap, norm=norm) ax2.set_title('pred') ax2.imshow(pred, cmap=cmap, norm=norm) plt.show() torch_device = torch.device(device) if not os.path.exists(save_dir): os.makedirs(save_dir) if dataset == 'pascal_aug': ds = VOCDataSet() else: print('Dataset {} not yet supported'.format(dataset)) return if arch == 'deeplab2': model = Res_Deeplab(num_classes=ds.num_classes) elif arch == 'unet_resnet50': model = unet_resnet50(num_classes=ds.num_classes) elif arch == 'resnet101_deeplabv3': model = resnet101_deeplabv3(num_classes=ds.num_classes) else: print('Architecture {} not supported'.format(arch)) return if pretrained_model is not None: restore_from = pretrianed_models_dict[pretrained_model] if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) model.load_state_dict(saved_state_dict) model.eval() model = model.to(torch_device) ds_val_xy = ds.val_xy(crop_size=(505, 505), scale=False, mirror=False, mean=model.MEAN, std=model.STD) testloader = data.DataLoader(ds_val_xy, batch_size=1, shuffle=False, pin_memory=True) data_list = [] colorize = VOCColorize() with torch.no_grad(): for index, batch in enumerate(testloader): if index % 100 == 0: print('%d processd' % (index)) image, label, size, name = batch size = size[0].numpy() image = torch.tensor(image, dtype=torch.float, device=torch_device) output = model(image) output = output.cpu().data[0].numpy() output = output[:, :size[0], :size[1]] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) filename = os.path.join(save_dir, '{}.png'.format(name[0])) color_file = Image.fromarray( colorize(output).transpose(1, 2, 0), 'RGB') color_file.save(filename) # show_all(gt, output) data_list.append([gt.flatten(), output.flatten()]) filename = os.path.join(save_dir, 'result.txt') get_iou(data_list, ds.num_classes, ignore_label, ds.class_names, filename)
def train(log_file, arch, dataset, batch_size, iter_size, num_workers, partial_data, partial_data_size, partial_id, ignore_label, crop_size, eval_crop_size, is_training, learning_rate, learning_rate_d, supervised, lambda_adv_pred, lambda_semi, lambda_semi_adv, mask_t, semi_start, semi_start_adv, d_remain, momentum, not_restore_last, num_steps, power, random_mirror, random_scale, random_seed, restore_from, restore_from_d, eval_every, save_snapshot_every, snapshot_dir, weight_decay, device): settings = locals().copy() import cv2 import torch import torch.nn as nn from torch.utils import data, model_zoo import numpy as np import pickle import torch.optim as optim import torch.nn.functional as F import scipy.misc import sys import os import os.path as osp import pickle from model.deeplab import Res_Deeplab from model.unet import unet_resnet50 from model.deeplabv3 import resnet101_deeplabv3 from model.discriminator import FCDiscriminator from utils.loss import CrossEntropy2d, BCEWithLogitsLoss2d from utils.evaluation import EvaluatorIoU from dataset.voc_dataset import VOCDataSet import logger torch_device = torch.device(device) import time if log_file != '' and log_file != 'none': if os.path.exists(log_file): print('Log file {} already exists; exiting...'.format(log_file)) return with logger.LogFile(log_file if log_file != 'none' else None): if dataset == 'pascal_aug': ds = VOCDataSet(augmented_pascal=True) elif dataset == 'pascal': ds = VOCDataSet(augmented_pascal=False) else: print('Dataset {} not yet supported'.format(dataset)) return print('Command: {}'.format(sys.argv[0])) print('Arguments: {}'.format(' '.join(sys.argv[1:]))) print('Settings: {}'.format(', '.join([ '{}={}'.format(k, settings[k]) for k in sorted(list(settings.keys())) ]))) print('Loaded data') def loss_calc(pred, label): """ This function returns cross entropy loss for semantic segmentation """ # out shape batch_size x channels x h x w -> batch_size x channels x h x w # label shape h x w x 1 x batch_size -> batch_size x 1 x h x w label = label.long().to(torch_device) criterion = CrossEntropy2d() return criterion(pred, label) def lr_poly(base_lr, iter, max_iter, power): return base_lr * ((1 - float(iter) / max_iter)**(power)) def adjust_learning_rate(optimizer, i_iter): lr = lr_poly(learning_rate, i_iter, num_steps, power) optimizer.param_groups[0]['lr'] = lr if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr * 10 def adjust_learning_rate_D(optimizer, i_iter): lr = lr_poly(learning_rate_d, i_iter, num_steps, power) optimizer.param_groups[0]['lr'] = lr if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr * 10 def one_hot(label): label = label.numpy() one_hot = np.zeros((label.shape[0], ds.num_classes, label.shape[1], label.shape[2]), dtype=label.dtype) for i in range(ds.num_classes): one_hot[:, i, ...] = (label == i) #handle ignore labels return torch.tensor(one_hot, dtype=torch.float, device=torch_device) def make_D_label(label, ignore_mask): ignore_mask = np.expand_dims(ignore_mask, axis=1) D_label = np.ones(ignore_mask.shape) * label D_label[ignore_mask] = ignore_label D_label = torch.tensor(D_label, dtype=torch.float, device=torch_device) return D_label h, w = map(int, eval_crop_size.split(',')) eval_crop_size = (h, w) h, w = map(int, crop_size.split(',')) crop_size = (h, w) # create network if arch == 'deeplab2': model = Res_Deeplab(num_classes=ds.num_classes) elif arch == 'unet_resnet50': model = unet_resnet50(num_classes=ds.num_classes) elif arch == 'resnet101_deeplabv3': model = resnet101_deeplabv3(num_classes=ds.num_classes) else: print('Architecture {} not supported'.format(arch)) return # load pretrained parameters if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) model.train() model = model.to(torch_device) # init D model_D = FCDiscriminator(num_classes=ds.num_classes) if restore_from_d is not None: model_D.load_state_dict(torch.load(restore_from_d)) model_D.train() model_D = model_D.to(torch_device) print('Built model') if snapshot_dir is not None: if not os.path.exists(snapshot_dir): os.makedirs(snapshot_dir) ds_train_xy = ds.train_xy(crop_size=crop_size, scale=random_scale, mirror=random_mirror, range01=model.RANGE01, mean=model.MEAN, std=model.STD) ds_train_y = ds.train_y(crop_size=crop_size, scale=random_scale, mirror=random_mirror, range01=model.RANGE01, mean=model.MEAN, std=model.STD) ds_val_xy = ds.val_xy(crop_size=eval_crop_size, scale=False, mirror=False, range01=model.RANGE01, mean=model.MEAN, std=model.STD) train_dataset_size = len(ds_train_xy) if partial_data_size != -1: if partial_data_size > partial_data_size: print('partial-data-size > |train|: exiting') return if partial_data == 1.0 and (partial_data_size == -1 or partial_data_size == train_dataset_size): trainloader = data.DataLoader(ds_train_xy, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(ds_train_y, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_remain = None print('|train|={}'.format(train_dataset_size)) print('|val|={}'.format(len(ds_val_xy))) else: #sample partial data if partial_data_size != -1: partial_size = partial_data_size else: partial_size = int(partial_data * train_dataset_size) if partial_id is not None: train_ids = pickle.load(open(partial_id)) print('loading train ids from {}'.format(partial_id)) else: rng = np.random.RandomState(random_seed) train_ids = list(rng.permutation(train_dataset_size)) if snapshot_dir is not None: pickle.dump(train_ids, open(osp.join(snapshot_dir, 'train_id.pkl'), 'wb')) print('|train supervised|={}'.format(partial_size)) print('|train unsupervised|={}'.format(train_dataset_size - partial_size)) print('|val|={}'.format(len(ds_val_xy))) print('supervised={}'.format(list(train_ids[:partial_size]))) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(ds_train_xy, batch_size=batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(ds_train_xy, batch_size=batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(ds_train_y, batch_size=batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) testloader = data.DataLoader(ds_val_xy, batch_size=1, shuffle=False, pin_memory=True) print('Data loaders ready') trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(learning_rate), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=learning_rate_d, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() print('Built optimizer') # labels for adversarial training pred_label = 0 gt_label = 1 loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_mask_accum = 0 loss_semi_value = 0 loss_semi_adv_value = 0 t1 = time.time() print('Training for {} steps...'.format(num_steps)) for i_iter in range(num_steps + 1): model.train() model.freeze_batchnorm() optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(iter_size): # train G if not supervised: # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if not supervised and (lambda_semi > 0 or lambda_semi_adv > 0 ) and i_iter >= semi_start_adv and \ trainloader_remain is not None: try: _, batch = next(trainloader_remain_iter) except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = next(trainloader_remain_iter) # only access to img images, _, _, _ = batch images = images.float().to(torch_device) pred = model(images) pred_remain = pred.detach() D_out = model_D(F.softmax(pred, dim=1)) D_out_sigmoid = F.sigmoid( D_out).data.cpu().numpy().squeeze(axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) loss_semi_adv = lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv / iter_size #loss_semi_adv.backward() loss_semi_adv_value += float( loss_semi_adv) / lambda_semi_adv if lambda_semi <= 0 or i_iter < semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < mask_t) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = ignore_label semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size loss_semi_mask_accum += float(semi_ratio) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = lambda_semi * loss_calc(pred, semi_gt) loss_semi = loss_semi / iter_size loss_semi_value += float(loss_semi) / lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = next(trainloader_iter) except: trainloader_iter = enumerate(trainloader) _, batch = next(trainloader_iter) images, labels, _, _ = batch images = images.float().to(torch_device) ignore_mask = (labels.numpy() == ignore_label) pred = model(images) loss_seg = loss_calc(pred, labels) if supervised: loss = loss_seg else: D_out = model_D(F.softmax(pred, dim=1)) loss_adv_pred = bce_loss( D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + lambda_adv_pred * loss_adv_pred loss_adv_pred_value += float(loss_adv_pred) / iter_size # proper normalization loss = loss / iter_size loss.backward() loss_seg_value += float(loss_seg) / iter_size if not supervised: # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if d_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate( (ignore_mask, ignore_mask_remain), axis=0) D_out = model_D(F.softmax(pred, dim=1)) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / iter_size / 2 loss_D.backward() loss_D_value += float(loss_D) # train with gt # get gt labels try: _, batch = next(trainloader_gt_iter) except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = next(trainloader_gt_iter) _, labels_gt, _, _ = batch D_gt_v = one_hot(labels_gt) ignore_mask_gt = (labels_gt.numpy() == ignore_label) D_out = model_D(D_gt_v) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / iter_size / 2 loss_D.backward() loss_D_value += float(loss_D) optimizer.step() optimizer_D.step() sys.stdout.write('.') sys.stdout.flush() if i_iter % eval_every == 0 and i_iter != 0: model.eval() with torch.no_grad(): evaluator = EvaluatorIoU(ds.num_classes) for index, batch in enumerate(testloader): image, label, size, name = batch size = size[0].numpy() image = image.float().to(torch_device) output = model(image) output = output.cpu().data[0].numpy() output = output[:, :size[0], :size[1]] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) evaluator.sample(gt, output, ignore_value=ignore_label) sys.stdout.write('+') sys.stdout.flush() per_class_iou = evaluator.score() mean_iou = per_class_iou.mean() loss_seg_value /= eval_every loss_adv_pred_value /= eval_every loss_D_value /= eval_every loss_semi_mask_accum /= eval_every loss_semi_value /= eval_every loss_semi_adv_value /= eval_every sys.stdout.write('\n') t2 = time.time() print( 'iter = {:8d}/{:8d}, took {:.3f}s, loss_seg = {:.6f}, loss_adv_p = {:.6f}, loss_D = {:.6f}, loss_semi_mask_rate = {:.3%} loss_semi = {:.6f}, loss_semi_adv = {:.3f}' .format(i_iter, num_steps, t2 - t1, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_mask_accum, loss_semi_value, loss_semi_adv_value)) for i, (class_name, iou) in enumerate(zip(ds.class_names, per_class_iou)): print('class {:2d} {:12} IU {:.2f}'.format( i, class_name, iou)) print('meanIOU: ' + str(mean_iou) + '\n') loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_mask_accum = 0 loss_semi_adv_value = 0 t1 = t2 if snapshot_dir is not None and i_iter % save_snapshot_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) if snapshot_dir is not None: print('save model ...') torch.save( model.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '_D.pth'))