Beispiel #1
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'Deeplab':
        model = Res_Deeplab(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from)

    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)


    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        if args.model == 'Deeplab':
            output = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG':
            output = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
Beispiel #2
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    # gpu0 = args.gpu

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    
    model = Res_Deeplab(num_classes=args.num_classes)

    if args.pretrained_model != None:
        args.restore_from = pretrianed_models_dict[args.pretrained_model]

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    # model.cuda(gpu0)

    testloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False), 
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(505, 505), mode='bilinear')
    data_list = []

    colorize = VOCColorize()

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd'%(index))
        image, label, size, name = batch
        size = size[0].numpy()
        # output = model(Variable(image, volatile=True).cuda(gpu0))
        output = model(Variable(image, volatile=True).cpu())
        output = interp(output).cpu().data[0].numpy()

        output = output[:,:size[0],:size[1]]
        gt = np.asarray(label[0].numpy()[:size[0],:size[1]], dtype=np.int)

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.int)
        
        filename = os.path.join(args.save_dir, '{}.png'.format(name[0]))
        color_file = Image.fromarray(colorize(output).transpose(1, 2, 0), 'RGB')
        color_file.save(filename)

        # show_all(gt, output)
        data_list.append([gt.flatten(), output.flatten()])

    filename = os.path.join(args.save_dir, 'result.txt')
    get_iou(data_list, args.num_classes, filename)
Beispiel #3
0
def eval(pth, cityscapes_eval_dir, i_iter):
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if args.model == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes)
        saved_state_dict = torch.load(pth)
    elif args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        saved_state_dict = torch.load(pth)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    cityscapesloader = data.DataLoader(cityscapesDataSet(
        args.cityscapes_data_dir,
        args.cityscapes_data_list,
        crop_size=(1024, 512),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                       batch_size=1,
                                       shuffle=False,
                                       pin_memory=True)

    interp = nn.Upsample(size=(1024, 2048),
                         mode='bilinear',
                         align_corners=True)

    for index, batch in enumerate(cityscapesloader):
        with torch.no_grad():
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, name = batch

            output = model(Variable(image).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            output.save('%s/%s' % (cityscapes_eval_dir, name))
            output_col.save('%s/%s_color.png' %
                            (cityscapes_eval_dir, name.split('.')[0]))

            if i_iter == 0:
                break
Beispiel #4
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)


    model = Res_Deeplab(num_classes=args.num_classes)

    if args.pretrained_model != None:
        args.restore_from = pretrianed_models_dict[args.pretrained_model]

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False),
                                    batch_size=1, shuffle=False, pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(505, 505), mode='bilinear')
    data_list = []

    colorize = VOCColorize()

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd'%(index))
        image, label, size, name = batch
        size = size[0].numpy()
        output = model(Variable(image, volatile=True).cuda(gpu0))
        output = interp(output).cpu().data[0].numpy()

        output = output[:,:size[0],:size[1]]
        gt = np.asarray(label[0].numpy()[:size[0],:size[1]], dtype=np.int)

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.int)

        filename = os.path.join(args.save_dir, '{}.png'.format(name[0]))
        color_file = Image.fromarray(colorize(output).transpose(1, 2, 0), 'RGB')
        color_file.save(filename)

        # show_all(gt, output)
        data_list.append([gt.flatten(), output.flatten()])

    filename = os.path.join(args.save_dir, 'result.txt')
    get_iou(data_list, args.num_classes, filename)
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)


    model = Res_Deeplab(num_classes=args.num_classes)

    if args.pretrained_model != None:
        args.restore_from[i] = pretrianed_models_dict[args.pretrained_model] # if there is pretrained model, restore_from will be changed

    if args.restore_from[i][:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from[i])
    else:
        saved_state_dict = torch.load(args.restore_from[i])  ##VOC_25000
    model.load_state_dict(saved_state_dict)

    model.eval() #evaluation mode
    model.cuda(gpu0)

    testloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False),
                                    batch_size=1, shuffle=False, pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(505, 505), mode='bilinear')
    data_list = []

    colorize = VOCColorize() # colorize!
    with torch.no_grad():### added for 0.4
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd'%(index))
            image, label, size, name = batch # size >>  tensor([[366, 500,   3]])
            size = size[0].numpy() ## [366 500 3]
            output = model(image.cuda(args.gpu))
            output = interp(output).cpu().data[0].numpy()# 21,505,505
            output = output[:,:size[0],:size[1]] # 21,366,500

            gt = np.asarray(label[0].numpy()[:size[0],:size[1]], dtype=np.int) # size of each image is diff

            output = output.transpose(1,2,0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.int)

            filename = os.path.join(args.save_dir, '{}.png'.format(name[0]))
            color_file = Image.fromarray(colorize(output).transpose(1, 2, 0), 'RGB') # colorize the output
            color_file.save(filename)

            # show_all(gt, output)
            data_list.append([gt.flatten(), output.flatten()])

    filename = os.path.join(args.save_dir, 'result'+args.restore_from[i][-10:-4]+'.txt')
    get_iou(data_list, args.num_classes, filename)
Beispiel #6
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=input_size, mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        if args.model == 'DeeplabMulti':
            output1, output2,_,_ = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
Beispiel #7
0
def main():
    # prepare
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D = detector.FlawDetector(in_channels=24)
    # model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = [_ for _ in range(0, train_dataset_size)]
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    minimum_loss = detector.MinimumCriterion()
    detector_loss = detector.FlawDetectorCriterion()

    # bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        if i_iter > 0 and i_iter % 1000 == 0:
            val(model, args.gpu)
            model.train()

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.__next__()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.__next__()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                pred_remain = pred.detach()

                D_out = model_D(images, pred)

                ignore_mask_remain = np.zeros(D_out.shape).astype(np.bool)
                loss_semi_adv = args.lambda_semi_adv * minimum_loss(
                    D_out)  # ke: SSL loss for unlabeled data
                loss_semi_adv = loss_semi_adv

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.iter_size

                loss_semi_adv.backward()
                loss_semi_value = 0

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source (labeled data)

            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(images, pred))

            loss_adv_pred = args.lambda_adv_pred * minimum_loss(
                D_out)  # ke: SSL loss for labeled data

            loss = loss_seg + loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred from labeled data
            pred = pred.detach()

            D_out = interp(model_D(images, pred))

            detect_gt = detector.generate_flaw_detector_gt(
                pred,
                labels.view(labels.shape[0], 1, labels.shape[1],
                            labels.shape[2]).cuda(args.gpu), NUM_CLASSES,
                IGNORE_LABEL)
            loss_D = detector_loss(D_out, detect_gt)

            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # # train with gt
            # # get gt labels
            # try:
            #     _, batch = trainloader_gt_iter.__next__()
            # except:
            #     trainloader_gt_iter = enumerate(trainloader_gt)
            #     _, batch = trainloader_gt_iter.__next__()

            # _, labels_gt, _, _ = batch
            # D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            # ignore_mask_gt = (labels_gt.numpy() == 255)

            # D_out = interp(model_D(D_gt_v))
            # loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            # loss_D = loss_D/args.iter_size/2
            # loss_D.backward()
            # loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.6f}, loss_adv_l = {3:.6f}, loss_D = {4:.6f}, loss_semi = {5:.6f}, loss_adv_u = {6:.6f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
