def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model = nn.DataParallel(model) model.cuda() cudnn.benchmark = True # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D = nn.DataParallel(model_D) model_D.train() model_D.cuda() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = list(range(train_dataset_size)) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.module.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv: try: _, batch = trainloader_remain_iter.next() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.next() # only access to img images, _, _, _ = batch images = Variable(images).cuda() pred = interp(model(images)) pred_remain = pred.detach() D_out = interp(model_D(F.softmax(pred, dim=1))) D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze( axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) loss_semi_adv = args.lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv / args.iter_size #loss_semi_adv.backward() loss_semi_adv_value += loss_semi_adv.data.cpu().numpy( )[0] / args.lambda_semi_adv if args.lambda_semi <= 0 or i_iter < args.semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < args.mask_T) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = 255 semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = args.lambda_semi * loss_calc(pred, semi_gt) loss_semi = loss_semi / args.iter_size loss_semi_value += loss_semi.data.cpu().numpy( )[0] / args.lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda() ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels) D_out = interp(model_D(F.softmax(pred, dim=1))) loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( )[0] / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if args.D_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain), axis=0) D_out = interp(model_D(F.softmax(pred, dim=1))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda() ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(i_iter) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def evaluate(arch, dataset, ignore_label, restore_from, pretrained_model, save_dir, device): import argparse import scipy from scipy import ndimage import cv2 import numpy as np import sys from collections import OrderedDict import os import torch import torch.nn as nn from torch.autograd import Variable import torchvision.models as models import torch.nn.functional as F from torch.utils import data, model_zoo from model.deeplab import Res_Deeplab from model.unet import unet_resnet50 from model.deeplabv3 import resnet101_deeplabv3 from dataset.voc_dataset import VOCDataSet from PIL import Image import matplotlib.pyplot as plt pretrianed_models_dict = { 'semi0.125': 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSemiSegVOC0.125-03c6f81c.pth', 'semi0.25': 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSemiSegVOC0.25-473f8a14.pth', 'semi0.5': 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSemiSegVOC0.5-acf6a654.pth', 'advFull': 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSegVOCFull-92fbc7ee.pth' } class VOCColorize(object): def __init__(self, n=22): self.cmap = color_map(22) self.cmap = torch.from_numpy(self.cmap[:n]) def __call__(self, gray_image): size = gray_image.shape color_image = np.zeros((3, size[0], size[1]), dtype=np.uint8) for label in range(0, len(self.cmap)): mask = (label == gray_image) color_image[0][mask] = self.cmap[label][0] color_image[1][mask] = self.cmap[label][1] color_image[2][mask] = self.cmap[label][2] # handle void mask = (255 == gray_image) color_image[0][mask] = color_image[1][mask] = color_image[2][ mask] = 255 return color_image def color_map(N=256, normalized=False): def bitget(byteval, idx): return ((byteval & (1 << idx)) != 0) dtype = 'float32' if normalized else 'uint8' cmap = np.zeros((N, 3), dtype=dtype) for i in range(N): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << 7 - j) g = g | (bitget(c, 1) << 7 - j) b = b | (bitget(c, 2) << 7 - j) c = c >> 3 cmap[i] = np.array([r, g, b]) cmap = cmap / 255 if normalized else cmap return cmap def get_iou(data_list, class_num, ignore_label, class_names, save_path=None): from multiprocessing import Pool from utils.evaluation import EvaluatorIoU evaluator = EvaluatorIoU(class_num) for truth, prediction in data_list: evaluator.sample(truth, prediction, ignore_value=ignore_label) per_class_iou = evaluator.score() mean_iou = per_class_iou.mean() for i, (class_name, iou) in enumerate(zip(class_names, per_class_iou)): print('class {:2d} {:12} IU {:.2f}'.format(i, class_name, iou)) print('meanIOU: ' + str(mean_iou) + '\n') if save_path: with open(save_path, 'w') as f: for i, (class_name, iou) in enumerate(zip(class_names, per_class_iou)): f.write('class {:2d} {:12} IU {:.2f}'.format( i, class_name, iou) + '\n') f.write('meanIOU: ' + str(mean_iou) + '\n') def show_all(gt, pred): import matplotlib.pyplot as plt from matplotlib import colors from mpl_toolkits.axes_grid1 import make_axes_locatable fig, axes = plt.subplots(1, 2) ax1, ax2 = axes colormap = [(0, 0, 0), (0.5, 0, 0), (0, 0.5, 0), (0.5, 0.5, 0), (0, 0, 0.5), (0.5, 0, 0.5), (0, 0.5, 0.5), (0.5, 0.5, 0.5), (0.25, 0, 0), (0.75, 0, 0), (0.25, 0.5, 0), (0.75, 0.5, 0), (0.25, 0, 0.5), (0.75, 0, 0.5), (0.25, 0.5, 0.5), (0.75, 0.5, 0.5), (0, 0.25, 0), (0.5, 0.25, 0), (0, 0.75, 0), (0.5, 0.75, 0), (0, 0.25, 0.5)] cmap = colors.ListedColormap(colormap) bounds = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 ] norm = colors.BoundaryNorm(bounds, cmap.N) ax1.set_title('gt') ax1.imshow(gt, cmap=cmap, norm=norm) ax2.set_title('pred') ax2.imshow(pred, cmap=cmap, norm=norm) plt.show() torch_device = torch.device(device) if not os.path.exists(save_dir): os.makedirs(save_dir) if dataset == 'pascal_aug': ds = VOCDataSet() else: print('Dataset {} not yet supported'.format(dataset)) return if arch == 'deeplab2': model = Res_Deeplab(num_classes=ds.num_classes) elif arch == 'unet_resnet50': model = unet_resnet50(num_classes=ds.num_classes) elif arch == 'resnet101_deeplabv3': model = resnet101_deeplabv3(num_classes=ds.num_classes) else: print('Architecture {} not supported'.format(arch)) return if pretrained_model is not None: restore_from = pretrianed_models_dict[pretrained_model] if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) model.load_state_dict(saved_state_dict) model.eval() model = model.to(torch_device) ds_val_xy = ds.val_xy(crop_size=(505, 505), scale=False, mirror=False, mean=model.MEAN, std=model.STD) testloader = data.DataLoader(ds_val_xy, batch_size=1, shuffle=False, pin_memory=True) data_list = [] colorize = VOCColorize() with torch.no_grad(): for index, batch in enumerate(testloader): if index % 100 == 0: print('%d processd' % (index)) image, label, size, name = batch size = size[0].numpy() image = torch.tensor(image, dtype=torch.float, device=torch_device) output = model(image) output = output.cpu().data[0].numpy() output = output[:, :size[0], :size[1]] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) filename = os.path.join(save_dir, '{}.png'.format(name[0])) color_file = Image.fromarray( colorize(output).transpose(1, 2, 0), 'RGB') color_file.save(filename) # show_all(gt, output) data_list.append([gt.flatten(), output.flatten()]) filename = os.path.join(save_dir, 'result.txt') get_iou(data_list, ds.num_classes, ignore_label, ds.class_names, filename)
def main(): # prepare h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D = detector.FlawDetector(in_channels=24) # model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = [_ for _ in range(0, train_dataset_size)] np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling minimum_loss = detector.MinimumCriterion() detector_loss = detector.FlawDetectorCriterion() # bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): if i_iter > 0 and i_iter % 1000 == 0: val(model, args.gpu) model.train() loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv: try: _, batch = trainloader_remain_iter.__next__() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.__next__() # only access to img images, _, _, _ = batch images = Variable(images).cuda(args.gpu) pred = interp(model(images)) pred_remain = pred.detach() D_out = model_D(images, pred) ignore_mask_remain = np.zeros(D_out.shape).astype(np.bool) loss_semi_adv = args.lambda_semi_adv * minimum_loss( D_out) # ke: SSL loss for unlabeled data loss_semi_adv = loss_semi_adv #loss_semi_adv.backward() loss_semi_adv_value += loss_semi_adv.data.cpu().numpy( ) / args.iter_size loss_semi_adv.backward() loss_semi_value = 0 else: loss_semi = None loss_semi_adv = None # train with source (labeled data) try: _, batch = trainloader_iter.__next__() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels, args.gpu) D_out = interp(model_D(images, pred)) loss_adv_pred = args.lambda_adv_pred * minimum_loss( D_out) # ke: SSL loss for labeled data loss = loss_seg + loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( ) / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred from labeled data pred = pred.detach() D_out = interp(model_D(images, pred)) detect_gt = detector.generate_flaw_detector_gt( pred, labels.view(labels.shape[0], 1, labels.shape[1], labels.shape[2]).cuda(args.gpu), NUM_CLASSES, IGNORE_LABEL) loss_D = detector_loss(D_out, detect_gt) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() # # train with gt # # get gt labels # try: # _, batch = trainloader_gt_iter.__next__() # except: # trainloader_gt_iter = enumerate(trainloader_gt) # _, batch = trainloader_gt_iter.__next__() # _, labels_gt, _, _ = batch # D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) # ignore_mask_gt = (labels_gt.numpy() == 255) # D_out = interp(model_D(D_gt_v)) # loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) # loss_D = loss_D/args.iter_size/2 # loss_D.backward() # loss_D_value += loss_D.data.cpu().numpy() optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.6f}, loss_adv_l = {3:.6f}, loss_D = {4:.6f}, loss_semi = {5:.6f}, loss_adv_u = {6:.6f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) end = timeit.default_timer()
def main(): tag = 0 h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model = nn.DataParallel(model) model.cuda() cudnn.benchmark = True # init D model_DS = [] for i in range(20): model_DS.append(Discriminator2_mul(num_classes=args.num_classes)) # if args.restore_from_D is not None: # model_D.load_state_dict(torch.load(args.restore_from_D)) for model_D in model_DS: model_D = nn.DataParallel(model_D) model_D.train() model_D.cuda() model_D2 = Discriminator2(num_classes=args.num_classes) if args.restore_from_D is not None: model_D2.load_state_dict(torch.load(args.restore_from_D)) model_D2 = nn.DataParallel(model_D2) model_D2.train() model_D2.cuda() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data == 0: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.module.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_DS = [] for model_D in model_DS: optimizer_DS.append( optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))) for optimizer_D in optimizer_DS: optimizer_D.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() # loss/ bilinear upsampling bce_loss = torch.nn.BCELoss() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) for optimizer_D in optimizer_DS: optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for model_D in model_DS: for param in model_D.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda() ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels) pred_re0 = F.softmax(pred, dim=1) pred_re = pred_re0.repeat(1, 3, 1, 1) indices_1 = torch.index_select( images, 1, Variable(torch.LongTensor([0])).cuda()) indices_2 = torch.index_select( images, 1, Variable(torch.LongTensor([1])).cuda()) indices_3 = torch.index_select( images, 1, Variable(torch.LongTensor([2])).cuda()) img_re = torch.cat([ indices_1.repeat(1, 21, 1, 1), indices_2.repeat(1, 21, 1, 1), indices_3.repeat(1, 21, 1, 1), ], 1) mul_img = pred_re * img_re #10,63,321,321 D_out_2 = model_D2(mul_img) loss_adv_pred_ = 0 for i_l in range(labels.shape[0]): label_set = np.unique(labels[i_l]).tolist() for ls in label_set: if ls != 0 and ls != 255: ls = int(ls) img_p = torch.cat([ mul_img[i_l][ls].unsqueeze(0).unsqueeze(0), mul_img[i_l][ls + 21].unsqueeze(0).unsqueeze(0), mul_img[i_l][ls + 21 + 21].unsqueeze(0).unsqueeze(0) ], 1) #print 1,img_p.size()#(1L, 3L, 321L, 321L) imgs = img_p imgs1 = imgs[:, :, 0:107, 0:107] imgs2 = imgs[:, :, 0:107, 107:214] imgs3 = imgs[:, :, 0:107, 214:321] imgs4 = imgs[:, :, 107:214, 0:107] imgs5 = imgs[:, :, 107:214, 107:214] imgs6 = imgs[:, :, 107:214, 214:321] imgs7 = imgs[:, :, 214:321, 0:107] imgs8 = imgs[:, :, 214:321, 107:214] imgs9 = imgs[:, :, 214:321, 214:321] #print 2, imgs1.size()#(1L, 3L, 107L, 107L) img_ps = torch.cat([ imgs1, imgs2, imgs3, imgs4, imgs5, imgs6, imgs7, imgs8, imgs9 ], 0) #print 3, img_ps.size()#(9L, 3L, 107L, 107L) D_out = model_DS[ls - 1](img_ps) loss_adv_pred_ = loss_adv_pred_ + bce_loss( D_out, make_D_label(gt_label, D_out)) loss_adv_pred = loss_adv_pred_ * 0.5 + bce_loss( D_out_2, make_D_label(gt_label, D_out_2)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( )[0] / args.iter_size # train D # bring back requires_grad for model_D in model_DS: for param in model_D.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with pred pred = pred.detach() pred_re0 = F.softmax(pred, dim=1) pred_re2 = pred_re0.repeat(1, 3, 1, 1) mul_img2 = pred_re2 * img_re D_out_2 = model_D2(mul_img2) loss_adv_pred_ = 0 for i_l in range(labels.shape[0]): label_set = np.unique(labels[i_l]).tolist() for ls in label_set: if ls != 0 and ls != 255: ls = int(ls) img_p = torch.cat([ mul_img2[i_l][ls].unsqueeze(0).unsqueeze(0), mul_img2[i_l][ls + 21].unsqueeze(0).unsqueeze(0), mul_img2[i_l][ls + 21 + 21].unsqueeze(0).unsqueeze(0) ], 1) # print 1,img_p.size()#(1L, 3L, 321L, 321L) imgs = img_p imgs1 = imgs[:, :, 0:107, 0:107] imgs2 = imgs[:, :, 0:107, 107:214] imgs3 = imgs[:, :, 0:107, 214:321] imgs4 = imgs[:, :, 107:214, 0:107] imgs5 = imgs[:, :, 107:214, 107:214] imgs6 = imgs[:, :, 107:214, 214:321] imgs7 = imgs[:, :, 214:321, 0:107] imgs8 = imgs[:, :, 214:321, 107:214] imgs9 = imgs[:, :, 214:321, 214:321] # print 2, imgs1.size()#(1L, 3L, 107L, 107L) img_ps = torch.cat([ imgs1, imgs2, imgs3, imgs4, imgs5, imgs6, imgs7, imgs8, imgs9 ], 0) # print 3, img_ps.size()#(9L, 3L, 107L, 107L) D_out = model_DS[ls - 1](img_ps) loss_adv_pred_ = loss_adv_pred_ + bce_loss( D_out, make_D_label(pred_label, D_out)) loss_D = loss_adv_pred_ * 0.5 + bce_loss( D_out_2, make_D_label(pred_label, D_out_2)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() img_gt, labels_gt, _, name = batch img_gt = Variable(img_gt).cuda() D_gt_v = Variable(one_hot(labels_gt)).cuda() ignore_mask_gt = (labels_gt.numpy() == 255) lb = D_gt_v.detach() pred_re3 = D_gt_v.repeat(1, 3, 1, 1) indices_1 = torch.index_select( img_gt, 1, Variable(torch.LongTensor([0])).cuda()) indices_2 = torch.index_select( img_gt, 1, Variable(torch.LongTensor([1])).cuda()) indices_3 = torch.index_select( img_gt, 1, Variable(torch.LongTensor([2])).cuda()) img_re3 = torch.cat([ indices_1.repeat(1, 21, 1, 1), indices_2.repeat(1, 21, 1, 1), indices_3.repeat(1, 21, 1, 1), ], 1) #mul_img3 = img_re3 mul_img3 = pred_re3 * img_re3 D_out_2 = model_D2(mul_img3) loss_adv_pred_ = 0 for i_l in range(labels_gt.shape[0]): label_set = np.unique(labels_gt[i_l]).tolist() for ls in label_set: if ls != 0 and ls != 255: ls = int(ls) img_p = torch.cat([ mul_img3[i_l][ls].unsqueeze(0).unsqueeze(0), mul_img3[i_l][ls + 21].unsqueeze(0).unsqueeze(0), mul_img3[i_l][ls + 21 + 21].unsqueeze(0).unsqueeze(0) ], 1) # print 1,img_p.size()#(1L, 3L, 321L, 321L) imgs = img_p imgs1 = imgs[:, :, 0:107, 0:107] imgs2 = imgs[:, :, 0:107, 107:214] imgs3 = imgs[:, :, 0:107, 214:321] imgs4 = imgs[:, :, 107:214, 0:107] imgs5 = imgs[:, :, 107:214, 107:214] imgs6 = imgs[:, :, 107:214, 214:321] imgs7 = imgs[:, :, 214:321, 0:107] imgs8 = imgs[:, :, 214:321, 107:214] imgs9 = imgs[:, :, 214:321, 214:321] # print 2, imgs1.size()#(1L, 3L, 107L, 107L) img_ps = torch.cat([ imgs1, imgs2, imgs3, imgs4, imgs5, imgs6, imgs7, imgs8, imgs9 ], 0) # print 3, img_ps.size()#(9L, 3L, 107L, 107L) D_out = model_DS[ls - 1](img_ps) loss_adv_pred_ = loss_adv_pred_ + bce_loss( D_out, make_D_label(gt_label, D_out)) ''' if tag == 0: # print lb[0].size() # lb1=lb[0][0] # lb2 = lb[0][1] # lb3 = lb[0][2] # lb4 = lb[0][3] # lb5 = lb[0][4] # lb6 = lb[0][5] # lb7 = lb[0][0] # lb8 = lb[0][0] # lb9 = lb[0][0] # lb10 = lb[0][0] print label_set, name[0] print ls imgs = imgs.squeeze() imgs = imgs.transpose(0, 1) imgs = imgs.transpose(1, 2) imgs1 = imgs1.squeeze() imgs1 = imgs1.transpose(0, 1) imgs1 = imgs1.transpose(1, 2) imgs2 = imgs2.squeeze() imgs2 = imgs2.transpose(0, 1) imgs2 = imgs2.transpose(1, 2) imgs3 = imgs3.squeeze() imgs3 = imgs3.transpose(0, 1) imgs3 = imgs3.transpose(1, 2) imgs4 = imgs4.squeeze() imgs4 = imgs4.transpose(0, 1) imgs4 = imgs4.transpose(1, 2) imgs5 = imgs5.squeeze() imgs5 = imgs5.transpose(0, 1) imgs5 = imgs5.transpose(1, 2) imgs6 = imgs6.squeeze() imgs6 = imgs6.transpose(0, 1) imgs6 = imgs6.transpose(1, 2) imgs7 = imgs7.squeeze() imgs7 = imgs7.transpose(0, 1) imgs7 = imgs7.transpose(1, 2) imgs8 = imgs8.squeeze() imgs8 = imgs8.transpose(0, 1) imgs8 = imgs8.transpose(1, 2) imgs9 = imgs9.squeeze() imgs9 = imgs9.transpose(0, 1) imgs9 = imgs9.transpose(1, 2) imgs = imgs.data.cpu().numpy() imgs1 = imgs1.data.cpu().numpy() imgs2 = imgs2.data.cpu().numpy() imgs3 = imgs3.data.cpu().numpy() imgs4 = imgs4.data.cpu().numpy() imgs5 = imgs5.data.cpu().numpy() imgs6 = imgs6.data.cpu().numpy() imgs7 = imgs7.data.cpu().numpy() imgs8 = imgs8.data.cpu().numpy() imgs9 = imgs9.data.cpu().numpy() cv2.imwrite('/data1/wyc/1.png', imgs1) cv2.imwrite('/data1/wyc/2.png', imgs2) cv2.imwrite('/data1/wyc/3.png', imgs3) cv2.imwrite('/data1/wyc/4.png', imgs4) cv2.imwrite('/data1/wyc/5.png', imgs5) cv2.imwrite('/data1/wyc/6.png', imgs6) cv2.imwrite('/data1/wyc/7.png', imgs7) cv2.imwrite('/data1/wyc/8.png', imgs8) cv2.imwrite('/data1/wyc/9.png', imgs9) cv2.imwrite('/data1/wyc/img.png', imgs) tag = 1 ''' loss_D = loss_adv_pred_ * 0.5 + bce_loss( D_out_2, make_D_label(gt_label, D_out_2)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] optimizer.step() for optimizer_D in optimizer_DS: optimizer_D.step() optimizer_D2.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(args.num_steps) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(i_iter) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in currendt model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model = nn.DataParallel(model) model.cuda() cudnn.benchmark = True # init D model_D = Discriminator_mul(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) # model_D = nn.DataParallel(model_D) model_D.train() model_D.cuda() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data == 0: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) trainloader_iter = enumerate(trainloader) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.module.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first # train with source try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda() ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels) ######## D_out = model_D(torch.cat([F.softmax(pred, dim=1), images], 1)) D_out = F.log_softmax(D_out, dim=1) lab = np.zeros([D_out.size()[0], 40]) for bat in range(D_out.size()[0]): labels2 = labels[bat] labels2 = labels2.view(-1) labels2 = labels2.numpy().tolist() set2 = set(labels2) set2.discard(255) set2.discard(0) for item in set2: lab[bat][int(item) - 1] = 1.0 / len(set2) lab = Variable(torch.FloatTensor(lab)).cuda() loss_adv_pred = D_out * lab loss_adv_pred = -torch.mean(loss_adv_pred) ####### loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( )[0] / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() D_out = model_D(torch.cat([F.softmax(pred, dim=1), images], 1)) ######## D_out = F.log_softmax(D_out, dim=1) lab = np.zeros([D_out.size()[0], 40]) for bat in range(D_out.size()[0]): labels2 = labels[bat] labels2 = labels2.view(-1) labels2 = labels2.numpy().tolist() set2 = set(labels2) set2.discard(255) set2.discard(0) for item in set2: lab[bat][int(item) - 1 + 20] = 1.0 / len(set2) lab = Variable(torch.FloatTensor(lab)).cuda() loss_D = D_out * lab loss_D = -torch.mean(loss_D) ####### #loss_D = bce_loss(D_out, make_D_label(pred_label,D_out)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() img2, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda() ignore_mask_gt = (labels_gt.numpy() == 255) img2 = Variable(img2).cuda() D_out = model_D(torch.cat([D_gt_v, img2], 1)) ######## D_out = F.log_softmax(D_out, dim=1) lab = np.zeros([D_out.size()[0], 40]) for bat in range(D_out.size()[0]): labels2 = labels_gt[bat] labels2 = labels2.view(-1) labels2 = labels2.numpy().tolist() set2 = set(labels2) set2.discard(255) set2.discard(0) for item in set2: lab[bat][int(item) - 1] = 1.0 / len(set2) lab = Variable(torch.FloatTensor(lab)).cuda() loss_D = D_out * lab loss_D = -torch.mean(loss_D) ####### #loss_D = bce_loss(D_out, make_D_label(gt_label,D_out)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] optimizer.step() optimizer_D.step() # torch.save(model.state_dict(), osp.join(args.snapshot_dir, # 'VOC_' + os.path.abspath(__file__).split('/')[-1] + '_' + str( # args.num_steps) + '.pth')) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(args.num_steps) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(i_iter) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters (weights) if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url( args.restore_from ) ## http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth else: saved_state_dict = torch.load(args.restore_from) #checkpoint = torch.load(args.restore_from)_ # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() # state_dict() is current model for name, param in new_params.items(): #print (name) # 'conv1.weight, name:param(value), dict if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) #print('copy {}'.format(name)) model.load_state_dict(new_params) #model.load_state_dict(checkpoint['state_dict']) #optimizer.load_state_dict(args.checkpoint['optim_dict']) model.train( ) # https://pytorch.org/docs/stable/nn.html, Sets the module in training mode. model.cuda(args.gpu) ## cudnn.benchmark = True # This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware # init D model_D = FCDiscriminator(num_classes=args.num_classes) #args.restore_from_D = 'snapshots/linear2/VOC_25000_D.pth' if args.restore_from_D is not None: # None model_D.load_state_dict(torch.load(args.restore_from_D)) # checkpoint_D = torch.load(args.restore_from_D) # model_D.load_state_dict(checkpoint_D['state_dict']) # optimizer_D.load_state_dict(checkpoint_D['optim_dict']) model_D.train() model_D.cuda(args.gpu) if USECALI: model_cali = ModelWithTemperature(model, model_D) model_cali.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) random.seed(args.random_seed) np.random.seed(args.random_seed) torch.manual_seed(args.random_seed) torch.cuda.manual_seed(args.random_seed) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_remain = VOCDataSet(args.data_dir, args.data_list_remain, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_dataset_size_remain = len(train_dataset_remain) print train_dataset_size print train_dataset_size_remain train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: #if not partial, load all trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data #args.partial_data = 0.125 partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: #args.partial_id is none train_ids = range(train_dataset_size) train_ids_remain = range(train_dataset_size_remain) np.random.shuffle(train_ids) #shuffle! np.random.shuffle(train_ids_remain) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) #randomly suffled ids #sampler train_sampler = data.sampler.SubsetRandomSampler( train_ids[:]) # 0~1/8, train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids_remain[:]) train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:]) # train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) # 0~1/8 # train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) # used as unlabeled, 7/8 # train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) #train loader trainloader = data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) # multi-process data loading trainloader_remain = data.DataLoader(train_dataset_remain, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) # trainloader_remain = data.DataLoader(train_dataset, # batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, # pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network # model.optim_paramters(args) = list(dict1, dict2), dict1 >> 'lr' and 'params' # print(type(model.optim_parameters(args)[0]['params'])) # generator #print(model.state_dict()['coeff'][0]) #confirmed optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) #optimizer.add_param_group({"params":model.coeff}) # assign new coefficient to the optimizer #print(len(optimizer.param_groups)) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() #initialize if USECALI: optimizer_cali = optim.LBFGS([model_cali.temperature], lr=0.01, max_iter=50) optimizer_cali.zero_grad() nll_criterion = BCEWithLogitsLoss().cuda() # BCE!! ece_criterion = ECELoss().cuda() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample( size=(input_size[1], input_size[0]), mode='bilinear' ) # okay it automatically change to functional.interpolate # 321, 321 if version.parse(torch.__version__) >= version.parse('0.4.0'): #0.4.1 interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 semi_ratio_sum = 0 semi_sum = 0 loss_seg_sum = 0 loss_adv_sum = 0 loss_vat_sum = 0 l_seg_sum = 0 l_vat_sum = 0 l_adv_sum = 0 logits_list = [] labels_list = [] #https: // towardsdatascience.com / understanding - pytorch -with-an - example - a - step - by - step - tutorial - 81fc5f8c4e8e for i_iter in range(args.num_steps): loss_seg_value = 0 # L_seg loss_adv_pred_value = 0 # 0.01 L_adv loss_D_value = 0 # L_D loss_semi_value = 0 # 0.1 L_semi loss_semi_adv_value = 0 # 0.001 L_adv loss_vat_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) #changing lr by iteration optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): ###################### train G!!!########################### ############################################################ # don't accumulate grads in D for param in model_D.parameters( ): # <class 'torch.nn.parameter.Parameter'>, convolution weights param.requires_grad = False # do not update gradient of D (freeze) while G ######### do unlabeled first!! 0.001 L_adv + 0.1 L_semi ############### # lambda_semi, lambda_adv for unlabeled if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv: try: _, batch = trainloader_remain_iter.next( ) #remain = unlabeled print(trainloader_remain_iter.next()) except: trainloader_remain_iter = enumerate( trainloader_remain) # impose counters _, batch = trainloader_remain_iter.next() # only access to img images, _, _, _ = batch # <class 'torch.Tensor'> images = Variable(images).cuda( args.gpu) # <class 'torch.Tensor'> pred = interp( model(images)) # S(X), pred <class 'torch.Tensor'> pred_remain = pred.detach( ) #use detach() when attempting to remove a tensor from a computation graph, will be used for D # https://discuss.pytorch.org/t/clone-and-detach-in-v0-4-0/16861 # The difference is that detach refers to only a given variable on which it's called. # torch.no_grad affects all operations taking place within the with statement. >> for context, # requires_grad is for tensor # pred >> (8,21,321,321), L_adv D_out = interp( model_D(F.softmax(pred)) ) # D(S(X)), confidence, 8,1,321,321, not detached, there was not dim D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze( axis=1) # (8,321,321) 0~1 # 0.001 L_adv!!!! ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) # no ignore_mask for unlabeled adv loss_semi_adv = args.lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain)) #gt_label =1, # -log(D(S(X))) loss_semi_adv = loss_semi_adv / args.iter_size #normalization loss_semi_adv_value += loss_semi_adv.data.cpu().numpy( ) / args.lambda_semi_adv ##--- visualization, pred(8,21,321,321), D_out_sigmoid(8,321,321) """ if i_iter % 1000 == 0: vpred = pred.transpose(1, 2).transpose(2, 3).contiguous() # (8,321,321,21) vpred = vpred.view(-1, 21) # (8*321*321, 21) vlogsx = F.log_softmax(vpred) # torch.Tensor vsemi_gt = pred.data.cpu().numpy().argmax(axis=1) vsemi_gt = Variable(torch.FloatTensor(vsemi_gt).long()).cuda(gpu) vlogsx = vlogsx.gather(1, vsemi_gt.view(-1, 1)) sx = F.softmax(vpred).gather(1, vsemi_gt.view(-1, 1)) vD_out_sigmoid = Variable(torch.FloatTensor(D_out_sigmoid)).cuda(gpu).view(-1, 1) vlogsx = (vlogsx*(2.5*vD_out_sigmoid+0.5)) vlogsx = -vlogsx.squeeze(dim=1) sx = sx.squeeze(dim=1) vD_out_sigmoid = vD_out_sigmoid.squeeze(dim=1) dsx = vD_out_sigmoid.data.cpu().detach().numpy() vlogsx = vlogsx.data.cpu().detach().numpy() sx = sx.data.cpu().detach().numpy() plt.clf() plt.figure(figsize=(15, 5)) plt.subplot(131) plt.ylim(0, 0.004) plt.scatter(dsx, vlogsx, s = 0.1) # variable requires grad cannot call numpy >> detach plt.xlabel('D(S(X))') plt.ylabel('Loss_Semi per Pixel') plt.subplot(132) plt.scatter(dsx, vlogsx, s = 0.1) # variable requires grad cannot call numpy >> detach plt.xlabel('D(S(X))') plt.ylabel('Loss_Semi per Pixel') plt.subplot(133) plt.scatter(dsx, sx, s=0.1) plt.xlabel('D(S(X))') plt.ylabel('S(x)') plt.savefig('/home/eungyo/AdvSemiSeg/plot/' + str(i_iter) + '.png') """ if args.lambda_semi <= 0 or i_iter < args.semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: semi_gt = pred.data.cpu().numpy().argmax( axis=1 ) # pred=S(X) ((8,21,321,321)), semi_gt is not one-hot, 8,321,321 #(8, 321, 321) if not USECALI: semi_ignore_mask = ( D_out_sigmoid < args.mask_T ) # both (8,321,321) 0~1threshold!, numpy semi_gt[ semi_ignore_mask] = 255 # Yhat, ignore pixel becomes 255 semi_ratio = 1.0 - float(semi_ignore_mask.sum( )) / semi_ignore_mask.size # ignored pixels / H*W print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) confidence = torch.FloatTensor( D_out_sigmoid) ## added, only pred is on cuda loss_semi = args.lambda_semi * weighted_loss_calc( pred, semi_gt, args.gpu, confidence) else: semi_ratio = 1 semi_gt = (torch.FloatTensor(semi_gt)) # (8,321,321) confidence = torch.FloatTensor( F.sigmoid( model_cali.temperature_scale(D_out.view( -1))).data.cpu().numpy()) # (8*321*321,) loss_semi = args.lambda_semi * calibrated_loss_calc( pred, semi_gt, args.gpu, confidence, accuracies, n_bin ) # L_semi = Yhat * log(S(X)) # loss_calc(pred, semi_gt, args.gpu) # pred(8,21,321,321) if semi_ratio != 0: loss_semi = loss_semi / args.iter_size loss_semi_value += loss_semi.data.cpu().numpy( ) / args.lambda_semi if args.method == 'vatent' or args.method == 'vat': #v_loss = vat_loss(model, images, pred, eps=args.epsilon[i]) # R_vadv weighted_v_loss = weighted_vat_loss( model, images, pred, confidence, eps=args.epsilon) if args.method == 'vatent': #v_loss += entropy_loss(pred) # R_cent (conditional entropy loss) weighted_v_loss += weighted_entropy_loss( pred, confidence) v_loss = weighted_v_loss / args.iter_size loss_vat_value += v_loss.data.cpu().numpy() loss_semi_adv += args.alpha * v_loss loss_vat_sum += loss_vat_value if i_iter % 100 == 0 and sub_i == 4: l_vat_sum = loss_vat_sum / 100 if i_iter == 0: l_vat_sum = l_vat_sum * 100 loss_vat_sum = 0 loss_semi += loss_semi_adv loss_semi.backward( ) # 0.001 L_adv + 0.1 L_semi, backward == back propagation else: loss_semi = None loss_semi_adv = None ###########train with source (labeled data)############### L_ce + 0.01 * L_adv try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) # safe coding _, batch = trainloader_iter.next() #counter, batch images, labels, _, _ = batch # also get labels images(8,321,321) images = Variable(images).cuda(args.gpu) ignore_mask = ( labels.numpy() == 255 ) # ignored pixels == 255 >> 1, yes ignored mask for labeled data pred = interp(model(images)) # S(X), 8,21,321,321 loss_seg = loss_calc(pred, labels, args.gpu) # -Y*logS(X)= L_ce, not detached if USED: softsx = F.softmax(pred, dim=1) D_out = interp(model_D(softsx)) # D(S(X)), L_adv loss_adv_pred = bce_loss( D_out, make_D_label( gt_label, ignore_mask)) # both 8,1,321,321, gt_label = 1 # L_adv = -log(D(S(X)), make_D_label is all 1 except ignored_region loss = loss_seg + args.lambda_adv_pred * loss_adv_pred if USECALI: if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv: with torch.no_grad(): _, prediction = torch.max(softsx, 1) labels_mask = ( (labels > 0) * (labels != 255)) | (prediction.data.cpu() > 0) labels = labels[labels_mask] prediction = prediction[labels_mask] fake_mask = (labels.data.cpu().numpy() != prediction.data.cpu().numpy()) real_label = make_conf_label( 1, fake_mask ) # (10*321*321, ) 0 or 1 (fake or real) logits = D_out.squeeze(dim=1) logits = logits[labels_mask] logits_list.append(logits) # initialize labels_list.append(real_label) if (i_iter * args.iter_size * args.batch_size + sub_i + 1) % train_dataset_size == 0: logits = torch.cat(logits_list).cuda( ) # overall 5000 images in val, #logits >> 5000,100, (1464*321*321,) labels = torch.cat(labels_list).cuda() before_temperature_nll = nll_criterion( logits, labels).item() ####modify before_temperature_ece, _, _ = ece_criterion( logits, labels) # (1464*321*321,) before_temperature_ece = before_temperature_ece.item( ) print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece)) def eval(): loss_cali = nll_criterion( model_cali.temperature_scale(logits), labels) loss_cali.backward() return loss_cali optimizer_cali.step( eval) # just one backward >> not 50 iterations after_temperature_nll = nll_criterion( model_cali.temperature_scale(logits), labels).item() after_temperature_ece, accuracies, n_bin = ece_criterion( model_cali.temperature_scale(logits), labels) after_temperature_ece = after_temperature_ece.item( ) print('Optimal temperature: %.3f' % model_cali.temperature.item()) print( 'After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece)) logits_list = [] labels_list = [] else: loss = loss_seg # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_sum += loss_seg / args.iter_size if USED: loss_adv_sum += loss_adv_pred if i_iter % 100 == 0 and sub_i == 4: l_seg_sum = loss_seg_sum / 100 if USED: l_adv_sum = loss_adv_sum / 100 if i_iter == 0: l_seg_sum = l_seg_sum * 100 l_adv_sum = l_adv_sum * 100 loss_seg_sum = 0 loss_adv_sum = 0 loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size if USED: loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( ) / args.iter_size ##################### train D!!!########################### ########################################################### # bring back requires_grad if USED: for param in model_D.parameters(): param.requires_grad = True # before False. ############# train with pred S(X)############# labeled + unlabeled pred = pred.detach( ) #orginally only use labeled data, freeze S(X) when train D, # We do train D with the unlabeled data. But the difference is quite small if args.D_remain: #default true pred = torch.cat( (pred, pred_remain), 0 ) # pred_remain(unlabeled S(x)) is detached 16,21,321,321 ignore_mask = np.concatenate( (ignore_mask, ignore_mask_remain), axis=0) # 16,321,321 D_out = interp( model_D(F.softmax(pred, dim=1)) ) # D(S(X)) 16,1,321,321 # softmax(pred,dim=1) for 0.4, not nessesary loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) # pred_label = 0 # -log(1-D(S(X))) loss_D = loss_D / args.iter_size / 2 # iter_size = 1, /2 because there is G and D loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() ################## train with gt################### only labeled #VOCGT and VOCdataset can be reduced to one dataset in this repo. # get gt labels Y #print "before train gt" try: print(trainloader_gt_iter.next()) # len 732 _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() #print "train with gt?" _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) #one_hot ignore_mask_gt = (labels_gt.numpy() == 255 ) # same as ignore_mask (8,321,321) #print "finish" D_out = interp(model_D(D_gt_v)) # D(Y) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) # log(D(Y)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() optimizer.step() if USED: optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) #snapshot print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.6f}, loss_semi = {5:.6f}, loss_semi_adv = {6:.3f}, loss_vat = {7: .5f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value, loss_vat_value)) # L_ce L_adv for labeled L_D L_semi L_adv for unlabeled #loss_adv should be inversely proportional to the loss_D if they are seeing the same data. # loss_adv_p is essentially the inverse loss of loss_D. We expect them to achieve a good balance during the adversarial training # loss_D is around 0.2-0.5 >> good if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_D.pth')) #torch.save(state, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth.tar')) #torch.save(state_D, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth.tar')) break if i_iter % 100 == 0 and sub_i == 4: #loss_seg_value wdata = "iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.6f}, loss_semi = {5:.8f}, loss_semi_adv = {6:.3f}, l_vat_sum = {7: .5f}, loss_label = {8: .4}\n".format( i_iter, args.num_steps, l_seg_sum, l_adv_sum, loss_D_value, loss_semi_value, loss_semi_adv_value, l_vat_sum, l_seg_sum + 0.01 * l_adv_sum) #wdata2 = "{0:8d} {1:s} {2:s} {3:s} {4:s} {5:s} {6:s} {7:s} {8:s}\n".format(i_iter,str(model.coeff[0])[8:14],str(model.coeff[1])[8:14],str(model.coeff[2])[8:14],str(model.coeff[3])[8:14],str(model.coeff[4])[8:14],str(model.coeff[5])[8:14],str(model.coeff[6])[8:14],str(model.coeff[7])[8:14]) if i_iter == 0: f2 = open("/home/eungyo/AdvSemiSeg/snapshots/log.txt", 'w') f2.write(wdata) f2.close() #f3 = open("/home/eungyo/AdvSemiSeg/snapshots/coeff.txt", 'w') #f3.write(wdata2) #f3.close() else: f1 = open("/home/eungyo/AdvSemiSeg/snapshots/log.txt", 'a') f1.write(wdata) f1.close() #f4 = open("/home/eungyo/AdvSemiSeg/snapshots/coeff.txt", 'a') #f4.write(wdata2) #f4.close() if i_iter % args.save_pred_every == 0 and i_iter != 0: # 5000 print('taking snapshot ...') #state = {'epoch':i_iter, 'state_dict':model.state_dict(),'optim_dict':optimizer.state_dict()} #state_D = {'epoch':i_iter, 'state_dict': model_D.state_dict(), 'optim_dict': optimizer_D.state_dict()} #torch.save(state, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth.tar')) #torch.save(state_D, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth.tar')) torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) train_sampler_all = data.sampler.SubsetRandomSampler(train_ids) #train_gt_sampler_all = data.sampler.SubsetRandomSampler(train_ids) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) #train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader_all = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler_all, num_workers=3, pin_memory=True) #trainloader_gt_all = data.DataLoader(train_gt_dataset, #batch_size=args.batch_size, sampler=train_gt_sampler_all, num_workers=16, pin_memory=True) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) #trainloader_remain = data.DataLoader(train_dataset, #batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=16, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) #trainloader_remain_iter = iter(trainloader_remain) trainloader_all_iter = iter(trainloader_all) trainloader_iter = iter(trainloader) trainloader_gt_iter = iter(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable( torch.zeros(args.batch_size, 1).cuda()) for i_iter in range(20001, args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_fm_value = 0 loss_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # train with source #batch_l = next(trainloader_iter) try: batch_l = next(trainloader_iter) except: trainloader_iter = iter(trainloader) batch_l = next(trainloader_iter) images, labels, _, _ = batch_l images = Variable(images).cuda(args.gpu) pred_l = interp(model(images)) loss_seg = loss_calc(pred_l, labels, args.gpu) #fm loss calc #batch_all = next(trainloader_all_iter) try: batch_all = next(trainloader_all_iter) except: trainloader_all_iter = iter(trainloader_all) batch_all = next(trainloader_all_iter) images, _, _, _ = batch_all images = Variable(images).cuda(args.gpu) pred = interp(model(images)) #output of modelD for predictions _, D_out_y_pred = model_D(F.softmax(pred)) #output of modelD for ground truth #batch_gt = next(trainloader_gt_iter) try: batch_gt = next(trainloader_gt_iter) except: trainloader_gt_iter = iter(trainloader_gt) batch_gt = next(trainloader_gt_iter) _, labels_gt, _, _ = batch_gt D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) _, D_out_y_gt = model_D(D_gt_v) fm_loss = torch.mean( torch.abs( torch.mean(D_out_y_gt, 0) - torch.mean(D_out_y_pred, 0))) loss = loss_seg + fm_loss # proper normalization loss.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size loss_fm_value += fm_loss.data.cpu().numpy()[0] / args.iter_size loss_value += loss.data.cpu().numpy()[0] / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() D_out_z, _ = model_D(F.softmax(pred)) y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda()) loss_D_fake = criterion(D_out_z, y_fake_) D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) D_out_z_gt, _ = model_D(D_gt_v) y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda()) loss_D_real = criterion(D_out_z_gt, y_real_) loss_D = loss_D_fake + loss_D_real loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_D = {3:.3f}'. format(i_iter, args.num_steps, loss_seg_value, loss_D_value)) print('fm_loss: ', loss_fm_value, ' g_loss: ', loss_value) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model = nn.DataParallel(model) model.cuda() cudnn.benchmark = True # init D model_DS = [] for i in range(20): model_DS.append(Discriminator_n(num_classes=args.num_classes)) # if args.restore_from_D is not None: # model_D.load_state_dict(torch.load(args.restore_from_D)) for model_D in model_DS: model_D = nn.DataParallel(model_D) model_D.train() model_D.cuda() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data == 0: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.module.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_DS = [] for model_D in model_DS: optimizer_DS.append( optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))) for optimizer_D in optimizer_DS: optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = torch.nn.BCELoss() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 pred_c = [] for i in range(20): pred_c.append([]) gt_c = [] for i in range(20): gt_c.append([]) for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) for optimizer_D in optimizer_DS: optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for model_D in model_DS: for param in model_D.parameters(): param.requires_grad = False # do semi first # train with source try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda() ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels) loss_adv_pred = [] pred_soft = F.softmax(pred, dim=1) loss_adv_pred = 0 for i_l in range(labels.shape[0]): label_set = np.unique(labels[i_l]).tolist() for ls in label_set: if ls != 0 and ls != 255: ls = int(ls) D_out = model_DS[ls - 1](torch.cat([ pred_soft[i_l][ls - 1].unsqueeze(0), F.sigmoid(images[i_l]) ], 0).unsqueeze(0)) loss_adv_pred = loss_adv_pred + bce_loss( D_out, make_D_label(gt_label, D_out)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred * 2 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( )[0] / args.iter_size # train D # bring back requires_grad for model_D in model_DS: for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() pred_soft = F.softmax(pred, dim=1) for i_l in range(labels.shape[0]): label_set_pred = np.unique(labels[i_l]).tolist() for ls in label_set_pred: ls = int(ls) if ls != 0 and ls != 255: pred_c[ls - 1].append( torch.cat([ pred_soft[i_l][ls - 1].unsqueeze(0), F.sigmoid(images[i_l]) ], 0).unsqueeze(0)) loss_D = Variable(torch.FloatTensor([0.0]), requires_grad=True).cuda() for i in range(20): if len(pred_c[i]) >= 20: input_d = torch.cat([ pred_c[i][0], pred_c[i][1], pred_c[i][2], pred_c[i][3], pred_c[i][4], pred_c[i][5], pred_c[i][6], pred_c[i][7], pred_c[i][8], pred_c[i][9], pred_c[i][10], pred_c[i][11], pred_c[i][12], pred_c[i][13], pred_c[i][14], pred_c[i][15], pred_c[i][16], pred_c[i][17], pred_c[i][18], pred_c[i][19] ], 0) D_out = model_DS[i](input_d) loss_D = loss_D + bce_loss(D_out, make_D_label(pred_label, D_out)) pred_c[i] = [] loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() img_gt, labels_gt, _, _ = batch img_gt = Variable(img_gt).cuda() D_gt_v = Variable(one_hot(labels_gt)).cuda() ignore_mask_gt = (labels_gt.numpy() == 255) for i_l in range(labels_gt.shape[0]): label_set_gt = np.unique(labels_gt[i_l]).tolist() for ls in label_set_gt: ls = int(ls) if ls != 0 and ls != 255: gt_c[ls - 1].append( torch.cat([ D_gt_v[i_l][ls - 1].unsqueeze(0), F.sigmoid(img_gt[i_l]) ], 0).unsqueeze(0)) loss_D_gt = Variable(torch.FloatTensor([0.0]), requires_grad=True).cuda() for i in range(20): if len(gt_c[i]) >= 20: input_d = torch.cat([ gt_c[i][0], gt_c[i][1], gt_c[i][2], gt_c[i][3], gt_c[i][4], gt_c[i][5], gt_c[i][6], gt_c[i][7], gt_c[i][8], gt_c[i][9], gt_c[i][10], gt_c[i][11], gt_c[i][12], gt_c[i][13], gt_c[i][14], gt_c[i][15], gt_c[i][16], gt_c[i][17], gt_c[i][18], gt_c[i][19] ], 0) D_out = model_DS[i](input_d) loss_D_gt = loss_D_gt + bce_loss( D_out, make_D_label(gt_label, D_out)) gt_c[i] = [] loss_D_gt = loss_D_gt / args.iter_size / 2 loss_D_gt.backward() loss_D_value += loss_D_gt.data.cpu().numpy()[0] optimizer.step() for optimizer_D in optimizer_DS: optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(args.num_steps) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(i_iter) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): """Create the model and start the evaluation process.""" args = get_arguments() gpu0 = args.gpu print("Evaluating model") print(args.restore_from) print("classifier model") print(args.restore_from_classifier) print("sigmoid threshold") print(args.sigmoid_threshold) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model = Res_Deeplab(num_classes=args.num_classes) model_cls = Res_Deeplab_class(num_classes=args.num_classes, mode=6, latent_vars=args.latent_vars) saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict, strict=False) model.eval() model.cuda(gpu0) saved_state_dict = torch.load(args.restore_from_classifier) model_cls.load_state_dict(saved_state_dict, strict=False) model_cls.eval() model_cls.cuda(gpu0) testloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False), batch_size=1, shuffle=False, pin_memory=True) if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(505, 505), mode='bilinear') data_list = [] colorize = VOCColorize() combo_matrix = np.zeros((args.num_classes, args.latent_vars + 1), dtype=np.float32) for index, batch in enumerate(testloader): if index % 100 == 0: print('%d processd' % (index)) image, label, size, name = batch size = size[0].numpy() output = model(Variable(image, volatile=True).cuda(gpu0)) output = interp(output).cpu().data[0].numpy() output = output[:, :size[0], :size[1]] cls_pred = F.sigmoid( model_cls(Variable(image, volatile=True).cuda(gpu0))) cls_pred = cls_pred.cpu().data.numpy()[0] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) #gt_classes = np.unique(gt).tolist() for clsID in range(1, args.num_classes): if cls_pred[clsID - 1] < args.sigmoid_threshold: output[clsID, :, :] = -1000000000 output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) filename = os.path.join(args.save_dir, '{}.png'.format(name[0])) color_file = Image.fromarray( colorize(output).transpose(1, 2, 0), 'RGB') #color_file.save(filename) #filename = os.path.join(args.save_dir, '{}_lv.png'.format(name[0])) #color_file = Image.fromarray(colorize(output_lv).transpose(1, 2, 0), 'RGB') #color_file.save(filename) filename_gt = os.path.join(args.save_dir, '{}_gt.png'.format(name[0])) color_file_gt = Image.fromarray(colorize(gt).transpose(1, 2, 0), 'RGB') #color_file_gt.save(filename_gt) # show_all(gt, output) data_list.append([gt.flatten(), output.flatten()]) filename = os.path.join( args.save_dir, args.restore_from.split('/')[-1][:-4] + '_with_classifier_result.txt') confusion_matrix = get_iou(data_list, args.num_classes, filename)
def main(): """Create the model and start the evaluation process.""" args = get_arguments() gpu0 = args.gpu if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model = Res_Deeplab(num_classes=args.num_classes) # if args.pretrained_model != None: # args.restore_from = pretrianed_models_dict[args.pretrained_model] # # if args.restore_from[:4] == 'http' : # saved_state_dict = model_zoo.load_url(args.restore_from) # else: # saved_state_dict = torch.load(args.restore_from) #model.load_state_dict(saved_state_dict) model = Res_Deeplab(num_classes=args.num_classes) #model.load_state_dict(torch.load('/data/wyc/AdvSemiSeg/snapshots/VOC_15000.pth'))#70.7 state_dict = torch.load( '/data1/wyc/AdvSemiSeg/snapshots/VOC_t_baseline_1adv_mul_new_two_patch3_20000.pth' ) #baseline707 adv 709 nadv 705()*2#n adv0.694 # state_dict = torch.load( # '/home/wyc/VOC_t_baseline_nadv2_20000.pth') # baseline707 adv 709 nadv 705()*2 # original saved file with DataParallel # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in new_state_dict and param.size( ) == new_state_dict[name].size(): new_params[name].copy_(new_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.eval() model.cuda(gpu0) testloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False), batch_size=1, shuffle=False, pin_memory=True) if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(505, 505), mode='bilinear') data_list = [] colorize = VOCColorize() for index, batch in enumerate(testloader): if index % 100 == 0: print('%d processd' % (index)) image, label, size, name = batch size = size[0].numpy() output = model(Variable(image, volatile=True).cuda(gpu0)) output = interp(output).cpu().data[0].numpy() output = output[:, :size[0], :size[1]] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) filename = os.path.join(args.save_dir, '{}.png'.format(name[0])) color_file = Image.fromarray( colorize(output).transpose(1, 2, 0), 'RGB') color_file.save(filename) # show_all(gt, output) data_list.append([gt.flatten(), output.flatten()]) filename = os.path.join(args.save_dir, 'result.txt') get_iou(data_list, args.num_classes, filename)
def main(): """Create the model and start the evaluation process.""" args = get_arguments() gpu0 = args.gpu if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model = Res_Deeplab(num_classes=args.num_classes) # if args.pretrained_model != None: # args.restore_from = pretrianed_models_dict[args.pretrained_model] # # if args.restore_from[:4] == 'http' : # saved_state_dict = model_zoo.load_url(args.restore_from) # else: # saved_state_dict = torch.load(args.restore_from) #model.load_state_dict(saved_state_dict) model = Res_Deeplab(num_classes=args.num_classes) #model.load_state_dict(torch.load('/data/wyc/AdvSemiSeg/snapshots/VOC_15000.pth'))#70.7 state_dict = torch.load( '/data1/wyc/AdvSemiSeg/snapshots/VOC_t_baseline_1adv_mul_20000.pth' ) #baseline707 adv 709 nadv 705()*2#n adv0.694 # state_dict = torch.load( # '/home/wyc/VOC_t_baseline_nadv2_20000.pth') # baseline707 adv 709 nadv 705()*2 # original saved file with DataParallel # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in new_state_dict and param.size( ) == new_state_dict[name].size(): new_params[name].copy_(new_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.eval() model.cuda(gpu0) testloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False), batch_size=1, shuffle=False, pin_memory=True) if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(505, 505), mode='bilinear') data_list = [] colorize = VOCColorize() tag = 0 for index, batch in enumerate(testloader): if index % 100 == 0: print('%d processd' % (index)) image, label, size, name = batch size = size[0].numpy() output = model(Variable(image, volatile=True).cuda(gpu0)) pred = interp(output) pred01 = F.softmax(pred, dim=1) output = interp(output).cpu().data[0].numpy() image = Variable(image).cuda() pred_re = F.softmax(pred, dim=1).repeat(1, 3, 1, 1) indices_1 = torch.index_select(image, 1, Variable(torch.LongTensor([0])).cuda()) indices_2 = torch.index_select(image, 1, Variable(torch.LongTensor([1])).cuda()) indices_3 = torch.index_select(image, 1, Variable(torch.LongTensor([2])).cuda()) img_re = torch.cat([ indices_1.repeat(1, 21, 1, 1), indices_2.repeat(1, 21, 1, 1), indices_3.repeat(1, 21, 1, 1), ], 1) mul_img = pred_re * img_re for i_l in range(label.shape[0]): label_set = np.unique(label[i_l]).tolist() for ls in label_set: if ls != 0 and ls != 255: ls = int(ls) img_p = torch.cat([ mul_img[i_l][ls].unsqueeze(0).unsqueeze(0), mul_img[i_l][ls + 21].unsqueeze(0).unsqueeze(0), mul_img[i_l][ls + 21 + 21].unsqueeze(0).unsqueeze(0) ], 1) imgs = img_p.squeeze() imgs = imgs.transpose(0, 1) imgs = imgs.transpose(1, 2) imgs = imgs.data.cpu().numpy() img_ori = image[0] img_ori = img_ori.squeeze() img_ori = img_ori.transpose(0, 1) img_ori = img_ori.transpose(1, 2) img_ori = img_ori.data.cpu().numpy() pred_ori = pred01[0][ls] pred_ori = pred_ori.data.cpu().numpy() pred_0 = pred_ori.copy() pred_ori = pred_ori size = pred_ori.shape color_image = np.zeros((3, size[0], size[1]), dtype=np.uint8) for i in range(size[0]): for j in range(size[1]): if pred_0[i][j] > 0.995: color_image[0][i][j] = 0 color_image[1][i][j] = 255 color_image[2][i][j] = 0 elif pred_0[i][j] > 0.9: color_image[0][i][j] = 255 color_image[1][i][j] = 0 color_image[2][i][j] = 0 elif pred_0[i][j] > 0.7: color_image[0][i][j] = 0 color_image[1][i][j] = 0 color_image[2][i][j] = 255 color_image = color_image.transpose((1, 2, 0)) # print pred_ori.shape cv2.imwrite( osp.join('/data1/wyc/AdvSemiSeg/vis/img_pred', name[0] + '.png'), imgs) cv2.imwrite( osp.join('/data1/wyc/AdvSemiSeg/vis/image', name[0] + '.png'), img_ori) cv2.imwrite( osp.join('/data1/wyc/AdvSemiSeg/vis/pred', name[0] + '.png'), color_image) output = output[:, :size[0], :size[1]] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) filename = os.path.join(args.save_dir, '{}.png'.format(name[0])) color_file = Image.fromarray( colorize(output).transpose(1, 2, 0), 'RGB') color_file.save(filename) # show_all(gt, output) data_list.append([gt.flatten(), output.flatten()]) filename = os.path.join(args.save_dir, 'result.txt') get_iou(data_list, args.num_classes, filename)
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) # only copy the params that exist in current model (caffe-like) # state_dict = torch.load( # '/data1/wyc/AdvSemiSeg/snapshots/VOC_t_baseline_1adv_mul_20000.pth') # baseline707 adv 709 nadv 705()*2 # from collections import OrderedDict # new_state_dict = OrderedDict() # for k, v in state_dict.items(): # name = k[7:] # remove `module.` # new_state_dict[name] = v # # load params # # new_params = model.state_dict().copy() # for name, param in new_params.items(): # print (name) # if name in new_state_dict and param.size() == new_state_dict[name].size(): # new_params[name].copy_(new_state_dict[name]) # print('copy {}'.format(name)) # # model.load_state_dict(new_params) model.train() model = nn.DataParallel(model) model.cuda() cudnn.benchmark = True # init D model_D = Discriminator2(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D = nn.DataParallel(model_D) model_D.train() model_D.cuda() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data == 0: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.module.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = torch.nn.BCELoss() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) tw = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ] for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # train with source try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda() ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels) pred_0 = F.softmax(pred, dim=1) #pred_0 = 1 / (math.e ** (((pred_01 - 0.33) * 30) * (-1)) + 1) labels0 = Variable(one_hot(labels)).cuda() one_s = Variable(torch.ones(labels0.size())).cuda() labels0 = one_s - labels0 labels0 = torch.index_select( labels0, 1, Variable( torch.LongTensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ])).cuda()) pred0 = torch.index_select( pred_0, 1, Variable( torch.LongTensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ])).cuda()) pred_label0 = labels0 * pred0 pred_max = torch.max(pred_0, dim=1)[1] pred_max = Variable(one_hot(pred_max.cpu().data)).cuda() pred_max = torch.index_select( pred_max, 1, Variable( torch.LongTensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ])).cuda()) pred_c = pred_label0 * (pred_max) one_s = Variable(torch.ones(pred_max.size())).cuda() pred_m = pred_label0 * (one_s - pred_max) c3 = 0 c4 = 0 c5 = 0 c6 = 0 c7 = 0 c8 = 0 c9 = 0 c10 = 0 c0 = 0 pred_min_c_list = pred_c.cpu().data.numpy().flatten().tolist() pred_min_c_l = len(pred_min_c_list) for n in pred_min_c_list: if n < 0.00000001: c0 = c0 + 1 elif n < 0.1: c3 = c3 + 1 elif n < 0.2: c4 = c4 + 1 elif n < 0.3: c5 = c5 + 1 elif n < 0.4: c6 = c6 + 1 elif n < 0.5: c7 = c7 + 1 elif n < 0.6: c8 = c8 + 1 elif n < 0.7: c9 = c9 + 1 elif n <= 1: c10 = c10 + 1 else: print n if pred_min_c_l - c0 == 0: print pred_min_c_l else: pred_min_c_l = (pred_min_c_l - c0) * 1.00000 print "correct", 3, ":", c3 / pred_min_c_l, 4, ":", c4 / pred_min_c_l, 5, ":", c5 / pred_min_c_l, 6, ":", c6 / pred_min_c_l, 7, ":", c7 / pred_min_c_l, 8, ":", c8 / pred_min_c_l, 9, ":", c9 / pred_min_c_l, 10, ":", c10 / pred_min_c_l c3 = 0 c4 = 0 c5 = 0 c6 = 0 c7 = 0 c8 = 0 c9 = 0 c10 = 0 c0 = 0 pred_min_m_list = pred_m.cpu().data.numpy().flatten().tolist() pred_min_c_l = len(pred_min_m_list) for n in pred_min_m_list: if n < 0.0000001: c0 = c0 + 1 elif n < 0.005: c3 = c3 + 1 elif n < 0.05: c4 = c4 + 1 elif n < 0.3: c5 = c5 + 1 elif n < 0.4: c6 = c6 + 1 elif n < 0.5: c7 = c7 + 1 else: print n if pred_min_c_l - c0 == 0: print pred_min_c_l else: pred_min_c_l = (pred_min_c_l - c0) * 1.00000 print "mistake", 1, ":", c3 / pred_min_c_l, 2, ":", c4 / pred_min_c_l, 3, ":", c5 / pred_min_c_l, 4, ":", c6 / pred_min_c_l, 5, ":", c7 / pred_min_c_l ''' images, labels, _, _ = batch images = Variable(images).cuda() ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels) pred_0 = F.softmax(pred, dim=1) labels0 = Variable(one_hot(labels)).cuda() labels0=torch.index_select(labels0, 1, Variable(torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])).cuda()) pred0 = torch.index_select(pred_0, 1, Variable(torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])).cuda()) pred_label0=labels0*pred0 pred_max = torch.max(pred_0, dim=1)[1] pred_max = Variable(one_hot(pred_max.cpu().data)).cuda() pred_max = torch.index_select(pred_max, 1, Variable( torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])).cuda()) pred_c=pred_label0*(pred_max) one_s = Variable(torch.ones(pred_max.size())).cuda() pred_m = pred_label0 * (one_s - pred_max) c3 = 0 c4 = 0 c5 = 0 c6 = 0 c7 = 0 c8 = 0 c9 = 0 c10 = 0 c0 = 0 pred_min_c_list = pred_c.cpu().data.numpy().flatten().tolist() pred_min_c_l = len(pred_min_c_list) for n in pred_min_c_list: if n < 0.00000001: c0 = c0 + 1 elif n < 0.3: c3 = c3 + 1 elif n < 0.4: c4 = c4 + 1 elif n < 0.5: c5 = c5 + 1 elif n < 0.6: c6 = c6 + 1 elif n < 0.7: c7 = c7 + 1 elif n < 0.8: c8 = c8 + 1 elif n < 0.9: c9 = c9 + 1 elif n <= 1: c10 = c10 + 1 else: print n if pred_min_c_l - c0 == 0: print pred_min_c_l else: pred_min_c_l = (pred_min_c_l - c0) * 1.00000 print "correct", 3, ":", c3 / pred_min_c_l, 4, ":", c4 / pred_min_c_l, 5, ":", c5 / pred_min_c_l, 6, ":", c6 / pred_min_c_l, 7, ":", c7 / pred_min_c_l, 8, ":", c8 / pred_min_c_l, 9, ":", c9 / pred_min_c_l, 10, ":", c10 / pred_min_c_l c3 = 0 c4 = 0 c5 = 0 c6 = 0 c7 = 0 c8 = 0 c9 = 0 c10 = 0 c0 = 0 pred_min_m_list = pred_m.cpu().data.numpy().flatten().tolist() pred_min_c_l = len(pred_min_m_list) for n in pred_min_m_list: if n < 0.0000001: c0 = c0 + 1 elif n < 0.1: c3 = c3 + 1 elif n < 0.2: c4 = c4 + 1 elif n < 0.3: c5 = c5 + 1 elif n < 0.4: c6 = c6 + 1 elif n < 0.5: c7 = c7 + 1 else: print n if pred_min_c_l - c0 == 0: print pred_min_c_l else: pred_min_c_l = (pred_min_c_l - c0) * 1.00000 print "mistake", 1, ":", c3 / pred_min_c_l, 2, ":", c4 / pred_min_c_l, 3, ":", c5 / pred_min_c_l, 4, ":", c6 / pred_min_c_l, 5, ":", c7 / pred_min_c_l ''' # c3=0 # c4=0 # c5=0 # c6=0 # c7=0 # c8=0 # c9=0 # c10=0 # c0=0 # # pred_min_c_list=pred_c.cpu().data.numpy().flatten().tolist() # # pred_min_c=set(pred_c.cpu().data.numpy().flatten().tolist()) # pred_min_c_l=len(pred_min_c_list) # if len(pred_min_c) >1: # pred_min_c.discard(0.0) # pred_min_c=list(pred_min_c) # pred_min_c2 = min(pred_min_c) # pred_max_c2 = max(pred_min_c) # else: # pred_min_c2 = 999 # pred_max_c2 = 999 # for n in pred_min_c_list: # if n<0.00000001: # c0=c0+1 # elif n<0.3: # c3=c3+1 # elif n<0.4: # c4=c4+1 # elif n<0.5: # c5=c5+1 # elif n<0.6: # c6=c6+1 # elif n<0.7: # c7=c7+1 # elif n<0.8: # c8=c8+1 # elif n<0.9: # c9=c9+1 # elif n<=1: # c10=c10+1 # else: # print n # # if pred_min_c_l-c0==0: # print pred_min_c_l # else: # pred_min_c_l=(pred_min_c_l-c0)*1.00000 # # #print c0 + c3 + c4 + c5 + c6 + c7+c8+c9+c10 # # print "correct",3,":",c3/pred_min_c_l,4,":",c4/pred_min_c_l,5,":",c5/pred_min_c_l,6,":",c6/pred_min_c_l,7,":",c7/pred_min_c_l,8,":",c8/pred_min_c_l,9,":",c9/pred_min_c_l,10,":",c10/pred_min_c_l # # c3 = 0 # c4 = 0 # c5 = 0 # c6 = 0 # c7 = 0 # c8 = 0 # c9 = 0 # c10 = 0 # c0 = 0 # # # # # pred_min_m_list = pred_m.cpu().data.numpy().flatten().tolist() # pred_min_m = set(pred_m.cpu().data.numpy().flatten().tolist()) # pred_min_c_l = len(pred_min_m_list) # if len(pred_min_m) >1: # pred_min_m.discard(0.0) # pred_min_m=list(pred_min_m) # pred_min_m2 = min(pred_min_m) # pred_max_m2 = max(pred_min_m) # else: # pred_min_m2 = 999 # pred_max_m2 = 999 # # for n in pred_min_m_list: # if n<0.0000001: # c0=c0+1 # elif n<0.1: # c3=c3+1 # elif n<0.2: # c4=c4+1 # elif n<0.3: # c5=c5+1 # elif n<0.4: # c6=c6+1 # elif n<0.5: # c7=c7+1 # else: # print n # if pred_min_c_l-c0==0: # print pred_min_c_l # else: # pred_min_c_l=(pred_min_c_l-c0)*1.00000 # print "mistake",1,":",c3/pred_min_c_l,2,":",c4/pred_min_c_l,3,":",c5/pred_min_c_l,4,":",c6/pred_min_c_l,5,":",c7/pred_min_c_l # # # # print('max c {} min c {} max m {} min m {}'.format( # pred_max_c2,pred_min_c2, pred_max_m2,pred_min_m2)) '''20000 correct 3 : 6.60318801918e-05 4 : 0.00186540061542 5 : 0.00659328323715 6 : 0.0196708971091 7 : 0.0234446190621 8 : 0.0315071116335 9 : 0.0565364958202 10 : 0.860316160642 mistake 1 : 0.326503117461 2 : 0.211204060532 3 : 0.199110192061 4 : 0.15803490303 5 : 0.105147726917 max c 1.0 min c 0.206159025431 max m 0.499729216099 min m 3.0866687678e-11 iter = 0/ 20000, loss_seg = 0.164, loss_adv_p = 0.688, loss_D = 0.687, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 1.03566126972e-05 4 : 0.000749128318431 5 : 0.00340042116892 6 : 0.0126937549625 7 : 0.0161735768288 8 : 0.0261400904478 9 : 0.0492767632133 10 : 0.891555908448 mistake 1 : 0.398261661669 2 : 0.20677876039 3 : 0.167001935704 4 : 0.129388545185 5 : 0.0985690970509 max c 1.0 min c 0.295039653778 max m 0.499661713839 min m 1.87631424287e-10 iter = 1/ 20000, loss_seg = 0.114, loss_adv_p = 0.621, loss_D = 0.666, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.000170320224732 4 : 0.00122630561807 5 : 0.00573184329631 6 : 0.0155854360311 7 : 0.0216533779042 8 : 0.0339187050213 9 : 0.0683937894433 10 : 0.853320222461 mistake 1 : 0.370818679971 2 : 0.185940950356 3 : 0.165003680379 4 : 0.165167252801 5 : 0.113069436493 max c 1.0 min c 0.205323472619 max m 0.499007672071 min m 2.63143761003e-07 iter = 2/ 20000, loss_seg = 0.136, loss_adv_p = 6.856, loss_D = 2.909, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 5.57311529183e-05 4 : 0.00162151116348 5 : 0.00529976725609 6 : 0.0104801106131 7 : 0.0125395094066 8 : 0.0183382031746 9 : 0.032151567505 10 : 0.919513599728 mistake 1 : 0.599129542479 2 : 0.138565514313 3 : 0.114044089027 4 : 0.0929903400446 5 : 0.0552705141361 max c 1.0 min c 0.251725673676 max m 0.499345749617 min m 2.17372370104e-11 iter = 3/ 20000, loss_seg = 0.173, loss_adv_p = 0.211, loss_D = 1.186, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0 4 : 0.000408403449911 5 : 0.00314207170335 6 : 0.00972922412127 7 : 0.0128537300848 8 : 0.0230989478122 9 : 0.0531056227933 10 : 0.897662000035 mistake 1 : 0.376961004034 2 : 0.197477108279 3 : 0.160338093104 4 : 0.144009732983 5 : 0.1212140616 max c 1.0 min c 0.303397655487 max m 0.49908259511 min m 2.31815082538e-13 ''' '''0000 mistake 1 : 1.0 2 : 0.0 3 : 0.0 4 : 0.0 5 : 0.0 iter = 0/ 20000, loss_seg = 3.045, loss_adv_p = 0.677, loss_D = 0.689, loss_semi = 0.000, loss_semi_adv = 0.000 20608200 mistake 1 : 1.0 2 : 0.0 3 : 0.0 4 : 0.0 5 : 0.0 iter = 1/ 20000, loss_seg = 2.311, loss_adv_p = 0.001, loss_D = 4.009, loss_semi = 0.000, loss_semi_adv = 0.000 20608200 mistake 1 : 1.0 2 : 0.0 3 : 0.0 4 : 0.0 5 : 0.0 iter = 2/ 20000, loss_seg = 2.695, loss_adv_p = 1.304, loss_D = 0.996, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.215316600114 4 : 0.28849591177 5 : 0.159402928313 6 : 0.113015782468 7 : 0.102595550485 8 : 0.0921563034797 9 : 0.0290169233695 10 : 0.0 mistake 1 : 0.80735288413 2 : 0.095622890725 3 : 0.0764600062066 4 : 0.0203387447147 5 : 0.000225474223205 iter = 3/ 20000, loss_seg = 2.103, loss_adv_p = 0.936, loss_D = 0.720, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0 4 : 0.0106591946386 5 : 0.0159743486048 6 : 0.0144144664625 7 : 0.0482119128777 8 : 0.1030099948 9 : 0.149690912242 10 : 0.658039170374 mistake 1 : 0.997786787186 2 : 0.00220988967167 3 : 0.0 4 : 3.32314236342e-06 5 : 0.0 iter = 4/ 20000, loss_seg = 3.231, loss_adv_p = 0.352, loss_D = 0.740, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0547200197126 4 : 0.0824408402489 5 : 0.105990337314 6 : 0.0824672410303 7 : 0.0818248220147 8 : 0.106500752422 9 : 0.12377566376 10 : 0.362280323498 mistake 1 : 0.89707796374 2 : 0.0776956426097 3 : 0.0212782533776 4 : 0.0038358600209 5 : 0.000112280251508 iter = 5/ 20000, loss_seg = 2.440, loss_adv_p = 0.285, loss_D = 0.796, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0645113648539 4 : 0.11994706344 5 : 0.133270975533 6 : 0.130346143711 7 : 0.130461978634 8 : 0.155641595163 9 : 0.129387609717 10 : 0.136433268948 mistake 1 : 0.687583743536 2 : 0.172921294404 3 : 0.108947452597 4 : 0.027046718675 5 : 0.00350079078777 iter = 6/ 20000, loss_seg = 1.562, loss_adv_p = 0.471, loss_D = 0.700, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0517260346307 4 : 0.0852126465864 5 : 0.0921211854525 6 : 0.0708675276672 7 : 0.0477657257266 8 : 0.0341796660139 9 : 0.0336846274009 10 : 0.584442586522 mistake 1 : 0.896840471376 2 : 0.0747503216715 3 : 0.0228403760663 4 : 0.00484719754372 5 : 0.000721633342184 iter = 7/ 20000, loss_seg = 1.521, loss_adv_p = 1.028, loss_D = 0.781, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0354795814142 4 : 0.146293048623 5 : 0.182080980555 6 : 0.160111006186 7 : 0.0642910828885 8 : 0.0516294397657 9 : 0.0560234346393 10 : 0.304091425928 mistake 1 : 0.927571328438 2 : 0.0370632102681 3 : 0.0252651475979 4 : 0.00991933537868 5 : 0.000180978317074 iter = 8/ 20000, loss_seg = 2.240, loss_adv_p = 1.128, loss_D = 1.030, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.013791534851 4 : 0.0218762276947 5 : 0.0420707290008 6 : 0.0509204695049 7 : 0.0551454624403 8 : 0.0738167607469 9 : 0.109429384722 10 : 0.632949431039 mistake 1 : 0.82698518928 2 : 0.0871808088324 3 : 0.056324796628 4 : 0.0222555690883 5 : 0.00725363617091 iter = 9/ 20000, loss_seg = 1.620, loss_adv_p = 0.856, loss_D = 0.680, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0111066673332 4 : 0.0219300438268 5 : 0.039156626506 6 : 0.0714017897315 7 : 0.0741513772934 8 : 0.0834458164609 9 : 0.100755719975 10 : 0.598051958873 mistake 1 : 0.666796397198 2 : 0.104720338041 3 : 0.101412209496 4 : 0.0761078060714 5 : 0.0509632491938 iter = 10/ 20000, loss_seg = 1.065, loss_adv_p = 0.474, loss_D = 0.706, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0158021057795 4 : 0.0165904591721 5 : 0.0267981756919 6 : 0.0441302710184 7 : 0.0592082596077 8 : 0.0964185397359 9 : 0.156169887236 10 : 0.584882301758 mistake 1 : 0.693582996679 2 : 0.111521356046 3 : 0.114913868543 4 : 0.0540284587444 5 : 0.0259533199871 iter = 11/ 20000, loss_seg = 1.008, loss_adv_p = 0.442, loss_D = 0.724, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.00525433110547 4 : 0.0132143122692 5 : 0.0181547053526 6 : 0.0353758581661 7 : 0.0438108771263 8 : 0.057140850772 9 : 0.0862049023901 10 : 0.740844162818 mistake 1 : 0.698567424322 2 : 0.152478802192 3 : 0.0874738631406 4 : 0.0430223987244 5 : 0.0184575116206 iter = 12/ 20000, loss_seg = 0.957, loss_adv_p = 0.442, loss_D = 0.743, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.00325696566801 4 : 0.0222856760435 5 : 0.029201151092 6 : 0.0513157992579 7 : 0.0617187558094 8 : 0.0804648983871 9 : 0.10184338308 10 : 0.649913370662 mistake 1 : 0.910824450383 2 : 0.0481368156486 3 : 0.0238175658895 4 : 0.0128053565929 5 : 0.00441581148608 iter = 13/ 20000, loss_seg = 1.422, loss_adv_p = 0.830, loss_D = 0.618, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0149639294303 4 : 0.0338563367539 5 : 0.0443977524345 6 : 0.0406656984358 7 : 0.0549094069189 8 : 0.0846765553201 9 : 0.131008785505 10 : 0.595521535202 mistake 1 : 0.672923967008 2 : 0.177251807277 3 : 0.079961731589 4 : 0.0552797386879 5 : 0.0145827554388 iter = 14/ 20000, loss_seg = 0.997, loss_adv_p = 1.154, loss_D = 0.837, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0193227170348 4 : 0.0525052357055 5 : 0.0709388706682 6 : 0.0720057691548 7 : 0.0699806377682 8 : 0.0911210337061 9 : 0.15581657249 10 : 0.468309163473 mistake 1 : 0.815034514788 2 : 0.1099266077 3 : 0.0481236762295 4 : 0.0204414796676 5 : 0.0064737216159 iter = 15/ 20000, loss_seg = 0.938, loss_adv_p = 1.266, loss_D = 0.830, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.0138818759925 4 : 0.0214850753368 5 : 0.0245366000015 6 : 0.0524393902805 7 : 0.0717583953517 8 : 0.0829619547321 9 : 0.127629836154 10 : 0.605306872151 mistake 1 : 0.845186224931 2 : 0.0863806795239 3 : 0.037294979522 4 : 0.0189205299287 5 : 0.0122175860942 iter = 16/ 20000, loss_seg = 1.872, loss_adv_p = 0.826, loss_D = 0.847, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.00228016623147 4 : 0.0089505718804 5 : 0.0161634364312 6 : 0.0323728439557 7 : 0.0323958295024 8 : 0.0427347284028 9 : 0.0699726012283 10 : 0.795129822368 mistake 1 : 0.815070892754 2 : 0.0785075426593 3 : 0.0488263539692 4 : 0.0359976506471 5 : 0.0215975599703 iter = 17/ 20000, loss_seg = 1.717, loss_adv_p = 0.603, loss_D = 0.620, loss_semi = 0.000, loss_semi_adv = 0.000 correct 3 : 0.00200898622728 4 : 0.00781591308424 5 : 0.0269459263818 6 : 0.0502692998205 7 : 0.0554862862773 8 : 0.072763567832 9 : 0.10771673932 10 : 0.676993281057 ''' pred_re = F.softmax(pred, dim=1).repeat(1, 3, 1, 1) indices_1 = torch.index_select( images, 1, Variable(torch.LongTensor([0])).cuda()) indices_2 = torch.index_select( images, 1, Variable(torch.LongTensor([1])).cuda()) indices_3 = torch.index_select( images, 1, Variable(torch.LongTensor([2])).cuda()) img_re = torch.cat([ indices_1.repeat(1, 21, 1, 1), indices_2.repeat(1, 21, 1, 1), indices_3.repeat(1, 21, 1, 1), ], 1) mul_img = pred_re * img_re D_out = model_D(mul_img) loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, D_out)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( )[0] / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() pred_re2 = F.softmax(pred, dim=1).repeat(1, 3, 1, 1) mul_img2 = pred_re2 * img_re D_out = model_D(mul_img2) loss_D = bce_loss(D_out, make_D_label(pred_label, D_out)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() img_gt, labels_gt, _, _ = batch img_gt = Variable(img_gt).cuda() D_gt_v = Variable(one_hot(labels_gt)).cuda() ignore_mask_gt = (labels_gt.numpy() == 255) pred_re3 = D_gt_v.repeat(1, 3, 1, 1) indices_1 = torch.index_select( img_gt, 1, Variable(torch.LongTensor([0])).cuda()) indices_2 = torch.index_select( img_gt, 1, Variable(torch.LongTensor([1])).cuda()) indices_3 = torch.index_select( img_gt, 1, Variable(torch.LongTensor([2])).cuda()) img_re3 = torch.cat([ indices_1.repeat(1, 21, 1, 1), indices_2.repeat(1, 21, 1, 1), indices_3.repeat(1, 21, 1, 1), ], 1) mul_img3 = pred_re3 * img_re3 D_out = model_D(mul_img3) loss_D = bce_loss(D_out, make_D_label(gt_label, D_out)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] optimizer.step() optimizer_D.step() #print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(args.num_steps) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(i_iter) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu np.random.seed(args.random_seed) # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # load dataset train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) # labeled data train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) # unlabeled data train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_D_ul_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) # creating 2nd discriminator as a copy of the 1st one if i_iter == args.discr_split: model_D_ul = FCDiscriminator(num_classes=args.num_classes) model_D_ul.load_state_dict(net_D.state_dict()) model_D_ul.train() model_D_ul.cuda(args.gpu) optimizer_D_ul = optim.Adam(model_D_ul.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) # start training 2nd discriminator after specified number of steps if i_iter >= args.discr_split: optimizer_D_ul.zero_grad() adjust_learning_rate_D(optimizer_D_ul, i_iter) for sub_i in range(args.iter_size): # train Segmentation # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # don't accumulate grads in D_ul, in case split has already been made if i_iter >= args.discr_split: for param in model_D_ul.parameters(): param.requires_grad = False # do semi-supervised training first if args.lambda_semi_adv > 0 and i_iter >= args.semi_start_adv: try: _, batch = trainloader_remain_iter.next() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.next() # only access to img images, _, _, _ = batch images = Variable(images).cuda(args.gpu) pred = interp(model(images)) pred_remain = pred.detach() # choose discriminator depending on the iteration if i_iter >= args.discr_split: D_out = interp(model_D_ul(F.softmax(pred))) else: D_out = interp(model_D(F.softmax(pred))) D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze( axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) # adversarial loss loss_semi_adv = args.lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain, args.gpu)) loss_semi_adv = loss_semi_adv / args.iter_size # true loss value without multiplier loss_semi_adv_value += loss_semi_adv.data.cpu().numpy( ) / args.lambda_semi_adv loss_semi_adv.backward() else: loss_semi = None loss_semi_adv = None # train with labeled images try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) D_out = interp(model_D(F.softmax(pred))) # computing loss loss_seg = loss_calc(pred, labels, args.gpu) loss_adv_pred = bce_loss( D_out, make_D_label(gt_label, ignore_mask, args.gpu)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( ) / args.iter_size # train D and D_ul # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True if i_iter >= args.discr_split: for param in model_D_ul.parameters(): param.requires_grad = True # train D with pred pred = pred.detach() # before split, traing D with both labeled and unlabeled if args.D_remain and i_iter < args.discr_split and ( args.lambda_semi > 0 or args.lambda_semi_adv > 0): pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain), axis=0) D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask, args.gpu)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() # train D_ul with pred on unlabeled if i_iter >= args.discr_split and (args.lambda_semi > 0 or args.lambda_semi_adv > 0): D_ul_out = interp(model_D_ul(F.softmax(pred_remain))) loss_D_ul = bce_loss( D_ul_out, make_D_label(pred_label, ignore_mask_remain, args.gpu)) loss_D_ul = loss_D_ul / args.iter_size / 2 loss_D_ul.backward() loss_D_ul_value += loss_D_ul.data.cpu().numpy() # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() images_gt, labels_gt, _, _ = batch images_gt = Variable(images_gt).cuda(args.gpu) with torch.no_grad(): pred_l = interp(model(images_gt)) # train D with gt D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt, args.gpu)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() # train D_ul with pseudo_gt (gt are substituted for pred) if i_iter >= args.discr_split: D_ul_out = interp(model_D_ul(F.softmax(pred_l))) loss_D_ul = bce_loss( D_ul_out, make_D_label(gt_label, ignore_mask_gt, args.gpu)) loss_D_ul = loss_D_ul / args.iter_size / 2 loss_D_ul.backward() loss_D_ul_value += loss_D_ul.data.cpu().numpy() optimizer.step() optimizer_D.step() if i_iter >= args.discr_split: optimizer_D_ul.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_D_ul={5:.3f}, loss_semi = {6:.3f}, loss_semi_adv = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_D_ul_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( net.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' + str(args.random_seed) + '.pth')) torch.save( net_D.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' + str(args.random_seed) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( net.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(i_iter) + '_' + str(args.random_seed) + '.pth')) torch.save( net_D.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(i_iter) + '_' + str(args.random_seed) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): """Create the model and start the evaluation process.""" args = get_arguments() gpu0 = args.gpu if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model = Res_Deeplab(num_classes=args.num_classes) if args.pretrained_model != None: args.restore_from = pretrianed_models_dict[args.pretrained_model] if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict) model.eval() model.cuda(gpu0) testloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False), batch_size=1, shuffle=False, pin_memory=True) if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(505, 505), mode='bilinear') data_list = [] colorize = VOCColorize() for index, batch in enumerate(testloader): if index % 100 == 0: print('%d processd' % (index)) image, label, size, name = batch size = size[0].numpy() output = model(Variable(image, volatile=True).cuda(gpu0)) output = interp(output).cpu().data[0].numpy() output = output[:, :size[0], :size[1]] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) # Uncomment if you want to save images ''' filename = os.path.join(args.save_dir, '{}.png'.format(name[0])) color_file = Image.fromarray(colorize(output).transpose(1, 2, 0), 'RGB') color_file.save(filename) ''' # show_all(gt, output) data_list.append([gt.flatten(), output.flatten()]) filename = os.path.join(args.save_dir, args.save_name) get_iou(data_list, args.num_classes, filename)
def train(log_file, arch, dataset, batch_size, iter_size, num_workers, partial_data, partial_data_size, partial_id, ignore_label, crop_size, eval_crop_size, is_training, learning_rate, learning_rate_d, supervised, lambda_adv_pred, lambda_semi, lambda_semi_adv, mask_t, semi_start, semi_start_adv, d_remain, momentum, not_restore_last, num_steps, power, random_mirror, random_scale, random_seed, restore_from, restore_from_d, eval_every, save_snapshot_every, snapshot_dir, weight_decay, device): settings = locals().copy() import cv2 import torch import torch.nn as nn from torch.utils import data, model_zoo import numpy as np import pickle import torch.optim as optim import torch.nn.functional as F import scipy.misc import sys import os import os.path as osp import pickle from model.deeplab import Res_Deeplab from model.unet import unet_resnet50 from model.deeplabv3 import resnet101_deeplabv3 from model.discriminator import FCDiscriminator from utils.loss import CrossEntropy2d, BCEWithLogitsLoss2d from utils.evaluation import EvaluatorIoU from dataset.voc_dataset import VOCDataSet import logger torch_device = torch.device(device) import time if log_file != '' and log_file != 'none': if os.path.exists(log_file): print('Log file {} already exists; exiting...'.format(log_file)) return with logger.LogFile(log_file if log_file != 'none' else None): if dataset == 'pascal_aug': ds = VOCDataSet(augmented_pascal=True) elif dataset == 'pascal': ds = VOCDataSet(augmented_pascal=False) else: print('Dataset {} not yet supported'.format(dataset)) return print('Command: {}'.format(sys.argv[0])) print('Arguments: {}'.format(' '.join(sys.argv[1:]))) print('Settings: {}'.format(', '.join([ '{}={}'.format(k, settings[k]) for k in sorted(list(settings.keys())) ]))) print('Loaded data') def loss_calc(pred, label): """ This function returns cross entropy loss for semantic segmentation """ # out shape batch_size x channels x h x w -> batch_size x channels x h x w # label shape h x w x 1 x batch_size -> batch_size x 1 x h x w label = label.long().to(torch_device) criterion = CrossEntropy2d() return criterion(pred, label) def lr_poly(base_lr, iter, max_iter, power): return base_lr * ((1 - float(iter) / max_iter)**(power)) def adjust_learning_rate(optimizer, i_iter): lr = lr_poly(learning_rate, i_iter, num_steps, power) optimizer.param_groups[0]['lr'] = lr if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr * 10 def adjust_learning_rate_D(optimizer, i_iter): lr = lr_poly(learning_rate_d, i_iter, num_steps, power) optimizer.param_groups[0]['lr'] = lr if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr * 10 def one_hot(label): label = label.numpy() one_hot = np.zeros((label.shape[0], ds.num_classes, label.shape[1], label.shape[2]), dtype=label.dtype) for i in range(ds.num_classes): one_hot[:, i, ...] = (label == i) #handle ignore labels return torch.tensor(one_hot, dtype=torch.float, device=torch_device) def make_D_label(label, ignore_mask): ignore_mask = np.expand_dims(ignore_mask, axis=1) D_label = np.ones(ignore_mask.shape) * label D_label[ignore_mask] = ignore_label D_label = torch.tensor(D_label, dtype=torch.float, device=torch_device) return D_label h, w = map(int, eval_crop_size.split(',')) eval_crop_size = (h, w) h, w = map(int, crop_size.split(',')) crop_size = (h, w) # create network if arch == 'deeplab2': model = Res_Deeplab(num_classes=ds.num_classes) elif arch == 'unet_resnet50': model = unet_resnet50(num_classes=ds.num_classes) elif arch == 'resnet101_deeplabv3': model = resnet101_deeplabv3(num_classes=ds.num_classes) else: print('Architecture {} not supported'.format(arch)) return # load pretrained parameters if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) model.train() model = model.to(torch_device) # init D model_D = FCDiscriminator(num_classes=ds.num_classes) if restore_from_d is not None: model_D.load_state_dict(torch.load(restore_from_d)) model_D.train() model_D = model_D.to(torch_device) print('Built model') if snapshot_dir is not None: if not os.path.exists(snapshot_dir): os.makedirs(snapshot_dir) ds_train_xy = ds.train_xy(crop_size=crop_size, scale=random_scale, mirror=random_mirror, range01=model.RANGE01, mean=model.MEAN, std=model.STD) ds_train_y = ds.train_y(crop_size=crop_size, scale=random_scale, mirror=random_mirror, range01=model.RANGE01, mean=model.MEAN, std=model.STD) ds_val_xy = ds.val_xy(crop_size=eval_crop_size, scale=False, mirror=False, range01=model.RANGE01, mean=model.MEAN, std=model.STD) train_dataset_size = len(ds_train_xy) if partial_data_size != -1: if partial_data_size > partial_data_size: print('partial-data-size > |train|: exiting') return if partial_data == 1.0 and (partial_data_size == -1 or partial_data_size == train_dataset_size): trainloader = data.DataLoader(ds_train_xy, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(ds_train_y, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_remain = None print('|train|={}'.format(train_dataset_size)) print('|val|={}'.format(len(ds_val_xy))) else: #sample partial data if partial_data_size != -1: partial_size = partial_data_size else: partial_size = int(partial_data * train_dataset_size) if partial_id is not None: train_ids = pickle.load(open(partial_id)) print('loading train ids from {}'.format(partial_id)) else: rng = np.random.RandomState(random_seed) train_ids = list(rng.permutation(train_dataset_size)) if snapshot_dir is not None: pickle.dump(train_ids, open(osp.join(snapshot_dir, 'train_id.pkl'), 'wb')) print('|train supervised|={}'.format(partial_size)) print('|train unsupervised|={}'.format(train_dataset_size - partial_size)) print('|val|={}'.format(len(ds_val_xy))) print('supervised={}'.format(list(train_ids[:partial_size]))) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(ds_train_xy, batch_size=batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(ds_train_xy, batch_size=batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(ds_train_y, batch_size=batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) testloader = data.DataLoader(ds_val_xy, batch_size=1, shuffle=False, pin_memory=True) print('Data loaders ready') trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(learning_rate), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=learning_rate_d, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() print('Built optimizer') # labels for adversarial training pred_label = 0 gt_label = 1 loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_mask_accum = 0 loss_semi_value = 0 loss_semi_adv_value = 0 t1 = time.time() print('Training for {} steps...'.format(num_steps)) for i_iter in range(num_steps + 1): model.train() model.freeze_batchnorm() optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(iter_size): # train G if not supervised: # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if not supervised and (lambda_semi > 0 or lambda_semi_adv > 0 ) and i_iter >= semi_start_adv and \ trainloader_remain is not None: try: _, batch = next(trainloader_remain_iter) except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = next(trainloader_remain_iter) # only access to img images, _, _, _ = batch images = images.float().to(torch_device) pred = model(images) pred_remain = pred.detach() D_out = model_D(F.softmax(pred, dim=1)) D_out_sigmoid = F.sigmoid( D_out).data.cpu().numpy().squeeze(axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) loss_semi_adv = lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv / iter_size #loss_semi_adv.backward() loss_semi_adv_value += float( loss_semi_adv) / lambda_semi_adv if lambda_semi <= 0 or i_iter < semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < mask_t) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = ignore_label semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size loss_semi_mask_accum += float(semi_ratio) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = lambda_semi * loss_calc(pred, semi_gt) loss_semi = loss_semi / iter_size loss_semi_value += float(loss_semi) / lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = next(trainloader_iter) except: trainloader_iter = enumerate(trainloader) _, batch = next(trainloader_iter) images, labels, _, _ = batch images = images.float().to(torch_device) ignore_mask = (labels.numpy() == ignore_label) pred = model(images) loss_seg = loss_calc(pred, labels) if supervised: loss = loss_seg else: D_out = model_D(F.softmax(pred, dim=1)) loss_adv_pred = bce_loss( D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + lambda_adv_pred * loss_adv_pred loss_adv_pred_value += float(loss_adv_pred) / iter_size # proper normalization loss = loss / iter_size loss.backward() loss_seg_value += float(loss_seg) / iter_size if not supervised: # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if d_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate( (ignore_mask, ignore_mask_remain), axis=0) D_out = model_D(F.softmax(pred, dim=1)) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / iter_size / 2 loss_D.backward() loss_D_value += float(loss_D) # train with gt # get gt labels try: _, batch = next(trainloader_gt_iter) except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = next(trainloader_gt_iter) _, labels_gt, _, _ = batch D_gt_v = one_hot(labels_gt) ignore_mask_gt = (labels_gt.numpy() == ignore_label) D_out = model_D(D_gt_v) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / iter_size / 2 loss_D.backward() loss_D_value += float(loss_D) optimizer.step() optimizer_D.step() sys.stdout.write('.') sys.stdout.flush() if i_iter % eval_every == 0 and i_iter != 0: model.eval() with torch.no_grad(): evaluator = EvaluatorIoU(ds.num_classes) for index, batch in enumerate(testloader): image, label, size, name = batch size = size[0].numpy() image = image.float().to(torch_device) output = model(image) output = output.cpu().data[0].numpy() output = output[:, :size[0], :size[1]] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) evaluator.sample(gt, output, ignore_value=ignore_label) sys.stdout.write('+') sys.stdout.flush() per_class_iou = evaluator.score() mean_iou = per_class_iou.mean() loss_seg_value /= eval_every loss_adv_pred_value /= eval_every loss_D_value /= eval_every loss_semi_mask_accum /= eval_every loss_semi_value /= eval_every loss_semi_adv_value /= eval_every sys.stdout.write('\n') t2 = time.time() print( 'iter = {:8d}/{:8d}, took {:.3f}s, loss_seg = {:.6f}, loss_adv_p = {:.6f}, loss_D = {:.6f}, loss_semi_mask_rate = {:.3%} loss_semi = {:.6f}, loss_semi_adv = {:.3f}' .format(i_iter, num_steps, t2 - t1, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_mask_accum, loss_semi_value, loss_semi_adv_value)) for i, (class_name, iou) in enumerate(zip(ds.class_names, per_class_iou)): print('class {:2d} {:12} IU {:.2f}'.format( i, class_name, iou)) print('meanIOU: ' + str(mean_iou) + '\n') loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_mask_accum = 0 loss_semi_adv_value = 0 t1 = t2 if snapshot_dir is not None and i_iter % save_snapshot_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) if snapshot_dir is not None: print('save model ...') torch.save( model.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '_D.pth'))
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model = nn.DataParallel(model) model.cuda() cudnn.benchmark = True # init D model_D = Discriminator2(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D = nn.DataParallel(model_D) model_D.train() model_D.cuda() model_D2 = Discriminator2_patch(num_classes=args.num_classes) if args.restore_from_D is not None: model_D2.load_state_dict(torch.load(args.restore_from_D)) model_D2 = nn.DataParallel(model_D2) model_D2.train() model_D2.cuda() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data == 0: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.module.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() optimizer_D2.zero_grad() # loss/ bilinear upsampling bce_loss = torch.nn.BCELoss() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda() ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels) pred_re0 = F.softmax(pred, dim=1) pred_re = pred_re0.repeat(1, 3, 1, 1) #pred_re_2 = 1 / (math.e ** (((pred_re0 - 0.3) * 20) * (-1)) + 1)# 0.35) * 20) 673 indices_1 = torch.index_select( images, 1, Variable(torch.LongTensor([0])).cuda()) indices_2 = torch.index_select( images, 1, Variable(torch.LongTensor([1])).cuda()) indices_3 = torch.index_select( images, 1, Variable(torch.LongTensor([2])).cuda()) img_re = torch.cat([ indices_1.repeat(1, 21, 1, 1), indices_2.repeat(1, 21, 1, 1), indices_3.repeat(1, 21, 1, 1), ], 1) mul_img = pred_re * img_re D_out = model_D(mul_img) D_out_2 = model_D2(mul_img) loss_adv_pred = bce_loss(D_out, make_D_label( gt_label, D_out)) + bce_loss(D_out_2, make_D_label(gt_label, D_out_2)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( )[0] / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with pred pred = pred.detach() pred_re0 = F.softmax(pred, dim=1) pred_re2 = pred_re0.repeat(1, 3, 1, 1) #pred_re2_2 = 1 / (math.e ** (((pred_re0 - 0.35) * 20) * (-1)) + 1) mul_img2 = pred_re2 * img_re D_out = model_D(mul_img2) D_out_2 = model_D2(mul_img2) loss_D = bce_loss(D_out, make_D_label( pred_label, D_out)) + bce_loss( D_out_2, make_D_label(pred_label, D_out_2)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() img_gt, labels_gt, _, _ = batch img_gt = Variable(img_gt).cuda() D_gt_v = Variable(one_hot(labels_gt)).cuda() ignore_mask_gt = (labels_gt.numpy() == 255) pred_re3 = D_gt_v.repeat(1, 3, 1, 1) indices_1 = torch.index_select( img_gt, 1, Variable(torch.LongTensor([0])).cuda()) indices_2 = torch.index_select( img_gt, 1, Variable(torch.LongTensor([1])).cuda()) indices_3 = torch.index_select( img_gt, 1, Variable(torch.LongTensor([2])).cuda()) img_re3 = torch.cat([ indices_1.repeat(1, 21, 1, 1), indices_2.repeat(1, 21, 1, 1), indices_3.repeat(1, 21, 1, 1), ], 1) mul_img3 = pred_re3 * img_re3 D_out = model_D(mul_img3) D_out_2 = model_D2(mul_img3) loss_D = bce_loss(D_out, make_D_label(gt_label, D_out)) + bce_loss( D_out_2, make_D_label(gt_label, D_out_2)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] optimizer.step() optimizer_D.step() optimizer_D2.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(args.num_steps) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(i_iter) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu np.random.seed(args.random_seed) # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # load dataset train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) # labeled data train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) # unlabeled data train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # loss/bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') for i_iter in range(args.num_steps): loss_seg_value = 0 loss_unlabeled_seg_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) for sub_i in range(args.iter_size): # train Segmentation # train with labeled images try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) pred = interp(model(images)) # computing loss loss_seg = loss_calc(pred, labels, args.gpu) # proper normalization loss = loss_seg / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size # train with unlabeled if args.lambda_semi > 0 and i_iter >= args.semi_start: try: _, batch = trainloader_remain_iter.next() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.next() # only access to img images, _, _, _ = batch images = Variable(images).cuda(args.gpu) pred = interp(model(images)) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt = torch.FloatTensor(semi_gt) loss_unlabeled_seg = args.lambda_semi * loss_calc( pred, semi_gt, args.gpu) loss_unlabeled_seg = loss_unlabeled_seg / args.iter_size loss_unlabeled_seg.backward() loss_unlabeled_seg_value += loss_unlabeled_seg.data.cpu( ).numpy() / args.lambda_semi else: if args.lambda_semi > 0 and i_iter < args.semi_start: loss_unlabeled_seg_value = 0 else: loss_unlabeled_seg_value = None optimizer.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_unlabeled_seg = {3:.3f} ' .format(i_iter, args.num_steps, loss_seg_value, loss_unlabeled_seg_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' + str(args.lambda_semi) + '_' + str(args.random_seed) + '.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(i_iter) + '_' + str(args.lambda_semi) + '_' + str(args.random_seed) + '.pth')) end = timeit.default_timer() print(end - start, 'seconds')