Beispiel #8
0
def train(log_file, arch, dataset, batch_size, iter_size, num_workers,
          partial_data, partial_data_size, partial_id, ignore_label, crop_size,
          eval_crop_size, is_training, learning_rate, learning_rate_d,
          supervised, lambda_adv_pred, lambda_semi, lambda_semi_adv, mask_t,
          semi_start, semi_start_adv, d_remain, momentum, not_restore_last,
          num_steps, power, random_mirror, random_scale, random_seed,
          restore_from, restore_from_d, eval_every, save_snapshot_every,
          snapshot_dir, weight_decay, device):
    settings = locals().copy()

    import cv2
    import torch
    import torch.nn as nn
    from torch.utils import data, model_zoo
    import numpy as np
    import pickle
    import torch.optim as optim
    import torch.nn.functional as F
    import scipy.misc
    import sys
    import os
    import os.path as osp
    import pickle

    from model.deeplab import Res_Deeplab
    from model.unet import unet_resnet50
    from model.deeplabv3 import resnet101_deeplabv3
    from model.discriminator import FCDiscriminator
    from utils.loss import CrossEntropy2d, BCEWithLogitsLoss2d
    from utils.evaluation import EvaluatorIoU
    from dataset.voc_dataset import VOCDataSet
    import logger

    torch_device = torch.device(device)

    import time

    if log_file != '' and log_file != 'none':
        if os.path.exists(log_file):
            print('Log file {} already exists; exiting...'.format(log_file))
            return

    with logger.LogFile(log_file if log_file != 'none' else None):
        if dataset == 'pascal_aug':
            ds = VOCDataSet(augmented_pascal=True)
        elif dataset == 'pascal':
            ds = VOCDataSet(augmented_pascal=False)
        else:
            print('Dataset {} not yet supported'.format(dataset))
            return

        print('Command: {}'.format(sys.argv[0]))
        print('Arguments: {}'.format(' '.join(sys.argv[1:])))
        print('Settings: {}'.format(', '.join([
            '{}={}'.format(k, settings[k])
            for k in sorted(list(settings.keys()))
        ])))

        print('Loaded data')

        def loss_calc(pred, label):
            """
            This function returns cross entropy loss for semantic segmentation
            """
            # out shape batch_size x channels x h x w -> batch_size x channels x h x w
            # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w
            label = label.long().to(torch_device)
            criterion = CrossEntropy2d()

            return criterion(pred, label)

        def lr_poly(base_lr, iter, max_iter, power):
            return base_lr * ((1 - float(iter) / max_iter)**(power))

        def adjust_learning_rate(optimizer, i_iter):
            lr = lr_poly(learning_rate, i_iter, num_steps, power)
            optimizer.param_groups[0]['lr'] = lr
            if len(optimizer.param_groups) > 1:
                optimizer.param_groups[1]['lr'] = lr * 10

        def adjust_learning_rate_D(optimizer, i_iter):
            lr = lr_poly(learning_rate_d, i_iter, num_steps, power)
            optimizer.param_groups[0]['lr'] = lr
            if len(optimizer.param_groups) > 1:
                optimizer.param_groups[1]['lr'] = lr * 10

        def one_hot(label):
            label = label.numpy()
            one_hot = np.zeros((label.shape[0], ds.num_classes, label.shape[1],
                                label.shape[2]),
                               dtype=label.dtype)
            for i in range(ds.num_classes):
                one_hot[:, i, ...] = (label == i)
            #handle ignore labels
            return torch.tensor(one_hot,
                                dtype=torch.float,
                                device=torch_device)

        def make_D_label(label, ignore_mask):
            ignore_mask = np.expand_dims(ignore_mask, axis=1)
            D_label = np.ones(ignore_mask.shape) * label
            D_label[ignore_mask] = ignore_label
            D_label = torch.tensor(D_label,
                                   dtype=torch.float,
                                   device=torch_device)

            return D_label

        h, w = map(int, eval_crop_size.split(','))
        eval_crop_size = (h, w)

        h, w = map(int, crop_size.split(','))
        crop_size = (h, w)

        # create network
        if arch == 'deeplab2':
            model = Res_Deeplab(num_classes=ds.num_classes)
        elif arch == 'unet_resnet50':
            model = unet_resnet50(num_classes=ds.num_classes)
        elif arch == 'resnet101_deeplabv3':
            model = resnet101_deeplabv3(num_classes=ds.num_classes)
        else:
            print('Architecture {} not supported'.format(arch))
            return

        # load pretrained parameters
        if restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(restore_from)
        else:
            saved_state_dict = torch.load(restore_from)

        # only copy the params that exist in current model (caffe-like)
        new_params = model.state_dict().copy()
        for name, param in new_params.items():
            if name in saved_state_dict and param.size(
            ) == saved_state_dict[name].size():
                new_params[name].copy_(saved_state_dict[name])
        model.load_state_dict(new_params)

        model.train()
        model = model.to(torch_device)

        # init D
        model_D = FCDiscriminator(num_classes=ds.num_classes)
        if restore_from_d is not None:
            model_D.load_state_dict(torch.load(restore_from_d))
        model_D.train()
        model_D = model_D.to(torch_device)

        print('Built model')

        if snapshot_dir is not None:
            if not os.path.exists(snapshot_dir):
                os.makedirs(snapshot_dir)

        ds_train_xy = ds.train_xy(crop_size=crop_size,
                                  scale=random_scale,
                                  mirror=random_mirror,
                                  range01=model.RANGE01,
                                  mean=model.MEAN,
                                  std=model.STD)
        ds_train_y = ds.train_y(crop_size=crop_size,
                                scale=random_scale,
                                mirror=random_mirror,
                                range01=model.RANGE01,
                                mean=model.MEAN,
                                std=model.STD)
        ds_val_xy = ds.val_xy(crop_size=eval_crop_size,
                              scale=False,
                              mirror=False,
                              range01=model.RANGE01,
                              mean=model.MEAN,
                              std=model.STD)

        train_dataset_size = len(ds_train_xy)

        if partial_data_size != -1:
            if partial_data_size > partial_data_size:
                print('partial-data-size > |train|: exiting')
                return

        if partial_data == 1.0 and (partial_data_size == -1 or
                                    partial_data_size == train_dataset_size):
            trainloader = data.DataLoader(ds_train_xy,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=5,
                                          pin_memory=True)

            trainloader_gt = data.DataLoader(ds_train_y,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=5,
                                             pin_memory=True)

            trainloader_remain = None
            print('|train|={}'.format(train_dataset_size))
            print('|val|={}'.format(len(ds_val_xy)))
        else:
            #sample partial data
            if partial_data_size != -1:
                partial_size = partial_data_size
            else:
                partial_size = int(partial_data * train_dataset_size)

            if partial_id is not None:
                train_ids = pickle.load(open(partial_id))
                print('loading train ids from {}'.format(partial_id))
            else:
                rng = np.random.RandomState(random_seed)
                train_ids = list(rng.permutation(train_dataset_size))

            if snapshot_dir is not None:
                pickle.dump(train_ids,
                            open(osp.join(snapshot_dir, 'train_id.pkl'), 'wb'))

            print('|train supervised|={}'.format(partial_size))
            print('|train unsupervised|={}'.format(train_dataset_size -
                                                   partial_size))
            print('|val|={}'.format(len(ds_val_xy)))

            print('supervised={}'.format(list(train_ids[:partial_size])))

            train_sampler = data.sampler.SubsetRandomSampler(
                train_ids[:partial_size])
            train_remain_sampler = data.sampler.SubsetRandomSampler(
                train_ids[partial_size:])
            train_gt_sampler = data.sampler.SubsetRandomSampler(
                train_ids[:partial_size])

            trainloader = data.DataLoader(ds_train_xy,
                                          batch_size=batch_size,
                                          sampler=train_sampler,
                                          num_workers=3,
                                          pin_memory=True)
            trainloader_remain = data.DataLoader(ds_train_xy,
                                                 batch_size=batch_size,
                                                 sampler=train_remain_sampler,
                                                 num_workers=3,
                                                 pin_memory=True)
            trainloader_gt = data.DataLoader(ds_train_y,
                                             batch_size=batch_size,
                                             sampler=train_gt_sampler,
                                             num_workers=3,
                                             pin_memory=True)

            trainloader_remain_iter = enumerate(trainloader_remain)

        testloader = data.DataLoader(ds_val_xy,
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        print('Data loaders ready')

        trainloader_iter = enumerate(trainloader)
        trainloader_gt_iter = enumerate(trainloader_gt)

        # implement model.optim_parameters(args) to handle different models' lr setting

        # optimizer for segmentation network
        optimizer = optim.SGD(model.optim_parameters(learning_rate),
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
        optimizer.zero_grad()

        # optimizer for discriminator network
        optimizer_D = optim.Adam(model_D.parameters(),
                                 lr=learning_rate_d,
                                 betas=(0.9, 0.99))
        optimizer_D.zero_grad()

        # loss/ bilinear upsampling
        bce_loss = BCEWithLogitsLoss2d()

        print('Built optimizer')

        # labels for adversarial training
        pred_label = 0
        gt_label = 1

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_mask_accum = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        t1 = time.time()

        print('Training for {} steps...'.format(num_steps))
        for i_iter in range(num_steps + 1):

            model.train()
            model.freeze_batchnorm()

            optimizer.zero_grad()
            adjust_learning_rate(optimizer, i_iter)
            optimizer_D.zero_grad()
            adjust_learning_rate_D(optimizer_D, i_iter)

            for sub_i in range(iter_size):

                # train G

                if not supervised:
                    # don't accumulate grads in D
                    for param in model_D.parameters():
                        param.requires_grad = False

                # do semi first
                if not supervised and (lambda_semi > 0 or lambda_semi_adv > 0 ) and i_iter >= semi_start_adv and \
                        trainloader_remain is not None:
                    try:
                        _, batch = next(trainloader_remain_iter)
                    except:
                        trainloader_remain_iter = enumerate(trainloader_remain)
                        _, batch = next(trainloader_remain_iter)

                    # only access to img
                    images, _, _, _ = batch
                    images = images.float().to(torch_device)

                    pred = model(images)
                    pred_remain = pred.detach()

                    D_out = model_D(F.softmax(pred, dim=1))
                    D_out_sigmoid = F.sigmoid(
                        D_out).data.cpu().numpy().squeeze(axis=1)

                    ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                        np.bool)

                    loss_semi_adv = lambda_semi_adv * bce_loss(
                        D_out, make_D_label(gt_label, ignore_mask_remain))
                    loss_semi_adv = loss_semi_adv / iter_size

                    #loss_semi_adv.backward()
                    loss_semi_adv_value += float(
                        loss_semi_adv) / lambda_semi_adv

                    if lambda_semi <= 0 or i_iter < semi_start:
                        loss_semi_adv.backward()
                        loss_semi_value = 0
                    else:
                        # produce ignore mask
                        semi_ignore_mask = (D_out_sigmoid < mask_t)

                        semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                        semi_gt[semi_ignore_mask] = ignore_label

                        semi_ratio = 1.0 - float(
                            semi_ignore_mask.sum()) / semi_ignore_mask.size

                        loss_semi_mask_accum += float(semi_ratio)

                        if semi_ratio == 0.0:
                            loss_semi_value += 0
                        else:
                            semi_gt = torch.FloatTensor(semi_gt)

                            loss_semi = lambda_semi * loss_calc(pred, semi_gt)
                            loss_semi = loss_semi / iter_size
                            loss_semi_value += float(loss_semi) / lambda_semi
                            loss_semi += loss_semi_adv
                            loss_semi.backward()

                else:
                    loss_semi = None
                    loss_semi_adv = None

                # train with source

                try:
                    _, batch = next(trainloader_iter)
                except:
                    trainloader_iter = enumerate(trainloader)
                    _, batch = next(trainloader_iter)

                images, labels, _, _ = batch
                images = images.float().to(torch_device)
                ignore_mask = (labels.numpy() == ignore_label)
                pred = model(images)

                loss_seg = loss_calc(pred, labels)

                if supervised:
                    loss = loss_seg
                else:
                    D_out = model_D(F.softmax(pred, dim=1))

                    loss_adv_pred = bce_loss(
                        D_out, make_D_label(gt_label, ignore_mask))

                    loss = loss_seg + lambda_adv_pred * loss_adv_pred
                    loss_adv_pred_value += float(loss_adv_pred) / iter_size

                # proper normalization
                loss = loss / iter_size
                loss.backward()
                loss_seg_value += float(loss_seg) / iter_size

                if not supervised:
                    # train D

                    # bring back requires_grad
                    for param in model_D.parameters():
                        param.requires_grad = True

                    # train with pred
                    pred = pred.detach()

                    if d_remain:
                        pred = torch.cat((pred, pred_remain), 0)
                        ignore_mask = np.concatenate(
                            (ignore_mask, ignore_mask_remain), axis=0)

                    D_out = model_D(F.softmax(pred, dim=1))
                    loss_D = bce_loss(D_out,
                                      make_D_label(pred_label, ignore_mask))
                    loss_D = loss_D / iter_size / 2
                    loss_D.backward()
                    loss_D_value += float(loss_D)

                    # train with gt
                    # get gt labels
                    try:
                        _, batch = next(trainloader_gt_iter)
                    except:
                        trainloader_gt_iter = enumerate(trainloader_gt)
                        _, batch = next(trainloader_gt_iter)

                    _, labels_gt, _, _ = batch
                    D_gt_v = one_hot(labels_gt)
                    ignore_mask_gt = (labels_gt.numpy() == ignore_label)

                    D_out = model_D(D_gt_v)
                    loss_D = bce_loss(D_out,
                                      make_D_label(gt_label, ignore_mask_gt))
                    loss_D = loss_D / iter_size / 2
                    loss_D.backward()
                    loss_D_value += float(loss_D)

            optimizer.step()
            optimizer_D.step()

            sys.stdout.write('.')
            sys.stdout.flush()

            if i_iter % eval_every == 0 and i_iter != 0:
                model.eval()
                with torch.no_grad():
                    evaluator = EvaluatorIoU(ds.num_classes)
                    for index, batch in enumerate(testloader):
                        image, label, size, name = batch
                        size = size[0].numpy()
                        image = image.float().to(torch_device)
                        output = model(image)
                        output = output.cpu().data[0].numpy()

                        output = output[:, :size[0], :size[1]]
                        gt = np.asarray(label[0].numpy()[:size[0], :size[1]],
                                        dtype=np.int)

                        output = output.transpose(1, 2, 0)
                        output = np.asarray(np.argmax(output, axis=2),
                                            dtype=np.int)

                        evaluator.sample(gt, output, ignore_value=ignore_label)

                        sys.stdout.write('+')
                        sys.stdout.flush()

                per_class_iou = evaluator.score()
                mean_iou = per_class_iou.mean()

                loss_seg_value /= eval_every
                loss_adv_pred_value /= eval_every
                loss_D_value /= eval_every
                loss_semi_mask_accum /= eval_every
                loss_semi_value /= eval_every
                loss_semi_adv_value /= eval_every

                sys.stdout.write('\n')

                t2 = time.time()

                print(
                    'iter = {:8d}/{:8d}, took {:.3f}s, loss_seg = {:.6f}, loss_adv_p = {:.6f}, loss_D = {:.6f}, loss_semi_mask_rate = {:.3%} loss_semi = {:.6f}, loss_semi_adv = {:.3f}'
                    .format(i_iter, num_steps, t2 - t1, loss_seg_value,
                            loss_adv_pred_value, loss_D_value,
                            loss_semi_mask_accum, loss_semi_value,
                            loss_semi_adv_value))

                for i, (class_name,
                        iou) in enumerate(zip(ds.class_names, per_class_iou)):
                    print('class {:2d} {:12} IU {:.2f}'.format(
                        i, class_name, iou))

                print('meanIOU: ' + str(mean_iou) + '\n')

                loss_seg_value = 0
                loss_adv_pred_value = 0
                loss_D_value = 0
                loss_semi_value = 0
                loss_semi_mask_accum = 0
                loss_semi_adv_value = 0

                t1 = t2

            if snapshot_dir is not None and i_iter % save_snapshot_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D.state_dict(),
                    osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

        if snapshot_dir is not None:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '_D.pth'))
def main():
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network

    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters (weights)
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(
            args.restore_from
        )  ## http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth
    else:
        saved_state_dict = torch.load(args.restore_from)
        #checkpoint = torch.load(args.restore_from)_

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()  # state_dict() is current model
    for name, param in new_params.items():
        #print (name) # 'conv1.weight, name:param(value), dict
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            #print('copy {}'.format(name))
    model.load_state_dict(new_params)
    #model.load_state_dict(checkpoint['state_dict'])
    #optimizer.load_state_dict(args.checkpoint['optim_dict'])

    model.train(
    )  # https://pytorch.org/docs/stable/nn.html, Sets the module in training mode.
    model.cuda(args.gpu)  ##

    cudnn.benchmark = True  # This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware

    # init D

    model_D = FCDiscriminator(num_classes=args.num_classes)
    #args.restore_from_D = 'snapshots/linear2/VOC_25000_D.pth'
    if args.restore_from_D is not None:  # None
        model_D.load_state_dict(torch.load(args.restore_from_D))
        # checkpoint_D = torch.load(args.restore_from_D)
        # model_D.load_state_dict(checkpoint_D['state_dict'])
        # optimizer_D.load_state_dict(checkpoint_D['optim_dict'])
    model_D.train()
    model_D.cuda(args.gpu)

    if USECALI:
        model_cali = ModelWithTemperature(model, model_D)
        model_cali.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)
    train_dataset_remain = VOCDataSet(args.data_dir,
                                      args.data_list_remain,
                                      crop_size=input_size,
                                      scale=args.random_scale,
                                      mirror=args.random_mirror,
                                      mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    train_dataset_size_remain = len(train_dataset_remain)

    print train_dataset_size
    print train_dataset_size_remain

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)
    if args.partial_data is None:  #if not partial, load all

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        #args.partial_data = 0.125
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:  #args.partial_id is none
            train_ids = range(train_dataset_size)
            train_ids_remain = range(train_dataset_size_remain)
            np.random.shuffle(train_ids)  #shuffle!
            np.random.shuffle(train_ids_remain)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'),
                         'wb'))  #randomly suffled ids

        #sampler
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:])  # 0~1/8,
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids_remain[:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:])
        # train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) # 0~1/8
        # train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) # used as unlabeled, 7/8
        # train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        #train loader
        trainloader = data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            sampler=train_sampler,
            num_workers=3,
            pin_memory=True)  # multi-process data loading
        trainloader_remain = data.DataLoader(train_dataset_remain,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        # trainloader_remain = data.DataLoader(train_dataset,
        #                                      batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3,
        #                                     pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network

    # model.optim_paramters(args) = list(dict1, dict2), dict1 >> 'lr' and 'params'
    # print(type(model.optim_parameters(args)[0]['params'])) # generator
    #print(model.state_dict()['coeff'][0]) #confirmed

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    #optimizer.add_param_group({"params":model.coeff}) # assign new coefficient to the optimizer
    #print(len(optimizer.param_groups))
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()  #initialize

    if USECALI:
        optimizer_cali = optim.LBFGS([model_cali.temperature],
                                     lr=0.01,
                                     max_iter=50)
        optimizer_cali.zero_grad()

        nll_criterion = BCEWithLogitsLoss().cuda()  # BCE!!
        ece_criterion = ECELoss().cuda()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(
        size=(input_size[1], input_size[0]), mode='bilinear'
    )  # okay it automatically change to functional.interpolate
    # 321, 321

    if version.parse(torch.__version__) >= version.parse('0.4.0'):  #0.4.1
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1
    semi_ratio_sum = 0
    semi_sum = 0
    loss_seg_sum = 0
    loss_adv_sum = 0
    loss_vat_sum = 0
    l_seg_sum = 0
    l_vat_sum = 0
    l_adv_sum = 0
    logits_list = []
    labels_list = []

    #https: // towardsdatascience.com / understanding - pytorch -with-an - example - a - step - by - step - tutorial - 81fc5f8c4e8e

    for i_iter in range(args.num_steps):

        loss_seg_value = 0  # L_seg
        loss_adv_pred_value = 0  # 0.01 L_adv
        loss_D_value = 0  # L_D
        loss_semi_value = 0  # 0.1 L_semi
        loss_semi_adv_value = 0  # 0.001 L_adv
        loss_vat_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)  #changing lr by iteration

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            ###################### train G!!!###########################
            ############################################################
            # don't accumulate grads in D

            for param in model_D.parameters(
            ):  # <class 'torch.nn.parameter.Parameter'>, convolution weights
                param.requires_grad = False  # do not update gradient of D (freeze) while G

            ######### do unlabeled first!! 0.001 L_adv + 0.1 L_semi ###############

            # lambda_semi, lambda_adv for unlabeled
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.next(
                    )  #remain = unlabeled
                    print(trainloader_remain_iter.next())
                except:
                    trainloader_remain_iter = enumerate(
                        trainloader_remain)  # impose counters
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch  # <class 'torch.Tensor'>
                images = Variable(images).cuda(
                    args.gpu)  # <class 'torch.Tensor'>

                pred = interp(
                    model(images))  # S(X), pred <class 'torch.Tensor'>
                pred_remain = pred.detach(
                )  #use detach() when attempting to remove a tensor from a computation graph, will be used for D
                # https://discuss.pytorch.org/t/clone-and-detach-in-v0-4-0/16861
                # The difference is that detach refers to only a given variable on which it's called.
                # torch.no_grad affects all operations taking place within the with statement. >> for context,
                # requires_grad is for tensor

                # pred >> (8,21,321,321), L_adv
                D_out = interp(
                    model_D(F.softmax(pred))
                )  # D(S(X)), confidence, 8,1,321,321, not detached, there was not dim
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)  # (8,321,321) 0~1

                # 0.001 L_adv!!!!
                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)  # no ignore_mask for unlabeled adv
                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label,
                                        ignore_mask_remain))  #gt_label =1,
                # -log(D(S(X)))
                loss_semi_adv = loss_semi_adv / args.iter_size  #normalization

                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv

                ##--- visualization, pred(8,21,321,321), D_out_sigmoid(8,321,321)
                """
                if i_iter % 1000 == 0:
                    vpred = pred.transpose(1, 2).transpose(2, 3).contiguous()  # (8,321,321,21)
                    vpred = vpred.view(-1, 21)  # (8*321*321, 21)
                    vlogsx = F.log_softmax(vpred)  # torch.Tensor
                    vsemi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    vsemi_gt = Variable(torch.FloatTensor(vsemi_gt).long()).cuda(gpu)
                    vlogsx = vlogsx.gather(1, vsemi_gt.view(-1, 1))
                    sx = F.softmax(vpred).gather(1, vsemi_gt.view(-1, 1))
                    vD_out_sigmoid = Variable(torch.FloatTensor(D_out_sigmoid)).cuda(gpu).view(-1, 1)
                    vlogsx = (vlogsx*(2.5*vD_out_sigmoid+0.5))
                    vlogsx = -vlogsx.squeeze(dim=1)
                    sx = sx.squeeze(dim=1)
                    vD_out_sigmoid = vD_out_sigmoid.squeeze(dim=1)
                    dsx = vD_out_sigmoid.data.cpu().detach().numpy()
                    vlogsx = vlogsx.data.cpu().detach().numpy()
                    sx = sx.data.cpu().detach().numpy()
                    plt.clf()
                    plt.figure(figsize=(15, 5))
                    plt.subplot(131)
                    plt.ylim(0, 0.004)
                    plt.scatter(dsx, vlogsx, s = 0.1)  # variable requires grad cannot call numpy >> detach
                    plt.xlabel('D(S(X))')
                    plt.ylabel('Loss_Semi per Pixel')
                    plt.subplot(132)
                    plt.scatter(dsx, vlogsx, s = 0.1)  # variable requires grad cannot call numpy >> detach
                    plt.xlabel('D(S(X))')
                    plt.ylabel('Loss_Semi per Pixel')
                    plt.subplot(133)
                    plt.scatter(dsx, sx, s=0.1)
                    plt.xlabel('D(S(X))')
                    plt.ylabel('S(x)')
                    plt.savefig('/home/eungyo/AdvSemiSeg/plot/'  + str(i_iter) + '.png')
                    """

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:

                    semi_gt = pred.data.cpu().numpy().argmax(
                        axis=1
                    )  # pred=S(X) ((8,21,321,321)), semi_gt is not one-hot, 8,321,321
                    #(8, 321, 321)

                    if not USECALI:
                        semi_ignore_mask = (
                            D_out_sigmoid < args.mask_T
                        )  # both (8,321,321) 0~1threshold!, numpy
                        semi_gt[
                            semi_ignore_mask] = 255  # Yhat, ignore pixel becomes 255
                        semi_ratio = 1.0 - float(semi_ignore_mask.sum(
                        )) / semi_ignore_mask.size  # ignored pixels / H*W
                        print('semi ratio: {:.4f}'.format(semi_ratio))

                        if semi_ratio == 0.0:
                            loss_semi_value += 0
                        else:
                            semi_gt = torch.FloatTensor(semi_gt)
                            confidence = torch.FloatTensor(
                                D_out_sigmoid)  ## added, only pred is on cuda
                            loss_semi = args.lambda_semi * weighted_loss_calc(
                                pred, semi_gt, args.gpu, confidence)

                    else:
                        semi_ratio = 1
                        semi_gt = (torch.FloatTensor(semi_gt))  # (8,321,321)
                        confidence = torch.FloatTensor(
                            F.sigmoid(
                                model_cali.temperature_scale(D_out.view(
                                    -1))).data.cpu().numpy())  # (8*321*321,)
                        loss_semi = args.lambda_semi * calibrated_loss_calc(
                            pred, semi_gt, args.gpu, confidence, accuracies,
                            n_bin
                        )  #  L_semi = Yhat * log(S(X)) # loss_calc(pred, semi_gt, args.gpu)
                        # pred(8,21,321,321)

                    if semi_ratio != 0:
                        loss_semi = loss_semi / args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy(
                        ) / args.lambda_semi

                        if args.method == 'vatent' or args.method == 'vat':
                            #v_loss = vat_loss(model, images, pred, eps=args.epsilon[i])  # R_vadv
                            weighted_v_loss = weighted_vat_loss(
                                model,
                                images,
                                pred,
                                confidence,
                                eps=args.epsilon)

                            if args.method == 'vatent':
                                #v_loss += entropy_loss(pred)  # R_cent (conditional entropy loss)
                                weighted_v_loss += weighted_entropy_loss(
                                    pred, confidence)

                            v_loss = weighted_v_loss / args.iter_size
                            loss_vat_value += v_loss.data.cpu().numpy()
                            loss_semi_adv += args.alpha * v_loss

                            loss_vat_sum += loss_vat_value
                            if i_iter % 100 == 0 and sub_i == 4:
                                l_vat_sum = loss_vat_sum / 100
                                if i_iter == 0:
                                    l_vat_sum = l_vat_sum * 100
                                loss_vat_sum = 0

                        loss_semi += loss_semi_adv
                        loss_semi.backward(
                        )  # 0.001 L_adv + 0.1 L_semi, backward == back propagation

            else:
                loss_semi = None
                loss_semi_adv = None

            ###########train with source (labeled data)############### L_ce + 0.01 * L_adv

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)  # safe coding
                _, batch = trainloader_iter.next()  #counter, batch

            images, labels, _, _ = batch  # also get labels images(8,321,321)
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (
                labels.numpy() == 255
            )  # ignored pixels == 255 >> 1, yes ignored mask for labeled data

            pred = interp(model(images))  # S(X), 8,21,321,321
            loss_seg = loss_calc(pred, labels,
                                 args.gpu)  # -Y*logS(X)= L_ce, not detached

            if USED:
                softsx = F.softmax(pred, dim=1)
                D_out = interp(model_D(softsx))  # D(S(X)), L_adv

                loss_adv_pred = bce_loss(
                    D_out, make_D_label(
                        gt_label,
                        ignore_mask))  # both  8,1,321,321, gt_label = 1
                # L_adv =  -log(D(S(X)), make_D_label is all 1 except ignored_region

                loss = loss_seg + args.lambda_adv_pred * loss_adv_pred
                if USECALI:
                    if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                        ) and i_iter >= args.semi_start_adv:
                        with torch.no_grad():
                            _, prediction = torch.max(softsx, 1)
                            labels_mask = (
                                (labels > 0) *
                                (labels != 255)) | (prediction.data.cpu() > 0)
                            labels = labels[labels_mask]
                            prediction = prediction[labels_mask]
                            fake_mask = (labels.data.cpu().numpy() !=
                                         prediction.data.cpu().numpy())
                            real_label = make_conf_label(
                                1, fake_mask
                            )  # (10*321*321, ) 0 or 1 (fake or real)

                            logits = D_out.squeeze(dim=1)
                            logits = logits[labels_mask]
                            logits_list.append(logits)  # initialize
                            labels_list.append(real_label)

                        if (i_iter * args.iter_size * args.batch_size + sub_i +
                                1) % train_dataset_size == 0:
                            logits = torch.cat(logits_list).cuda(
                            )  # overall 5000 images in val,  #logits >> 5000,100, (1464*321*321,)
                            labels = torch.cat(labels_list).cuda()
                            before_temperature_nll = nll_criterion(
                                logits, labels).item()  ####modify
                            before_temperature_ece, _, _ = ece_criterion(
                                logits, labels)  # (1464*321*321,)
                            before_temperature_ece = before_temperature_ece.item(
                            )
                            print('Before temperature - NLL: %.3f, ECE: %.3f' %
                                  (before_temperature_nll,
                                   before_temperature_ece))

                            def eval():
                                loss_cali = nll_criterion(
                                    model_cali.temperature_scale(logits),
                                    labels)
                                loss_cali.backward()
                                return loss_cali

                            optimizer_cali.step(
                                eval)  # just one backward >> not 50 iterations
                            after_temperature_nll = nll_criterion(
                                model_cali.temperature_scale(logits),
                                labels).item()
                            after_temperature_ece, accuracies, n_bin = ece_criterion(
                                model_cali.temperature_scale(logits), labels)
                            after_temperature_ece = after_temperature_ece.item(
                            )
                            print('Optimal temperature: %.3f' %
                                  model_cali.temperature.item())
                            print(
                                'After temperature - NLL: %.3f, ECE: %.3f' %
                                (after_temperature_nll, after_temperature_ece))

                            logits_list = []
                            labels_list = []

            else:
                loss = loss_seg

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_sum += loss_seg / args.iter_size
            if USED:
                loss_adv_sum += loss_adv_pred
            if i_iter % 100 == 0 and sub_i == 4:
                l_seg_sum = loss_seg_sum / 100
                if USED:
                    l_adv_sum = loss_adv_sum / 100
                if i_iter == 0:
                    l_seg_sum = l_seg_sum * 100
                    l_adv_sum = l_adv_sum * 100
                loss_seg_sum = 0
                loss_adv_sum = 0

            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            if USED:
                loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
                ) / args.iter_size

            ##################### train D!!!###########################
            ###########################################################
            # bring back requires_grad
            if USED:
                for param in model_D.parameters():
                    param.requires_grad = True  # before False.

                ############# train with pred S(X)############# labeled + unlabeled
                pred = pred.detach(
                )  #orginally only use labeled data, freeze S(X) when train D,

                # We do train D with the unlabeled data. But the difference is quite small

                if args.D_remain:  #default true
                    pred = torch.cat(
                        (pred, pred_remain), 0
                    )  # pred_remain(unlabeled S(x)) is detached  16,21,321,321
                    ignore_mask = np.concatenate(
                        (ignore_mask, ignore_mask_remain),
                        axis=0)  # 16,321,321

                D_out = interp(
                    model_D(F.softmax(pred, dim=1))
                )  # D(S(X)) 16,1,321,321  # softmax(pred,dim=1) for 0.4, not nessesary

                loss_D = bce_loss(D_out,
                                  make_D_label(pred_label,
                                               ignore_mask))  # pred_label = 0

                # -log(1-D(S(X)))
                loss_D = loss_D / args.iter_size / 2  # iter_size = 1, /2 because there is G and D
                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()

                ################## train with gt################### only labeled
                #VOCGT and VOCdataset can be reduced to one dataset in this repo.
                # get gt labels Y
                #print "before train gt"
                try:
                    print(trainloader_gt_iter.next())  # len 732
                    _, batch = trainloader_gt_iter.next()

                except:
                    trainloader_gt_iter = enumerate(trainloader_gt)
                    _, batch = trainloader_gt_iter.next()
                #print "train with gt?"
                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)  #one_hot
                ignore_mask_gt = (labels_gt.numpy() == 255
                                  )  # same as ignore_mask (8,321,321)
                #print "finish"
                D_out = interp(model_D(D_gt_v))  # D(Y)
                loss_D = bce_loss(D_out,
                                  make_D_label(gt_label,
                                               ignore_mask_gt))  # log(D(Y))
                loss_D = loss_D / args.iter_size / 2

                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()

        if USED:
            optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))  #snapshot
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.6f}, loss_semi = {5:.6f}, loss_semi_adv = {6:.3f}, loss_vat = {7: .5f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value, loss_vat_value))
        #                                        L_ce             L_adv for labeled      L_D                L_semi              L_adv for unlabeled
        #loss_adv should be inversely proportional to the loss_D if they are seeing the same data.
        # loss_adv_p is essentially the inverse loss of loss_D. We expect them to achieve a good balance during the adversarial training
        # loss_D is around 0.2-0.5   >> good
        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            #torch.save(state, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth.tar'))
            #torch.save(state_D, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth.tar'))
            break

        if i_iter % 100 == 0 and sub_i == 4:  #loss_seg_value
            wdata = "iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.6f}, loss_semi = {5:.8f}, loss_semi_adv = {6:.3f}, l_vat_sum = {7: .5f}, loss_label = {8: .4}\n".format(
                i_iter, args.num_steps, l_seg_sum, l_adv_sum, loss_D_value,
                loss_semi_value, loss_semi_adv_value, l_vat_sum,
                l_seg_sum + 0.01 * l_adv_sum)
            #wdata2 = "{0:8d} {1:s} {2:s} {3:s} {4:s} {5:s} {6:s} {7:s} {8:s}\n".format(i_iter,str(model.coeff[0])[8:14],str(model.coeff[1])[8:14],str(model.coeff[2])[8:14],str(model.coeff[3])[8:14],str(model.coeff[4])[8:14],str(model.coeff[5])[8:14],str(model.coeff[6])[8:14],str(model.coeff[7])[8:14])
            if i_iter == 0:
                f2 = open("/home/eungyo/AdvSemiSeg/snapshots/log.txt", 'w')
                f2.write(wdata)
                f2.close()
                #f3 = open("/home/eungyo/AdvSemiSeg/snapshots/coeff.txt", 'w')
                #f3.write(wdata2)
                #f3.close()
            else:
                f1 = open("/home/eungyo/AdvSemiSeg/snapshots/log.txt", 'a')
                f1.write(wdata)
                f1.close()
                #f4 = open("/home/eungyo/AdvSemiSeg/snapshots/coeff.txt", 'a')
                #f4.write(wdata2)
                #f4.close()

        if i_iter % args.save_pred_every == 0 and i_iter != 0:  # 5000
            print('taking snapshot ...')
            #state = {'epoch':i_iter, 'state_dict':model.state_dict(),'optim_dict':optimizer.state_dict()}
            #state_D = {'epoch':i_iter, 'state_dict': model_D.state_dict(), 'optim_dict': optimizer_D.state_dict()}
            #torch.save(state, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth.tar'))
            #torch.save(state_D, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth.tar'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Beispiel #10
0
def main():
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    np.random.seed(args.random_seed)

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # load dataset
    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        # labeled data
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        # unlabeled data
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # loss/bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    for i_iter in range(args.num_steps):
        loss_seg_value = 0
        loss_unlabeled_seg_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):
            # train Segmentation
            # train with labeled images
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            pred = interp(model(images))

            # computing loss
            loss_seg = loss_calc(pred, labels, args.gpu)

            # proper normalization
            loss = loss_seg / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size

            # train with unlabeled
            if args.lambda_semi > 0 and i_iter >= args.semi_start:
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                semi_gt = torch.FloatTensor(semi_gt)

                loss_unlabeled_seg = args.lambda_semi * loss_calc(
                    pred, semi_gt, args.gpu)
                loss_unlabeled_seg = loss_unlabeled_seg / args.iter_size
                loss_unlabeled_seg.backward()

                loss_unlabeled_seg_value += loss_unlabeled_seg.data.cpu(
                ).numpy() / args.lambda_semi

            else:
                if args.lambda_semi > 0 and i_iter < args.semi_start:
                    loss_unlabeled_seg_value = 0

                else:
                    loss_unlabeled_seg_value = None

        optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_unlabeled_seg = {3:.3f} '
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_unlabeled_seg_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' +
                    str(args.lambda_semi) + '_' + str(args.random_seed) +
                    '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'VOC_' + str(i_iter) + '_' + str(args.lambda_semi) + '_' +
                    str(args.random_seed) + '.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Beispiel #11
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    log_dir = args.save
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
    log_dir = os.path.join(log_dir, exp_name)
    writer = SummaryWriter(log_dir)

    # testloader = data.DataLoader(SyntheticSmokeTrain(args={}, dataset_limit=-1, #args.num_steps * args.iter_size * args.batch_size,
    #                 image_shape=(360,640), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)

    testloader = data.DataLoader(SmokeDataset(image_size=(640, 360),
                                              dataset_mean=IMG_MEAN),
                                 batch_size=1,
                                 shuffle=True,
                                 pin_memory=True)
    # testloader = data.DataLoader(SimpleSmokeTrain(args = {}, image_size=(640,360), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)
    # testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
    # batch_size=1, shuffle=False, pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)

    count = 0
    iou_sum_fg = 0
    iou_count_fg = 0

    iou_sum_bg = 0
    iou_count_bg = 0

    for index, batch in enumerate(testloader):
        if (index + 1) % 100 == 0:
            print('%d processd' % index)
            # print("Processed {}/{}".format(index, len(testloader)))

        # if count > 5:
        #     break
        image, label, name = batch
        if args.model == 'DeeplabMulti':
            with torch.no_grad():
                output1, output2 = model(Variable(image).cuda(gpu0))
            # print(output1.shape)
            # print(output2.shape)
            output = interp(output2).cpu()
            orig_output = output.detach().clone()
            output = output.data[0].numpy()
            # output = (output > 0.5).astype(np.uint8)*255
            # print(np.all(output==0), np.all(output==255))
            # print(np.min(output), np.max(output))

        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            with torch.no_grad():
                output = model(Variable(image).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        classes_seen = set(output.ravel().tolist())
        # print(classes_seen)
        # print(output.shape, name[0])
        output_col = colorize_mask(output)
        output = Image.fromarray(output)
        # print("name", name)
        name = name[0]
        # name = name[0].split('/')[-1]

        if len(classes_seen) > 1:
            count += 1
            print(classes_seen)
            print(Counter(np.asarray(output).ravel()))
            image = image.squeeze()
            for c in range(3):
                image[c, :, :] += IMG_MEAN[c]
                # image2[c,:,:] += IMG_MEAN[2-c]
            image = (image - image.min()) / (image.max() - image.min())
            image = image[[2, 1, 0], :, :]
            print(image.shape, image.min(), image.max())
            output.save(os.path.join(args.save, name + '.png'))
            output_col.save(os.path.join(args.save, name + '_color.png'))
            # output.save('%s/%s.png' % (args.save, name))
            # output_col.save('%s/%s_color.png' % (args.save, name))#.split('.')[0]))

            output_argmaxs = torch.argmax(orig_output.squeeze(), dim=0)
            mask1 = (output_argmaxs == 0).float() * 255
            label = label.squeeze()

            iou_fg = iou_pytorch(mask1, label)
            print("foreground IoU", iou_fg)
            iou_sum_fg += iou_fg
            iou_count_fg += 1

            mask2 = (output_argmaxs > 0).float() * 255
            label2 = label.max() - label

            iou_bg = iou_pytorch(mask2, label2)
            print("IoU for background: ", iou_bg)
            iou_sum_bg += iou_bg
            iou_count_bg += 1

            writer.add_images(f'input_images',
                              tf.resize(image[[2, 1, 0]], [1080, 1920]),
                              index,
                              dataformats='CHW')

            print("shape of label", label.shape)
            label_reshaped = tf.resize(label.unsqueeze(0),
                                       [1080, 1920]).squeeze()
            print("label reshaped: ", label_reshaped.shape)
            writer.add_images(f'labels',
                              label_reshaped,
                              index,
                              dataformats='HW')
            writer.add_images(
                f'output/1',
                255 - np.asarray(tf.resize(output, [1080, 1920])) * 255,
                index,
                dataformats='HW')
            # writer.add_images(f'output/1',np.asarray(output)*255, index,dataformats='HW')
            # writer.add_images(f'output/2',np.asarray(output_col), index, dataformats='HW')
            writer.add_scalar(f'iou/smoke', iou_fg, index)
            writer.add_scalar(f'iou/background', iou_bg, index)
            writer.add_scalar(f'iou/mean', (iou_bg + iou_fg) / 2, index)
            writer.flush()

    if iou_count_fg > 0:
        print("Mean IoU, foreground: {}".format(iou_sum_fg / iou_count_fg))
        print("Mean IoU, background: {}".format(iou_sum_bg / iou_count_bg))
        print("Mean IoU, averaged over classes: {}".format(
            (iou_sum_fg + iou_sum_bg) / (iou_count_fg + iou_count_bg)))
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    #model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(512, 1024),
                                                   resize_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)
    scale = 0.9
    testloader3 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')

    for index, img_data in enumerate(zip(testloader, testloader2,
                                         testloader3)):
        batch, batch2, batch3 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        #image3, _, _, name3 = batch3

        inputs = image.cuda()
        inputs2 = image2.cuda()
        #inputs3 = Variable(image3).cuda()
        print('\r>>>>Extracting feature...%03d/%03d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = output1, output2
                #output_batch = interp(sm(output1))
                #output_batch = interp(sm(output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = heatmap_output1 + output1, heatmap_output2 + output2
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
                heatmap_batch = torch.sum(kl_distance(log_sm(heatmap_output1),
                                                      sm(heatmap_output2)),
                                          dim=1)
                heatmap_batch = torch.log(
                    1 + 10 * heatmap_batch)  # for visualization
                heatmap_batch = heatmap_batch.cpu().data.numpy()

                #output1, output2 = model(inputs3)
                #output_batch += interp(sm(0.5* output1 + output2)).cpu().data.numpy()
                #output1, output2 = model(fliplr(inputs3))
                #output1, output2 = fliplr(output1), fliplr(output2)
                #output_batch += interp(sm(0.5 * output1 + output2)).cpu().data.numpy()
                #del output1, output2, inputs3
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        scoremap_batch = np.asarray(np.max(output_batch, axis=3))
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        output_iterator = []
        heatmap_iterator = []
        scoremap_iterator = []

        for i in range(output_batch.shape[0]):
            output_iterator.append(output_batch[i, :, :])
            heatmap_iterator.append(heatmap_batch[i, :, :] /
                                    np.max(heatmap_batch[i, :, :]))
            scoremap_iterator.append(1 - scoremap_batch[i, :, :] /
                                     np.max(scoremap_batch[i, :, :]))
            name_tmp = name[i].split('/')[-1]
            name[i] = '%s/%s' % (args.save, name_tmp)
        with Pool(4) as p:
            p.map(save, zip(output_iterator, name))
            p.map(save_heatmap, zip(heatmap_iterator, name))
            p.map(save_scoremap, zip(scoremap_iterator, name))

        del output_batch

    return args.save
Beispiel #13
0
def main():
    # 将参数的input_size 映射到整数,并赋值,从字符串转换到整数二元组
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = False
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    # 确保模型中参数的格式与要加载的参数相同
    # 返回一个字典,保存着module的所有状态(state);parameters和persistent buffers都会包含在字典中,字典的key就是parameter和buffer的 names。
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        # print (name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            # print('copy {}'.format(name))
    model.load_state_dict(new_params)

    # 设置为训练模式
    model.train()
    cudnn.benchmark = True

    model.cuda(gpu)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))  # ?
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'),
                         'wb'))  # 写入文件

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear')  # ???

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):
        print("Iter:", i_iter)
        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if args.lambda_semi > 0 and i_iter >= args.semi_start:
                try:
                    _, batch = next(trainloader_remain_iter)
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = next(trainloader_remain_iter)

                # only access to img
                images, _, _, _ = batch

                images = Variable(images).cuda(gpu)
                # images = Variable(images).cpu()

                pred = interp(model(images))
                D_out = interp(model_D(F.softmax(pred)))

                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)

                # produce ignore mask
                semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                semi_gt[semi_ignore_mask] = 255

                semi_ratio = 1.0 - float(
                    semi_ignore_mask.sum()) / semi_ignore_mask.size
                print('semi ratio: {:.4f}'.format(semi_ratio))

                if semi_ratio == 0.0:
                    loss_semi_value += 0
                else:
                    semi_gt = torch.FloatTensor(semi_gt)

                    loss_semi = args.lambda_semi * loss_calc(
                        pred, semi_gt, args.gpu)
                    loss_semi = loss_semi / args.iter_size
                    loss_semi.backward()
                    loss_semi_value += loss_semi.data.cpu().numpy(
                    )[0] / args.lambda_semi
            else:
                loss_semi = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _ = batch

            images = Variable(images).cuda(gpu)
            # images = Variable(images).cpu()

            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            )[0] / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            # D_gt_v = Variable(one_hot(labels_gt)).cpu()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Beispiel #14
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)


    model = Res_Deeplab(num_classes=args.num_classes)

    # if args.pretrained_model != None:
    #     args.restore_from = pretrianed_models_dict[args.pretrained_model]
    #
    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)
    #model.load_state_dict(saved_state_dict)

    model = Res_Deeplab(num_classes=args.num_classes)
    #model.load_state_dict(torch.load('/data/wyc/AdvSemiSeg/snapshots/VOC_15000.pth'))
    state_dict=torch.load('/data1/wyc/AdvSemiSeg/snapshots/VOC_t_concat_pred_img_15000.pth')
    from model.discriminator_pred_concat_img import FCDiscriminator

    model_D = FCDiscriminator(num_classes=args.num_classes)

    state_dict_d = torch.load('/data1/wyc/AdvSemiSeg/snapshots/VOC_t_concat_pred_img_15000_D.pth')


    # original saved file with DataParallel

    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    # load params
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print (name)
        if name in new_state_dict and param.size() == new_state_dict[name].size():
            new_params[name].copy_(new_state_dict[name])
            print('copy {}'.format(name))

    model.load_state_dict(new_params)

    model.eval()
    model.cuda()

    new_state_dict_d = OrderedDict()
    for k, v in state_dict_d.items():
        name = k[7:]  # remove `module.`
        new_state_dict_d[name] = v

    new_params_d = model_D.state_dict().copy()
    for name, param in new_params_d.items():

        print (name)
        if name in new_state_dict_d and param.size() == new_state_dict_d[name].size():
            new_params_d[name].copy_(new_state_dict_d[name])
            print('copy {}'.format(name))

    model_D.load_state_dict(new_params_d)

    model_D.eval()
    model_D.cuda()

    testloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False),
                                    batch_size=1, shuffle=False, pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(505, 505), mode='bilinear')
    data_list = []

    colorize = VOCColorize()

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd'%(index))
        image, label, size, name = batch
        size = size[0].numpy()
        output = model(Variable(image, volatile=True).cuda())
        image_d=Variable(image, volatile=True).cuda()
        output=interp(output)
        output_dout = output.clone()
        output_pred = F.softmax(output, dim=1).cpu().data[0].numpy()
        label2=label[0].numpy()
        output = output.cpu().data[0].numpy()
        output = output[:,:size[0],:size[1]]
        gt = np.asarray(label[0].numpy()[:size[0],:size[1]], dtype=np.int)
        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.int)

        #"""pred result"""
        filename = os.path.join(args.save_dir, '{}.png'.format(name[0]))
        color_file = Image.fromarray(colorize(output).transpose(1, 2, 0), 'RGB')
        color_file.save(filename)


        #"""the area of the pred which is wrong"""
        output_mistake=np.zeros(output.shape)
        semi_ignore_mask_correct = (output == gt)
        #semi_ignore_mask_255=(gt==255)
        output_mistake[semi_ignore_mask_correct] = 255
        #output_mistake[semi_ignore_mask_255] = 255
        output_mistake = np.expand_dims(output_mistake, axis=2)
        filename2 = os.path.join('/data1/wyc/AdvSemiSeg/pred_mis/', '{}.png'.format(name[0]))
        cv2.imwrite(filename2, output_mistake)



        #"""dis confidence map decide line of pred map"""
        D_out = interp(model_D(torch.cat([F.softmax(output_dout, dim=1),F.sigmoid(image_d)],1)))#67
        D_out_sigmoid = (F.sigmoid(D_out).data[0].cpu().numpy())
        D_out_sigmoid = D_out_sigmoid[:, :size[0], :size[1]]
        semi_ignore_mask_dout0 = (D_out_sigmoid < 0.0001)
        semi_ignore_mask_dout255 = (D_out_sigmoid >= 0.0001)

        D_out_sigmoid[semi_ignore_mask_dout0] = 0
        D_out_sigmoid[semi_ignore_mask_dout255] = 255
        filename2 = os.path.join('/data1/wyc/AdvSemiSeg/confidence_line/', '{}.png'.format(name[0]))#0 black 255 white
        cv2.imwrite(filename2,D_out_sigmoid.transpose(1, 2, 0))


        #""""pred max decide line of pred map"""
        # id2 = np.argmax(output_pred, axis=0)
        # map=np.zeros([1,id2.shape[0],id2.shape[1]])
        # for i in range(id2.shape[0]):
        #     for j in range(id2.shape[1]):
        #         map[0][i][j]=output_pred[id2[i][j]][i][j]
        # semi_ignore_mask2 = (map < 0.999999)
        # semi_ignore_mask3 = (map >= 0.999999)
        # map[semi_ignore_mask2] = 0
        # map[semi_ignore_mask3] = 255
        # map = map[:, :size[0], :size[1]]
        # filename2 = os.path.join('/data1/wyc/AdvSemiSeg/pred_line/', '{}.png'.format(name[0]))#0 black 255 white
        # cv2.imwrite(filename2,map.transpose(1, 2, 0))





        data_list.append([gt.flatten(), output.flatten()])

    filename = os.path.join(args.save_dir, 'result.txt')
    get_iou(data_list, args.num_classes, filename)
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    nyu_nyu_dict = {11:255, 13:255, 15:255, 17:255, 19:255, 20:255, 21: 255, 23: 255, 
            24:255, 25:255, 26:255, 27:255, 28:255, 29:255, 31:255, 32:255, 33:255}
    nyu_nyu_map = lambda x: nyu_nyu_dict.get(x+1,x)
    nyu_nyu_map = np.vectorize(nyu_nyu_map)
    args.nyu_nyu_map = nyu_nyu_map
    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    metrics = StreamSegMetrics(args.num_classes)
    metrics_remap = StreamSegMetrics(args.num_classes)
    ignore_label = 255
    value_scale = 255 
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    val_transform = transforms.Compose([
	    # et.ExtResize( 512 ),
	    transforms.Crop([args.height+1, args.width+1], crop_type='center', padding=IMG_MEAN, ignore_label=ignore_label),
	    transforms.ToTensor(),
	    transforms.Normalize(mean=IMG_MEAN,
	    	    std=[1, 1, 1]),
	])
    val_dst = NYU(root=args.data_dir, opt=args,
			 split='val', transform=val_transform,
			 imWidth = args.width, imHeight = args.height, phase="TEST",
			 randomize = False)
    print("Dset Length {}".format(len(val_dst)))
    testloader = data.DataLoader(val_dst,
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(args.height+1, args.width+1), mode='bilinear', align_corners=True)
    metrics.reset()
    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, targets, name = batch
        image = image.to(device)
        print(index)
        if args.model == 'DeeplabMulti':
            output1, output2 = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()
        targets = targets.cpu().numpy()
        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        preds = output[None,:,:]
        #input_ = image.cpu().numpy()[0].transpose(1,2,0) + np.array(IMG_MEAN)
        metrics.update(targets, preds)
        targets = args.nyu_nyu_map(targets)
        preds = args.nyu_nyu_map(preds)
        metrics_remap.update(targets,preds)
        #input_ = Image.fromarray(input_.astype(np.uint8))
        #output_col = colorize_mask(output)
        #output = Image.fromarray(output)
        
        #name = name[0].split('/')[-1]
        #input_.save('%s/%s' % (args.save, name))
        #output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
    print(metrics.get_results())
    print(metrics_remap.get_results())
Beispiel #16
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        _, output2, features = model(image)
        output = output2.cpu().data[0].numpy()

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        features = features.cpu().data[0].numpy()

        #### tsne plot ####
        gt_dir = '/project/AdaptSegNet/data/Cityscapes/data/gtFine/val/'
        gt_file = name[0].split('/')[-1]
        gt_file = gt_file.replace('leftImg8bit', 'gtFine_labelIds')
        
        if 'frankfurt' in gt_file:
            gt_file = gt_dir + 'frankfurt/' + gt_file
        elif 'lindau' in gt_file:
            gt_file = gt_dir + 'lindau/' + gt_file
        elif 'munster' in gt_file:
            gt_file = gt_dir + 'munster/' + gt_file

        """
        json_dir = '/project/AdaptSegNet/dataset/cityscapes_list/'
        with open(join(json_dir, 'info.json'), 'r') as fp:
          info = json.load(fp)
        num_classes = np.int(info['classes'])
        print('Num classes', num_classes)
        name_classes = np.array(info['label'], dtype=np.str)
        mapping = np.array(info['label2train'], dtype=np.int)"""

        labels = np.array(Image.open(gt_file))
        labels = label_mask(labels).astype(np.float64)
        labels = skimage.transform.resize(labels, [65, 129])
        labels = label_mask(labels).astype(np.uint8)

        F, H, W = np.squeeze(features).shape
        features = features.reshape(F, H*W).transpose(1,0)
        labels = labels.reshape(H*W,)
        
        # features = features[1:1000, :]
        # labels = labels[1:1000]
        tsne = TSNE(random_state=RS).fit_transform(features)
        pdb.set_trace()
        scatter_plot(tsne, labels)

        #### tsne plot ####

        pdb.set_trace()
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(robotDataSet(args.data_dir,
                                              args.data_list,
                                              crop_size=(960, 1280),
                                              resize_size=(1280, 960),
                                              mean=IMG_MEAN,
                                              scale=False,
                                              mirror=False,
                                              set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(robotDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(960 * scale), round(1280 * scale)),
        resize_size=(round(1280 * scale), round(960 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(960, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(960, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        print(image.shape)

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        #output_batch[score_batch<3.6] = 255  #3.6 = 4*0.9

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            #save_path = re.replace(save_path, 'leftImg8bit', 'pseudo')
            #print(save_path)
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

    return args.save
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    confidence_path = os.path.join(args.save, 'submit/confidence')
    label_path = os.path.join(args.save, 'submit/labelTrainIds')
    label_invalid_path = os.path.join(args.save,
                                      'submit/labelTrainIds_invalid')
    for path in [confidence_path, label_path, label_invalid_path]:
        if not os.path.exists(path):
            os.makedirs(path)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(DarkZurichDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(h, w),
                                                   resize_size=(w, h),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(DarkZurichDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(h * scale), round(w * scale)),
        resize_size=(round(w * scale), round(h * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1080, 1920),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1080, 1920), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')
    prior = np.load('./utils/prior_all.npy').transpose(
        (2, 0, 1))[np.newaxis, :, :, :]
    prior = torch.from_numpy(prior)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, name = batch
        image2, _, name2 = batch2

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, args.batchsize * len(testloader)),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))

                heatmap_batch = torch.sum(kl_distance(log_sm(output1),
                                                      sm(output2)),
                                          dim=1)

                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                ratio = 0.95
                output_batch = output_batch.cpu() / 4
                # output_batch = output_batch *(ratio + (1 - ratio) * prior)
                output_batch = output_batch.data.numpy()
                heatmap_batch = heatmap_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        threshold = 0.3274
        for i in range(output_batch.shape[0]):
            output_single = output_batch[i, :, :]
            output_col = colorize_mask(output_single)
            output = Image.fromarray(output_single)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

            # heatmap_tmp = heatmap_batch[i,:,:]/np.max(heatmap_batch[i,:,:])
            # fig = plt.figure()
            # plt.axis('off')
            # heatmap = plt.imshow(heatmap_tmp, cmap='viridis')
            # fig.colorbar(heatmap)
            # fig.savefig('%s/%s_heatmap.png' % (save_path, name_tmp.split('.')[0]))

            if args.set == 'test' or args.set == 'val':
                # label
                output.save('%s/%s' % (label_path, name_tmp))
                # label invalid
                output_single[score_batch[i, :, :] < threshold] = 255
                output = Image.fromarray(output_single)
                output.save('%s/%s' % (label_invalid_path, name_tmp))
                # conficence

                confidence = score_batch[i, :, :] * 65535
                confidence = np.asarray(confidence, dtype=np.uint16)
                print(confidence.min(), confidence.max())
                iio.imwrite('%s/%s' % (confidence_path, name_tmp), confidence)

    return args.save
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict, strict=False)

    print(model)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(BDDDataSet(args.data_dir,
                                            args.data_list,
                                            crop_size=(960, 540),
                                            mean=IMG_MEAN,
                                            scale=False,
                                            mirror=False,
                                            set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)  # 960 540

    interp = nn.Upsample(size=(720, 1280), mode='bilinear', align_corners=True)

    if args.save_confidence:
        select = open('list.txt', 'w')
        c_list = []

    for index, batch in enumerate(testloader):
        if index % 10 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        output = model(image)

        if args.save_confidence:
            confidence = get_confidence(output)
            confidence = confidence.cpu().item()
            c_list.append([confidence, name])
            name = name[0].split('/')[-1]
            save_path = '%s/%s_c.txt' % (args.save, name.split('.')[0])
            record = open(save_path, 'w')
            record.write('%.5f' % confidence)
            record.close()
        else:
            name = name[0].split('/')[-1]

        output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        output.save('%s/%s' % (args.save, name[:-4] + '.png'))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))

    def takeFirst(elem):
        return elem[0]

    if args.save_confidence:
        c_list.sort(key=takeFirst, reverse=True)
        length = len(c_list)
        for i in range(length // 3):
            print(c_list[i][0])
            print(c_list[i][1])
            select.write(c_list[i][1][0])
            select.write('\n')
        select.close()

    print(args.save)
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = Res_Deeplab(num_classes=args.num_classes)

    # if args.pretrained_model != None:
    #     args.restore_from = pretrianed_models_dict[args.pretrained_model]
    #
    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)
    #model.load_state_dict(saved_state_dict)

    model = Res_Deeplab(num_classes=args.num_classes)
    #model.load_state_dict(torch.load('/data/wyc/AdvSemiSeg/snapshots/VOC_15000.pth'))#70.7
    state_dict = torch.load(
        '/data1/wyc/AdvSemiSeg/snapshots/VOC_t_baseline_1adv_mul_20000.pth'
    )  #baseline707 adv 709 nadv 705()*2#n adv0.694

    # state_dict = torch.load(
    #     '/home/wyc/VOC_t_baseline_nadv2_20000.pth')  # baseline707 adv 709 nadv 705()*2

    # original saved file with DataParallel

    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    # load params

    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in new_state_dict and param.size(
        ) == new_state_dict[name].size():
            new_params[name].copy_(new_state_dict[name])
            print('copy {}'.format(name))

    model.load_state_dict(new_params)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(VOCDataSet(args.data_dir,
                                            args.data_list,
                                            crop_size=(505, 505),
                                            mean=IMG_MEAN,
                                            scale=False,
                                            mirror=False),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(505, 505),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(505, 505), mode='bilinear')
    data_list = []

    colorize = VOCColorize()

    tag = 0

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % (index))
        image, label, size, name = batch
        size = size[0].numpy()
        output = model(Variable(image, volatile=True).cuda(gpu0))
        pred = interp(output)
        pred01 = F.softmax(pred, dim=1)
        output = interp(output).cpu().data[0].numpy()
        image = Variable(image).cuda()

        pred_re = F.softmax(pred, dim=1).repeat(1, 3, 1, 1)

        indices_1 = torch.index_select(image, 1,
                                       Variable(torch.LongTensor([0])).cuda())
        indices_2 = torch.index_select(image, 1,
                                       Variable(torch.LongTensor([1])).cuda())
        indices_3 = torch.index_select(image, 1,
                                       Variable(torch.LongTensor([2])).cuda())
        img_re = torch.cat([
            indices_1.repeat(1, 21, 1, 1),
            indices_2.repeat(1, 21, 1, 1),
            indices_3.repeat(1, 21, 1, 1),
        ], 1)

        mul_img = pred_re * img_re

        for i_l in range(label.shape[0]):
            label_set = np.unique(label[i_l]).tolist()
            for ls in label_set:
                if ls != 0 and ls != 255:
                    ls = int(ls)

                    img_p = torch.cat([
                        mul_img[i_l][ls].unsqueeze(0).unsqueeze(0),
                        mul_img[i_l][ls + 21].unsqueeze(0).unsqueeze(0),
                        mul_img[i_l][ls + 21 + 21].unsqueeze(0).unsqueeze(0)
                    ], 1)

                    imgs = img_p.squeeze()
                    imgs = imgs.transpose(0, 1)
                    imgs = imgs.transpose(1, 2)
                    imgs = imgs.data.cpu().numpy()

                    img_ori = image[0]
                    img_ori = img_ori.squeeze()
                    img_ori = img_ori.transpose(0, 1)
                    img_ori = img_ori.transpose(1, 2)
                    img_ori = img_ori.data.cpu().numpy()

                    pred_ori = pred01[0][ls]
                    pred_ori = pred_ori.data.cpu().numpy()
                    pred_0 = pred_ori.copy()

                    pred_ori = pred_ori

                    size = pred_ori.shape
                    color_image = np.zeros((3, size[0], size[1]),
                                           dtype=np.uint8)

                    for i in range(size[0]):
                        for j in range(size[1]):
                            if pred_0[i][j] > 0.995:
                                color_image[0][i][j] = 0
                                color_image[1][i][j] = 255
                                color_image[2][i][j] = 0
                            elif pred_0[i][j] > 0.9:
                                color_image[0][i][j] = 255
                                color_image[1][i][j] = 0
                                color_image[2][i][j] = 0
                            elif pred_0[i][j] > 0.7:
                                color_image[0][i][j] = 0
                                color_image[1][i][j] = 0
                                color_image[2][i][j] = 255

                    color_image = color_image.transpose((1, 2, 0))

                    # print pred_ori.shape

                    cv2.imwrite(
                        osp.join('/data1/wyc/AdvSemiSeg/vis/img_pred',
                                 name[0] + '.png'), imgs)
                    cv2.imwrite(
                        osp.join('/data1/wyc/AdvSemiSeg/vis/image',
                                 name[0] + '.png'), img_ori)
                    cv2.imwrite(
                        osp.join('/data1/wyc/AdvSemiSeg/vis/pred',
                                 name[0] + '.png'), color_image)

        output = output[:, :size[0], :size[1]]
        gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int)

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.int)

        filename = os.path.join(args.save_dir, '{}.png'.format(name[0]))
        color_file = Image.fromarray(
            colorize(output).transpose(1, 2, 0), 'RGB')
        color_file.save(filename)

        # show_all(gt, output)
        data_list.append([gt.flatten(), output.flatten()])

    filename = os.path.join(args.save_dir, 'result.txt')
    get_iou(data_list, args.num_classes, filename)
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes,
                             train_bn=False,
                             norm_style='in')
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda()

    testloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                             args.data_list,
                                             crop_size=(640, 1280),
                                             resize_size=(1280, 640),
                                             mean=IMG_MEAN,
                                             scale=False,
                                             mirror=False),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, batch in enumerate(testloader):
        if (index * batchsize) % 100 == 0:
            print('%d processd' % (index * batchsize))
        image, _, _, name = batch
        print(image.shape)

        inputs = Variable(image).cuda()
        if args.model == 'DeeplabMulti':
            output1, output2 = model(inputs)
            output_batch = interp(sm(0.5 * output1 +
                                     output2)).cpu().data.numpy()
            #output1, output2 = model(fliplr(inputs))
            #output2 = fliplr(output2)
            #output_batch += interp(output2).cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            output.save('%s/%s' % (args.save, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (args.save, name_tmp.split('.')[0]))

    return args.save
def main():
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    np.random.seed(args.random_seed)

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)
    model.train()
    model.cuda(args.gpu)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # load dataset
    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        # labeled data
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        # unlabeled data
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)
    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):
        loss_seg_value = 0
        loss_adv_pred_value = 0

        loss_D_value = 0
        loss_D_ul_value = 0

        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # creating 2nd discriminator as a copy of the 1st one
        if i_iter == args.discr_split:
            model_D_ul = FCDiscriminator(num_classes=args.num_classes)
            model_D_ul.load_state_dict(net_D.state_dict())
            model_D_ul.train()
            model_D_ul.cuda(args.gpu)

            optimizer_D_ul = optim.Adam(model_D_ul.parameters(),
                                        lr=args.learning_rate_D,
                                        betas=(0.9, 0.99))

        # start training 2nd discriminator after specified number of steps
        if i_iter >= args.discr_split:
            optimizer_D_ul.zero_grad()
            adjust_learning_rate_D(optimizer_D_ul, i_iter)

        for sub_i in range(args.iter_size):
            # train Segmentation

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # don't accumulate grads in D_ul, in case split has already been made
            if i_iter >= args.discr_split:
                for param in model_D_ul.parameters():
                    param.requires_grad = False

            # do semi-supervised training first
            if args.lambda_semi_adv > 0 and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                pred_remain = pred.detach()

                # choose discriminator depending on the iteration
                if i_iter >= args.discr_split:
                    D_out = interp(model_D_ul(F.softmax(pred)))
                else:
                    D_out = interp(model_D(F.softmax(pred)))

                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)
                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)

                # adversarial loss
                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label, ignore_mask_remain,
                                        args.gpu))
                loss_semi_adv = loss_semi_adv / args.iter_size

                # true loss value without multiplier
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv
                loss_semi_adv.backward()

            else:

                loss_semi = None
                loss_semi_adv = None

            # train with labeled images
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            pred = interp(model(images))
            D_out = interp(model_D(F.softmax(pred)))

            # computing loss
            loss_seg = loss_calc(pred, labels, args.gpu)
            loss_adv_pred = bce_loss(
                D_out, make_D_label(gt_label, ignore_mask, args.gpu))
            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D and D_ul

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            if i_iter >= args.discr_split:
                for param in model_D_ul.parameters():
                    param.requires_grad = True

            # train D with pred
            pred = pred.detach()

            # before split, traing D with both labeled and unlabeled
            if args.D_remain and i_iter < args.discr_split and (
                    args.lambda_semi > 0 or args.lambda_semi_adv > 0):
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain),
                                             axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out,
                              make_D_label(pred_label, ignore_mask, args.gpu))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train D_ul with pred on unlabeled
            if i_iter >= args.discr_split and (args.lambda_semi > 0
                                               or args.lambda_semi_adv > 0):
                D_ul_out = interp(model_D_ul(F.softmax(pred_remain)))
                loss_D_ul = bce_loss(
                    D_ul_out,
                    make_D_label(pred_label, ignore_mask_remain, args.gpu))
                loss_D_ul = loss_D_ul / args.iter_size / 2
                loss_D_ul.backward()
                loss_D_ul_value += loss_D_ul.data.cpu().numpy()

            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            images_gt, labels_gt, _, _ = batch
            images_gt = Variable(images_gt).cuda(args.gpu)
            with torch.no_grad():
                pred_l = interp(model(images_gt))

            # train D with gt
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out,
                              make_D_label(gt_label, ignore_mask_gt, args.gpu))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train D_ul with pseudo_gt (gt are substituted for pred)
            if i_iter >= args.discr_split:
                D_ul_out = interp(model_D_ul(F.softmax(pred_l)))
                loss_D_ul = bce_loss(
                    D_ul_out, make_D_label(gt_label, ignore_mask_gt, args.gpu))
                loss_D_ul = loss_D_ul / args.iter_size / 2
                loss_D_ul.backward()
                loss_D_ul_value += loss_D_ul.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()
        if i_iter >= args.discr_split:
            optimizer_D_ul.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_D_ul={5:.3f}, loss_semi = {6:.3f}, loss_semi_adv = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_D_ul_value,
                    loss_semi_value, loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                net.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' +
                    str(args.random_seed) + '.pth'))
            torch.save(
                net_D.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' +
                    str(args.random_seed) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                net.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(i_iter) + '_' +
                    str(args.random_seed) + '.pth'))
            torch.save(
                net_D.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(i_iter) + '_' +
                    str(args.random_seed) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Beispiel #23
0
def evaluate(arch, dataset, ignore_label, restore_from, pretrained_model,
             save_dir, device):
    import argparse
    import scipy
    from scipy import ndimage
    import cv2
    import numpy as np
    import sys
    from collections import OrderedDict
    import os

    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    import torchvision.models as models
    import torch.nn.functional as F
    from torch.utils import data, model_zoo

    from model.deeplab import Res_Deeplab
    from model.unet import unet_resnet50
    from model.deeplabv3 import resnet101_deeplabv3
    from dataset.voc_dataset import VOCDataSet

    from PIL import Image

    import matplotlib.pyplot as plt

    pretrianed_models_dict = {
        'semi0.125':
        'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSemiSegVOC0.125-03c6f81c.pth',
        'semi0.25':
        'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSemiSegVOC0.25-473f8a14.pth',
        'semi0.5':
        'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSemiSegVOC0.5-acf6a654.pth',
        'advFull':
        'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/AdvSegVOCFull-92fbc7ee.pth'
    }

    class VOCColorize(object):
        def __init__(self, n=22):
            self.cmap = color_map(22)
            self.cmap = torch.from_numpy(self.cmap[:n])

        def __call__(self, gray_image):
            size = gray_image.shape
            color_image = np.zeros((3, size[0], size[1]), dtype=np.uint8)

            for label in range(0, len(self.cmap)):
                mask = (label == gray_image)
                color_image[0][mask] = self.cmap[label][0]
                color_image[1][mask] = self.cmap[label][1]
                color_image[2][mask] = self.cmap[label][2]

            # handle void
            mask = (255 == gray_image)
            color_image[0][mask] = color_image[1][mask] = color_image[2][
                mask] = 255

            return color_image

    def color_map(N=256, normalized=False):
        def bitget(byteval, idx):
            return ((byteval & (1 << idx)) != 0)

        dtype = 'float32' if normalized else 'uint8'
        cmap = np.zeros((N, 3), dtype=dtype)
        for i in range(N):
            r = g = b = 0
            c = i
            for j in range(8):
                r = r | (bitget(c, 0) << 7 - j)
                g = g | (bitget(c, 1) << 7 - j)
                b = b | (bitget(c, 2) << 7 - j)
                c = c >> 3

            cmap[i] = np.array([r, g, b])

        cmap = cmap / 255 if normalized else cmap
        return cmap

    def get_iou(data_list,
                class_num,
                ignore_label,
                class_names,
                save_path=None):
        from multiprocessing import Pool
        from utils.evaluation import EvaluatorIoU

        evaluator = EvaluatorIoU(class_num)
        for truth, prediction in data_list:
            evaluator.sample(truth, prediction, ignore_value=ignore_label)

        per_class_iou = evaluator.score()
        mean_iou = per_class_iou.mean()

        for i, (class_name, iou) in enumerate(zip(class_names, per_class_iou)):
            print('class {:2d} {:12} IU {:.2f}'.format(i, class_name, iou))

        print('meanIOU: ' + str(mean_iou) + '\n')
        if save_path:
            with open(save_path, 'w') as f:
                for i, (class_name,
                        iou) in enumerate(zip(class_names, per_class_iou)):
                    f.write('class {:2d} {:12} IU {:.2f}'.format(
                        i, class_name, iou) + '\n')
                f.write('meanIOU: ' + str(mean_iou) + '\n')

    def show_all(gt, pred):
        import matplotlib.pyplot as plt
        from matplotlib import colors
        from mpl_toolkits.axes_grid1 import make_axes_locatable

        fig, axes = plt.subplots(1, 2)
        ax1, ax2 = axes

        colormap = [(0, 0, 0), (0.5, 0, 0), (0, 0.5, 0), (0.5, 0.5, 0),
                    (0, 0, 0.5), (0.5, 0, 0.5), (0, 0.5, 0.5), (0.5, 0.5, 0.5),
                    (0.25, 0, 0), (0.75, 0, 0), (0.25, 0.5, 0), (0.75, 0.5, 0),
                    (0.25, 0, 0.5), (0.75, 0, 0.5), (0.25, 0.5, 0.5),
                    (0.75, 0.5, 0.5), (0, 0.25, 0), (0.5, 0.25, 0),
                    (0, 0.75, 0), (0.5, 0.75, 0), (0, 0.25, 0.5)]
        cmap = colors.ListedColormap(colormap)
        bounds = [
            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
            19, 20, 21
        ]
        norm = colors.BoundaryNorm(bounds, cmap.N)

        ax1.set_title('gt')
        ax1.imshow(gt, cmap=cmap, norm=norm)

        ax2.set_title('pred')
        ax2.imshow(pred, cmap=cmap, norm=norm)

        plt.show()

    torch_device = torch.device(device)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if dataset == 'pascal_aug':
        ds = VOCDataSet()
    else:
        print('Dataset {} not yet supported'.format(dataset))
        return

    if arch == 'deeplab2':
        model = Res_Deeplab(num_classes=ds.num_classes)
    elif arch == 'unet_resnet50':
        model = unet_resnet50(num_classes=ds.num_classes)
    elif arch == 'resnet101_deeplabv3':
        model = resnet101_deeplabv3(num_classes=ds.num_classes)
    else:
        print('Architecture {} not supported'.format(arch))
        return

    if pretrained_model is not None:
        restore_from = pretrianed_models_dict[pretrained_model]

    if restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(restore_from)
    else:
        saved_state_dict = torch.load(restore_from)

    model.load_state_dict(saved_state_dict)

    model.eval()
    model = model.to(torch_device)

    ds_val_xy = ds.val_xy(crop_size=(505, 505),
                          scale=False,
                          mirror=False,
                          mean=model.MEAN,
                          std=model.STD)

    testloader = data.DataLoader(ds_val_xy,
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    data_list = []

    colorize = VOCColorize()

    with torch.no_grad():
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % (index))
            image, label, size, name = batch
            size = size[0].numpy()
            image = torch.tensor(image, dtype=torch.float, device=torch_device)
            output = model(image)
            output = output.cpu().data[0].numpy()

            output = output[:, :size[0], :size[1]]
            gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int)

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.int)

            filename = os.path.join(save_dir, '{}.png'.format(name[0]))
            color_file = Image.fromarray(
                colorize(output).transpose(1, 2, 0), 'RGB')
            color_file.save(filename)

            # show_all(gt, output)
            data_list.append([gt.flatten(), output.flatten()])

        filename = os.path.join(save_dir, 'result.txt')
        get_iou(data_list, ds.num_classes, ignore_label, ds.class_names,
                filename)
Beispiel #24
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()
    gpu0 = args.gpu

    print("Evaluating model")
    print(args.restore_from)
    print("classifier model")
    print(args.restore_from_classifier)
    print("sigmoid threshold")
    print(args.sigmoid_threshold)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = Res_Deeplab(num_classes=args.num_classes)
    model_cls = Res_Deeplab_class(num_classes=args.num_classes,
                                  mode=6,
                                  latent_vars=args.latent_vars)

    saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict, strict=False)

    model.eval()
    model.cuda(gpu0)

    saved_state_dict = torch.load(args.restore_from_classifier)
    model_cls.load_state_dict(saved_state_dict, strict=False)

    model_cls.eval()
    model_cls.cuda(gpu0)

    testloader = data.DataLoader(VOCDataSet(args.data_dir,
                                            args.data_list,
                                            crop_size=(505, 505),
                                            mean=IMG_MEAN,
                                            scale=False,
                                            mirror=False),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(505, 505),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(505, 505), mode='bilinear')
    data_list = []

    colorize = VOCColorize()

    combo_matrix = np.zeros((args.num_classes, args.latent_vars + 1),
                            dtype=np.float32)
    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % (index))
        image, label, size, name = batch
        size = size[0].numpy()
        output = model(Variable(image, volatile=True).cuda(gpu0))
        output = interp(output).cpu().data[0].numpy()

        output = output[:, :size[0], :size[1]]

        cls_pred = F.sigmoid(
            model_cls(Variable(image, volatile=True).cuda(gpu0)))
        cls_pred = cls_pred.cpu().data.numpy()[0]

        gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int)

        #gt_classes = np.unique(gt).tolist()
        for clsID in range(1, args.num_classes):
            if cls_pred[clsID - 1] < args.sigmoid_threshold:
                output[clsID, :, :] = -1000000000

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.int)

        filename = os.path.join(args.save_dir, '{}.png'.format(name[0]))
        color_file = Image.fromarray(
            colorize(output).transpose(1, 2, 0), 'RGB')
        #color_file.save(filename)

        #filename = os.path.join(args.save_dir, '{}_lv.png'.format(name[0]))
        #color_file = Image.fromarray(colorize(output_lv).transpose(1, 2, 0), 'RGB')
        #color_file.save(filename)

        filename_gt = os.path.join(args.save_dir, '{}_gt.png'.format(name[0]))
        color_file_gt = Image.fromarray(colorize(gt).transpose(1, 2, 0), 'RGB')
        #color_file_gt.save(filename_gt)

        # show_all(gt, output)
        data_list.append([gt.flatten(), output.flatten()])

    filename = os.path.join(
        args.save_dir,
        args.restore_from.split('/')[-1][:-4] + '_with_classifier_result.txt')
    confusion_matrix = get_iou(data_list, args.num_classes, filename)
Beispiel #25
0
def main():
    # 将参数的input_size 映射到整数,并赋值,从字符串转换到整数二元组
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = False
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    # 确保模型中参数的格式与要加载的参数相同
    # 返回一个字典,保存着module的所有状态(state);parameters和persistent buffers都会包含在字典中,字典的key就是parameter和buffer的 names。
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        # print (name)
        if name in saved_state_dict and param.size() == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            # print('copy {}'.format(name))
    model.load_state_dict(new_params)

    # 设置为训练模式
    model.train()
    cudnn.benchmark = True

    model.cuda(gpu)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
                               scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
                                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))  # ?
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))  # 写入文件

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')  # ???

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):
        print("Iter:", i_iter)
        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if args.lambda_semi > 0 and i_iter >= args.semi_start:
                try:
                    _, batch = next(trainloader_remain_iter)
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = next(trainloader_remain_iter)

                # only access to img
                images, _, _, _ = batch

                images = Variable(images).cuda(gpu)
                # images = Variable(images).cpu()

                pred = interp(model(images))
                D_out = interp(model_D(F.softmax(pred)))

                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

                # produce ignore mask
                semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                semi_gt[semi_ignore_mask] = 255

                semi_ratio = 1.0 - float(semi_ignore_mask.sum()) / semi_ignore_mask.size
                print('semi ratio: {:.4f}'.format(semi_ratio))

                if semi_ratio == 0.0:
                    loss_semi_value += 0
                else:
                    semi_gt = torch.FloatTensor(semi_gt)

                    loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
                    loss_semi = loss_semi / args.iter_size
                    loss_semi.backward()
                    loss_semi_value += loss_semi.data.cpu().numpy()[0] / args.lambda_semi
            else:
                loss_semi = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _ = batch

            images = Variable(images).cuda(gpu)
            # images = Variable(images).cpu()

            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0] / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            # D_gt_v = Variable(one_hot(labels_gt)).cpu()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}'.format(
                i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
def main():
    seed = 1338
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)
    for files in range(int(args.num_steps_stop / args.save_pred_every)):
        print('Step: ', (files + 1) * args.save_pred_every)
        if SOURCE_ONLY:
            saved_state_dict = torch.load('./snapshots/source_only/GTA5_' +
                                          str((files + 1) *
                                              args.save_pred_every) + '.pth')
        else:
            if args.level == 'single-level':
                saved_state_dict = torch.load(
                    './snapshots/single_level/GTA5_' +
                    str((files + 1) * args.save_pred_every) + '.pth')
            elif args.level == 'multi-level':
                saved_state_dict = torch.load('./snapshots/multi_level/GTA5_' +
                                              str((files + 1) *
                                                  args.save_pred_every) +
                                              '.pth')
            else:
                raise NotImplementedError(
                    'level choice {} is not implemented'.format(args.level))
        ### for running different versions of pytorch
        model_dict = model.state_dict()
        saved_state_dict = {
            k: v
            for k, v in saved_state_dict.items() if k in model_dict
        }
        model_dict.update(saved_state_dict)
        ###
        model.load_state_dict(saved_state_dict)

        device = torch.device("cuda" if not args.cpu else "cpu")
        model = model.to(device)
        if args.multi_gpu:
            model = nn.DataParallel(model)

        model.eval()

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, name = batch
            image = image.to(device)

            if args.model == 'DeeplabMulti':
                output1, output2 = model(image)
                output = interp(output2).cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
                output = model(image)
                output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            if SOURCE_ONLY:
                if not os.path.exists(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every))):
                    os.makedirs(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every)))
                output.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every), name))
                output_col.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every),
                        name.split('.')[0] + '_color.png'))
            else:
                if args.level == 'single-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                elif args.level == 'multi-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                else:
                    raise NotImplementedError(
                        'level choice {} is not implemented'.format(
                            args.level))
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = Res_Deeplab(num_classes=args.num_classes)

    # if args.pretrained_model != None:
    #     args.restore_from = pretrianed_models_dict[args.pretrained_model]
    #
    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)
    #model.load_state_dict(saved_state_dict)

    model = Res_Deeplab(num_classes=args.num_classes)
    #model.load_state_dict(torch.load('/data/wyc/AdvSemiSeg/snapshots/VOC_15000.pth'))#70.7
    state_dict = torch.load(
        '/data1/wyc/AdvSemiSeg/snapshots/VOC_t_baseline_1adv_mul_new_two_patch2_20000.pth'
    )  #baseline707 adv 709 nadv 705()*2#n adv0.694

    # state_dict = torch.load(
    #     '/home/wyc/VOC_t_baseline_nadv2_20000.pth')  # baseline707 adv 709 nadv 705()*2

    # original saved file with DataParallel

    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    # load params

    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in new_state_dict and param.size(
        ) == new_state_dict[name].size():
            new_params[name].copy_(new_state_dict[name])
            print('copy {}'.format(name))

    model.load_state_dict(new_params)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(VOCDataSet(args.data_dir,
                                            args.data_list,
                                            crop_size=(505, 505),
                                            mean=IMG_MEAN,
                                            scale=False,
                                            mirror=False),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(505, 505),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(505, 505), mode='bilinear')
    data_list = []

    colorize = VOCColorize()

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % (index))
        image, label, size, name = batch
        size = size[0].numpy()
        output = model(Variable(image, volatile=True).cuda(gpu0))
        output = interp(output).cpu().data[0].numpy()

        output = output[:, :size[0], :size[1]]
        gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int)

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.int)

        filename = os.path.join(args.save_dir, '{}.png'.format(name[0]))
        color_file = Image.fromarray(
            colorize(output).transpose(1, 2, 0), 'RGB')
        color_file.save(filename)

        # show_all(gt, output)
        data_list.append([gt.flatten(), output.flatten()])

    filename = os.path.join(args.save_dir, 'result.txt')
    get_iou(data_list, args.num_classes, filename)
Beispiel #28
0
def main():
    # LD build for summary
    # training_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'train'))
    # val_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'val'))
    # dice_placeholder = tf.placeholder(tf.float32, [], name='dice')
    # loss_placeholder = tf.placeholder(tf.float32, [], name='loss')
    # # image_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='image')
    # # prediction_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='prediction')
    # tf.summary.scalar('dice', dice_placeholder)
    # tf.summary.scalar('loss', loss_placeholder)
    # # tf.summary.image('image', image_placeholder, max_outputs=1)
    # # tf.summary.image('prediction', prediction_placeholder, max_outputs=1)
    # summary_op = tf.summary.merge_all()
    # config = tf.ConfigProto()
    # config.gpu_options.allow_growth = True
    # sess = tf.Session(config=config)

    perfix_name = 'Liver'
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # LD delete
    '''
    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)
    '''

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # LD ADD start
    from dataset.LiverDataset.liver_dataset import LiverDataset
    user_name = 'ld'
    validation_interval = 800
    max_steps = 1000000000
    batch_size = 5
    n_neighboringslices = 1
    input_size = 400
    output_size = 400
    slice_type = 'axial'
    oversample = False
    # reset_counter = args.reset_counter
    label_of_interest = 1
    label_required = 0
    magic_number = 26.91
    max_slice_tries_val = 0
    max_slice_tries_train = 2
    fuse_labels = True
    apply_crop = False

    train_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_2"
    test_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1"
    train_dataset = LiverDataset(data_dir=train_data_dir,
                                 slice_type=slice_type,
                                 n_neighboringslices=n_neighboringslices,
                                 input_size=input_size,
                                 oversample=oversample,
                                 label_of_interest=label_of_interest,
                                 label_required=label_required,
                                 max_slice_tries=max_slice_tries_train,
                                 fuse_labels=fuse_labels,
                                 apply_crop=apply_crop,
                                 interval=validation_interval,
                                 is_training=True,
                                 batch_size=batch_size,
                                 data_augmentation=True)
    val_dataset = LiverDataset(data_dir=test_data_dir,
                               slice_type=slice_type,
                               n_neighboringslices=n_neighboringslices,
                               input_size=input_size,
                               oversample=oversample,
                               label_of_interest=label_of_interest,
                               label_required=label_required,
                               max_slice_tries=max_slice_tries_val,
                               fuse_labels=fuse_labels,
                               apply_crop=apply_crop,
                               interval=validation_interval,
                               is_training=False,
                               batch_size=batch_size)
    # LD ADD end

    # LD delete
    '''
    train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
                       scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = range(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)


    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)
    '''

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # LD delete
    '''
    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()
    '''

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size, input_size), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size, input_size),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size, input_size), mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1
    loss_list = []

    for i_iter in range(iter_start, args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0
        num_prediction = 0
        num_ground_truth = 0
        num_intersection = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        # LD delete
        '''
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)
        '''
        for sub_i in range(args.iter_size):

            # train G
            # LD delete
            '''
            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            
            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()
                
                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)
                
                
                pred = interp(model(images))
                pred_remain = pred.detach()

                D_out = interp(model_D(F.softmax(pred)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv/args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
                        loss_semi = loss_semi/args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None
            '''

            # train with source

            # LD delete
            '''
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            '''
            batch_image, batch_label = train_dataset.get_next_batch()
            batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2))
            batch_image = np.concatenate(
                [batch_image, batch_image, batch_image], axis=1)
            # print('Shape: ', np.shape(batch_image))
            batch_image_torch = torch.Tensor(batch_image)
            images = Variable(batch_image_torch).cuda(args.gpu)

            # LD delete
            # ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))
            pred_ny = pred.data.cpu().numpy()
            pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1))
            pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3))

            # prepare for dice
            # print('Shape of gt is: ', np.shape(batch_label))
            # print('Shape of pred is: ', np.shape(pred_ny))
            # print('Shape of pred_label is: ', np.shape(pred_label_ny))
            num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8))
            num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8))
            num_intersection += np.sum(
                np.asarray(
                    np.logical_and(batch_label >= 1, pred_label_ny >= 1),
                    np.uint8))
            # num_intersection += np.sum(np.asarray(batch_label >= 1, np.uint8) == np.asarray(pred_label_ny, np.uint8))

            loss_seg = loss_calc(pred, batch_label, args.gpu)

            # LD delete
            '''
            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred
            '''
            loss = loss_seg
            # print('Loss is: ', loss)
            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print('Loss of numpy is: ', loss_seg.data.cpu().numpy())
            # print('Loss of numpy of zero is: ', loss_seg.data.cpu().numpy())
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_list.append(loss_seg_value)
            # loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size

            # train D
            # LD delete
            '''
            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]


            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]
            '''

        optimizer.step()
        # optimizer_D.step()
        dice = (2 * num_intersection + 1e-7) / (num_prediction +
                                                num_ground_truth + 1e-7)
        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value))
        print(
            'dice: %.4f, num_prediction: %d, num_ground_truth: %d, num_intersection: %d'
            % (dice, num_prediction, num_ground_truth, num_intersection))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         perfix_name + str(args.num_steps) + '.pth'))
            # torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, perfix_name +str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            # torch.save(model.state_dict(), osp.join(args.snapshot_dir, perfix_name + str(i_iter)+'.pth'))
            save_model(model, args.snapshot_dir, perfix_name, i_iter, 2)
            # torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, perfix_name +str(i_iter)+'_D.pth'))

        # if i_iter % UPDATE_TENSORBOARD_INTERVAL and i_iter != 0:
        #     # update tensorboard
        #     feed_dict = {
        #         dice_placeholder: dice,
        #         loss_placeholder: np.mean(loss_list)
        #     }
        #     summery_value = sess.run(summary_op, feed_dict)
        #     training_summary.add_summary(summery_value, i_iter)
        #     training_summary.flush()
        #
        #     # for validation
        #     val_num_prediction = 0
        #     val_num_ground_truth = 0
        #     val_num_intersection = 0
        #     loss_list = []
        #
        #     for _ in range(VAL_EXECUTE_TIMES):
        #         batch_image, batch_label = val_dataset.get_next_batch()
        #         batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2))
        #         batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1)
        #         # print('Shape: ', np.shape(batch_image))
        #         batch_image_torch = torch.Tensor(batch_image)
        #         images = Variable(batch_image_torch).cuda(args.gpu)
        #
        #         # LD delete
        #         # ignore_mask = (labels.numpy() == 255)
        #         pred = interp(model(images))
        #         pred_ny = pred.data.cpu().numpy()
        #         pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1))
        #         pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3))
        #         val_num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8))
        #         val_num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8))
        #         val_num_intersection += np.sum(np.asarray(np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8))
        #
        #         loss_seg = loss_calc(pred, batch_label, args.gpu)
        #         loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
        #         loss_list.append(loss_seg)
        #     dice = (2 * val_num_intersection + 1e-7) / (val_num_prediction + val_num_ground_truth + 1e-7)
        #     feed_dict = {
        #         dice_placeholder: dice,
        #         loss_placeholder: np.mean(loss_list)
        #     }
        #     summery_value = sess.run(summary_op, feed_dict)
        #     val_summary.add_summary(summery_value, i_iter)
        #     val_summary.flush()
        #     loss_list = []

    training_summary.close()
    val_summary.close()
    end = timeit.default_timer()
    print(end - start, 'seconds')
Beispiel #29
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=16,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=16,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler_all = data.sampler.SubsetRandomSampler(train_ids)
        train_gt_sampler_all = data.sampler.SubsetRandomSampler(train_ids)
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader_all = data.DataLoader(train_dataset,
                                          batch_size=args.batch_size,
                                          sampler=train_sampler_all,
                                          num_workers=16,
                                          pin_memory=True)
        trainloader_gt_all = data.DataLoader(train_gt_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_gt_sampler_all,
                                             num_workers=16,
                                             pin_memory=True)
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=16,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=16,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=16,
                                         pin_memory=True)

        trainloader_remain_iter = iter(trainloader_remain)

    trainloader_all_iter = iter(trainloader_all)
    trainloader_iter = iter(trainloader)
    trainloader_gt_iter = iter(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    #y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable(torch.zeros(args.batch_size, 1).cuda())

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_fm_value = 0
        loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # train with source

            try:
                batch = next(trainloader_iter)
            except:
                trainloader_iter = iter(trainloader)
                batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            #ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)
            loss_seg.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size

            if i_iter >= args.adv_start:

                #fm loss calc
                try:
                    batch = next(trainloader_all_iter)
                except:
                    trainloader_iter = iter(trainloader_all)
                    batch = next(trainloader_all_iter)

                images, labels, _, _ = batch
                images = Variable(images).cuda(args.gpu)
                #ignore_mask = (labels.numpy() == 255)
                pred = interp(model(images))

                _, D_out_y_pred = model_D(F.softmax(pred))

                trainloader_gt_iter = iter(trainloader_gt)
                batch = next(trainloader_gt_iter)

                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
                #ignore_mask_gt = (labels_gt.numpy() == 255)

                _, D_out_y_gt = model_D(D_gt_v)

                fm_loss = torch.mean(
                    torch.abs(
                        torch.mean(D_out_y_gt, 0) -
                        torch.mean(D_out_y_pred, 0)))

                loss = loss_seg + args.lambda_fm * fm_loss

                # proper normalization
                fm_loss.backward()
                #loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
                loss_fm_value += fm_loss.data.cpu().numpy()[0] / args.iter_size
                loss_value += loss.data.cpu().numpy()[0] / args.iter_size

                # train D

                # bring back requires_grad
                for param in model_D.parameters():
                    param.requires_grad = True

                # train with pred
                pred = pred.detach()

                D_out_z, _ = model_D(F.softmax(pred))
                y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda())
                loss_D_fake = criterion(D_out_z, y_fake_)

                # train with gt
                # get gt labels
                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
                #ignore_mask_gt = (labels_gt.numpy() == 255)

                D_out_z_gt, _ = model_D(D_gt_v)
                #D_out = interp(D_out_x)

                y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda())

                loss_D_real = criterion(D_out_z_gt, y_real_)
                loss_D = loss_D_fake + loss_D_real
                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_D = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value, loss_D_value))
        print('fm_loss: ', loss_fm_value, ' g_loss: ', loss_value)

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Beispiel #30
0
                             pretrained=True,
                             _print=False)
    elif config.MODEL_SELECTION == 'pretrained_deeplab_multi':
        net = DeepLabv3_plus_multi(nInputChannels=config.NUM_CHANNELS,
                                   n_classes=config.NUM_CLASSES,
                                   os=16,
                                   pretrained=True,
                                   _print=False)
    elif config.MODEL_SELECTION == 'pretrained_deeplab_multi_depth':
        net = DeepLabv3_plus_multi_depth(nInputChannels=config.NUM_CHANNELS,
                                         n_classes=config.NUM_CLASSES,
                                         os=16,
                                         pretrained=True,
                                         _print=False)
    elif config.MODEL_SELECTION == 'og_deeplab':
        net = Res_Deeplab(num_classes=config.NUM_CLASSES)
    else:
        logging.info("*** No Model Selected! ***")
        exit(0)

    upsample = nn.Upsample(size=(config.CROP_H, config.CROP_W),
                           mode='bilinear',
                           align_corners=True)

    #######################
    #######################

    logging.info(f'Loading model {args.model}')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
Beispiel #31
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True


    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in currendt model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print (name)
        if name in saved_state_dict and param.size() == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))

    model.load_state_dict(new_params)


    model.train()


    model=nn.DataParallel(model)
    model.cuda()


    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))

    model_D = nn.DataParallel(model_D)
    model_D.train()
    model_D.cuda()


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)


    train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
                       scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)


    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)


    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')


    # labels for adversarial training
    pred_label = 0
    gt_label = 1


    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):


            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda()


                pred = interp(model(images))
                pred_remain = pred.detach()

                mask1=F.softmax(pred,dim=1).data.cpu().numpy()

                id2 = np.argmax(mask1, axis=1)#10, 321, 321)


                D_out = interp(model_D(F.softmax(pred,dim=1)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)


                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv/args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)
                    #print semi_ignore_mask.shape 10,321,321

                    map2 = np.zeros([pred.size()[0], id2.shape[1], id2.shape[2]])
                    for k in  range(pred.size()[0]):
                        for i in range(id2.shape[1]):
                            for j in range(id2.shape[2]):
                               map2[k][i][j] = mask1[k][id2[k][i][j]][i][j]


                    semi_ignore_mask = (map2 <  0.999999)
                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(pred, semi_gt)
                        loss_semi = loss_semi/args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda()
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels)

            D_out = interp(model_D(F.softmax(pred,dim=1)))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss/args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size


            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)


            D_out = interp(model_D(F.softmax(pred,dim=1)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]


            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]



        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value))

        if i_iter >= args.num_steps-1:
            print( 'save model ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter!=0:
            print ('taking snapshot ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end-start,'seconds')
Beispiel #32
0
def main():
    """Create the model and start the training."""
    w, h = map(int, args.input_size_source.split(','))
    input_size_source = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        # model.load_state_dict(saved_state_dict)

    elif args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           pretrained=True,
                           vgg16_caffe_path=args.restore_from)

        # saved_state_dict = torch.load(args.restore_from)
        # model.load_state_dict(saved_state_dict)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    model.train()
    model.cuda(args.gpu)
    cudnn.benchmark = True

    #Discrimintator setting
    model_D = FCDiscriminator(num_classes=args.num_classes)
    model_D.train()
    model_D.cuda(args.gpu)

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # labels for adversarial training
    source_adv_label = 0
    target_adv_label = 1

    #Dataloader
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.translated_data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size_source,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    style_trainloader = data.DataLoader(GTA5DataSet(
        args.stylized_data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_source,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers,
                                        pin_memory=True)

    style_trainloader_iter = enumerate(style_trainloader)

    if STAGE == 1:
        targetloader = data.DataLoader(cityscapesDataSet(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    else:
        #Dataloader for self-training
        targetloader = data.DataLoader(cityscapesDataSetLabel(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set,
            label_folder='Path to generated pseudo labels'),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # load checkpoint
    model, model_D, optimizer, start_iter = load_checkpoint(
        model,
        model_D,
        optimizer,
        filename=args.snapshot_dir + 'checkpoint_' + CHECKPOINT + '.pth.tar')

    for i_iter in range(start_iter, args.num_steps):
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        #train segementation network
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        if STAGE == 1:
            if i_iter % 2 == 0:
                _, batch = next(trainloader_iter)
            else:
                _, batch = next(style_trainloader_iter)

        else:
            _, batch = next(trainloader_iter)

        image_source, label, _, _ = batch
        image_source = Variable(image_source).cuda(args.gpu)

        pred_source = model(image_source)
        pred_source = interp(pred_source)

        loss_seg_source = loss_calc(pred_source, label, args.gpu)
        loss_seg_source_value = loss_seg_source.item()
        loss_seg_source.backward()

        if STAGE == 2:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, target_label, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #target segmentation loss
            loss_seg_target = loss_calc(pred_target,
                                        target_label,
                                        gpu=args.gpu)
            loss_seg_target.backward()

        # optimize
        optimizer.step()

        if STAGE == 1:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #output-level adversarial training
            D_output_target = model_D(F.softmax(pred_target))
            loss_adv = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_adv = loss_adv * args.lambda_adv
            loss_adv.backward()

            #train discriminator
            for param in model_D.parameters():
                param.requires_grad = True

            pred_source = pred_source.detach()
            pred_target = pred_target.detach()

            D_output_source = model_D(F.softmax(pred_source))
            D_output_target = model_D(F.softmax(pred_target))

            loss_D_source = bce_loss(
                D_output_source,
                Variable(
                    torch.FloatTensor(D_output_source.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_D_target = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        target_adv_label)).cuda(args.gpu))

            loss_D_source = loss_D_source / 2
            loss_D_target = loss_D_target / 2

            loss_D_source.backward()
            loss_D_target.backward()

            #optimize
            optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg_source = {2:.5f}'.format(
            i_iter, args.num_steps, loss_seg_source_value))

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            state = {
                'iter': i_iter,
                'model': model.state_dict(),
                'model_D': model_D.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                state,
                osp.join(args.snapshot_dir,
                         'checkpoint_' + str(i_iter) + '.pth.tar'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_D_' + str(i_iter) + '.pth'))

            cityscapes_eval_dir = osp.join(args.cityscapes_eval_dir,
                                           str(i_iter))
            if not os.path.exists(cityscapes_eval_dir):
                os.makedirs(cityscapes_eval_dir)

            eval(osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'),
                 cityscapes_eval_dir, i_iter)

            iou19, iou13, iou = compute_mIoU(cityscapes_eval_dir, i_iter)
            outputfile = open(args.output_file, 'a')
            outputfile.write(
                str(i_iter) + '\t' + str(iou19) + '\t' +
                str(iou.replace('\n', ' ')) + '\n')
            outputfile.close()