Exemplo n.º 1
0
def main(args):
    """Create the model and start the evaluation process."""

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(ListDataSet(args.data_dir,
                                             args.img_list,
                                             args.lbl_list,
                                             crop_size=(1024, 512),
                                             mean=IMG_MEAN,
                                             split=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    interp = nn.Upsample(size=(1024, 2048),
                         mode='bilinear',
                         align_corners=True)

    with torch.no_grad():
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, _, name = batch
            if args.model == 'DeeplabMulti':
                output1, output2 = model(Variable(image).cuda(gpu0))
                output = interp(output2).cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG':
                output = model(Variable(image).cuda(gpu0))
                output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            output.save('%s/%s' % (args.save, name))
            output_col.save('%s/%s_color.png' %
                            (args.save, name.split('.')[0]))
Exemplo n.º 2
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=input_size, mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        if args.model == 'DeeplabMulti':
            output1, output2,_,_ = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
Exemplo n.º 3
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print '%d processd' % index
        image, _, name = batch
        if args.model == 'DeeplabMulti':
            output1, output2 = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    #model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(512, 1024),
                                                   resize_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)
    scale = 0.9
    testloader3 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')

    for index, img_data in enumerate(zip(testloader, testloader2,
                                         testloader3)):
        batch, batch2, batch3 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        #image3, _, _, name3 = batch3

        inputs = image.cuda()
        inputs2 = image2.cuda()
        #inputs3 = Variable(image3).cuda()
        print('\r>>>>Extracting feature...%03d/%03d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = output1, output2
                #output_batch = interp(sm(output1))
                #output_batch = interp(sm(output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = heatmap_output1 + output1, heatmap_output2 + output2
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
                heatmap_batch = torch.sum(kl_distance(log_sm(heatmap_output1),
                                                      sm(heatmap_output2)),
                                          dim=1)
                heatmap_batch = torch.log(
                    1 + 10 * heatmap_batch)  # for visualization
                heatmap_batch = heatmap_batch.cpu().data.numpy()

                #output1, output2 = model(inputs3)
                #output_batch += interp(sm(0.5* output1 + output2)).cpu().data.numpy()
                #output1, output2 = model(fliplr(inputs3))
                #output1, output2 = fliplr(output1), fliplr(output2)
                #output_batch += interp(sm(0.5 * output1 + output2)).cpu().data.numpy()
                #del output1, output2, inputs3
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        scoremap_batch = np.asarray(np.max(output_batch, axis=3))
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        output_iterator = []
        heatmap_iterator = []
        scoremap_iterator = []

        for i in range(output_batch.shape[0]):
            output_iterator.append(output_batch[i, :, :])
            heatmap_iterator.append(heatmap_batch[i, :, :] /
                                    np.max(heatmap_batch[i, :, :]))
            scoremap_iterator.append(1 - scoremap_batch[i, :, :] /
                                     np.max(scoremap_batch[i, :, :]))
            name_tmp = name[i].split('/')[-1]
            name[i] = '%s/%s' % (args.save, name_tmp)
        with Pool(4) as p:
            p.map(save, zip(output_iterator, name))
            p.map(save_heatmap, zip(heatmap_iterator, name))
            p.map(save_scoremap, zip(scoremap_iterator, name))

        del output_batch

    return args.save
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(robotDataSet(args.data_dir,
                                              args.data_list,
                                              crop_size=(960, 1280),
                                              resize_size=(1280, 960),
                                              mean=IMG_MEAN,
                                              scale=False,
                                              mirror=False,
                                              set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(robotDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(960 * scale), round(1280 * scale)),
        resize_size=(round(1280 * scale), round(960 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(960, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(960, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        print(image.shape)

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        #output_batch[score_batch<3.6] = 255  #3.6 = 4*0.9

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            #save_path = re.replace(save_path, 'leftImg8bit', 'pseudo')
            #print(save_path)
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

    return args.save
Exemplo n.º 6
0
class Test:
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset,
                                            split='test',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.save_path = save_path
        self.save_batch = save_batch

        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city,
                                     split='test',
                                     img_root=config.img_root)
            self.target_set.append(test)
            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
            train = spacenet.Spacenet(city=city,
                                      split='train',
                                      img_root=config.img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        # self.model = DeepLab(num_classes=2,
        #         backbone=config.backbone,
        #         output_stride=config.out_stride,
        #         sync_bn=config.sync_bn,
        #         freeze_bn=config.freeze_bn)
        self.model = DeeplabMulti(num_classes=2)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        # print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

    def get_performance(self, dataloader, trainloader, city):
        # change mean and var of bn to adapt to the target domain
        if self.bn and city != self.config.dataset:
            print('BN Adaptation on' + city)
            self.model.train()
            for sample in trainloader:
                image, target = sample['image'], sample['label']
                if self.cuda:
                    image, target = image.cuda(), target.cuda()
                with torch.no_grad():
                    output = self.model(image)

        batch = self.save_batch
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(dataloader, desc='\r')

        # save in different directories
        if self.bn:
            save_path = os.path.join(self.save_path, city + '_bn')
        else:
            save_path = os.path.join(self.save_path, city)

        interp = nn.Upsample(size=(400, 400), mode='bilinear')

        # evaluate on the test dataset
        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, target = sample
            if self.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                _, output = self.model(image)
            output = interp(output)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # save pictures
            if batch > 0:
                if not os.path.exists(self.save_path):
                    os.mkdir(self.save_path)
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                image = image.cpu().numpy() * 255
                image = image.transpose(0, 2, 3, 1).astype(int)

                imgs = self.color_images(pred, target)
                self.save_images(imgs, batch, save_path, False)
                self.save_images(image, batch, save_path, True)
                batch -= 1

        Acc = self.evaluator.Building_Acc()
        IoU = self.evaluator.Building_IoU()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        return Acc, IoU, mIoU

    def test(self):
        A, I, Im = self.get_performance(self.source_loader, None,
                                        self.config.dataset)
        tA, tI, tIm = [], [], []
        for dl, tl, city in zip(self.target_loader, self.target_trainloader,
                                self.target):
            tA_, tI_, tIm_ = self.get_performance(dl, tl, city)
            tA.append(tA_)
            tI.append(tI_)
            tIm.append(tIm_)

        res = {}
        print("Test for source domain:")
        print("{}: Acc:{}, IoU:{}, mIoU:{}".format(self.config.dataset, A, I,
                                                   Im))
        res[config.dataset] = {'Acc': A, 'IoU': I, 'mIoU': Im}

        print('Test for target domain:')
        for i, city in enumerate(self.target):
            print("{}: Acc:{}, IoU:{}, mIoU:{}".format(city, tA[i], tI[i],
                                                       tIm[i]))
            res[city] = {'Acc': tA[i], 'IoU': tI[i], 'mIoU': tIm[i]}

        if self.bn:
            name = 'train_log/test_bn.json'
        else:
            name = 'train_log/test.json'

        with open(name, 'w') as f:
            json.dump(res, f)

    def save_images(self, imgs, batch_index, save_path, if_original=False):
        for i, img in enumerate(imgs):
            img = img[:, :, ::-1]  # change to BGR
            # from IPython import embed
            # embed()
            if not if_original:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Original.jpg'),
                    img)
            else:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Pred.jpg'), img)

    def color_images(self, pred, target):
        imgs = []
        for p, t in zip(pred, target):
            tmp = p * 2 + t
            np.squeeze(tmp)
            img = np.zeros((p.shape[0], p.shape[1], 3))
            # bkg:negative, building:postive
            # from IPython import embed
            # embed()
            img[np.where(tmp == 0)] = [0, 0, 0]  # Black RGB, for true negative
            img[np.where(tmp == 1)] = [255, 0,
                                       0]  # Red RGB, for false negative
            img[np.where(tmp == 2)] = [0, 255,
                                       0]  # Green RGB, for false positive
            img[np.where(tmp == 3)] = [255, 255,
                                       0]  # Yellow RGB, for true positive
            imgs.append(img)
        return imgs
Exemplo n.º 7
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        #model = Res_Deeplab(num_classes=args.num_classes)
        model = DeepLab(backbone='resnet', output_stride=8)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)
    model.eval()

    num_classes = 20
    tp_list = [0] * num_classes
    fp_list = [0] * num_classes
    fn_list = [0] * num_classes
    iou_list = [0] * num_classes

    hist = np.zeros((21, 21))
    group = 1
    scorer = SegScorer(num_classes=21)
    datalayer = SSDatalayer(group)
    cos_similarity_func = nn.CosineSimilarity()
    for count in tqdm(range(1000)):
        dat = datalayer.dequeue()
        ref_img = dat['second_img'][0]  # (3, 457, 500)
        query_img = dat['first_img'][0]  # (3, 375, 500)
        query_label = dat['second_label'][0]  # (1, 375, 500)
        ref_label = dat['first_label'][0]  # (1, 457, 500)
        # query_img = dat['second_img'][0]
        # ref_img = dat['first_img'][0]
        # ref_label = dat['second_label'][0]
        # query_label = dat['first_label'][0]
        deploy_info = dat['deploy_info']
        semantic_label = deploy_info['first_semantic_labels'][0][0] - 1  # 2

        ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(
            ref_label).cuda()
        query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(
            query_label[0, :, :]).cuda()
        #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
        #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

        # ref_img = ref_img*ref_label
        ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
        query_label_var, ref_label_var = Variable(query_label), Variable(
            ref_label)

        ref_img_var = torch.unsqueeze(ref_img_var, dim=0)  # [1, 3, 457, 500]
        ref_label_var = torch.unsqueeze(ref_label_var,
                                        dim=1)  # [1, 1, 457, 500]
        query_img_var = torch.unsqueeze(query_img_var,
                                        dim=0)  # [1, 3, 375, 500]
        query_label_var = torch.unsqueeze(query_label_var,
                                          dim=0)  # [1, 375, 500]

        samples = torch.cat([ref_img_var, query_img_var], 0)
        pred = model(samples, ref_label_var)
        w, h = query_label.size()
        pred = F.upsample(pred, size=(w, h), mode='bilinear')  #[2, 416, 416]
        pred = F.softmax(pred, dim=1).squeeze()
        values, pred = torch.max(pred, dim=0)
        #print(pred.shape)
        pred = pred.data.cpu().numpy().astype(np.int32)  # (333, 500)
        #print(pred.shape)
        org_img = get_org_img(
            query_img.squeeze().cpu().data.numpy())  # 查询集的图片(375, 500, 3)
        #print(org_img.shape)
        img = mask_to_img(pred, org_img)  # (375, 500, 3)mask和原图加权后的彩色图片
        cv2.imwrite('save_bins/que_pred/query_set_1_%d.png' % (count), img)

        query_label = query_label.cpu().numpy().astype(np.int32)  # (333, 500)
        class_ind = int(deploy_info['first_semantic_labels'][0][0]
                        ) - 1  # because class indices from 1 in data layer,0
        scorer.update(pred, query_label, class_ind + 1)
        tp, tn, fp, fn = measure(query_label, pred)
        # iou_img = tp/float(max(tn+fp+fn,1))
        tp_list[class_ind] += tp
        fp_list[class_ind] += fp
        fn_list[class_ind] += fn
        # max in case both pred and label are zero
        iou_list = [
            tp_list[ic] /
            float(max(tp_list[ic] + fp_list[ic] + fn_list[ic], 1))
            for ic in range(num_classes)
        ]

        tmp_pred = pred
        tmp_pred[tmp_pred > 0.5] = class_ind + 1
        tmp_gt_label = query_label
        tmp_gt_label[tmp_gt_label > 0.5] = class_ind + 1

        hist += Metrics.fast_hist(tmp_pred, query_label, 21)

    print("-------------GROUP %d-------------" % (group))
    print(iou_list)
    class_indexes = range(group * 5, (group + 1) * 5)
    print('Mean:', np.mean(np.take(iou_list, class_indexes)))
    '''
    for group in range(2):
        datalayer = SSDatalayer(group+1)
        restore(args, model, group+1)

        for count in tqdm(range(1000)):
            dat = datalayer.dequeue()
            ref_img = dat['second_img'][0]#(3, 457, 500)
            query_img = dat['first_img'][0]#(3, 375, 500)
            query_label = dat['second_label'][0]#(1, 375, 500)
            ref_label = dat['first_label'][0]#(1, 457, 500)
            # query_img = dat['second_img'][0]
            # ref_img = dat['first_img'][0]
            # ref_label = dat['second_label'][0]
            # query_label = dat['first_label'][0]
            deploy_info = dat['deploy_info']
            semantic_label = deploy_info['first_semantic_labels'][0][0] - 1#2

            ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(ref_label).cuda()
            query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(query_label[0,:,:]).cuda()
            #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
            #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

            # ref_img = ref_img*ref_label
            ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
            query_label_var, ref_label_var = Variable(query_label), Variable(ref_label)

            ref_img_var = torch.unsqueeze(ref_img_var,dim=0)#[1, 3, 457, 500]
            ref_label_var = torch.unsqueeze(ref_label_var, dim=1)#[1, 1, 457, 500]
            query_img_var = torch.unsqueeze(query_img_var, dim=0)#[1, 3, 375, 500]
            query_label_var = torch.unsqueeze(query_label_var, dim=0)#[1, 375, 500]

            logits  = model(query_img_var, ref_img_var, ref_label_var,ref_label_var)

            # w, h = query_label.size()
            # outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear')
            # out_side = F.softmax(outB_side, dim=1).squeeze()
            # values, pred = torch.max(out_side, dim=0)
            values, pred = model.get_pred(logits, query_img_var)#values[2, 333, 500]
            pred = pred.data.cpu().numpy().astype(np.int32)#(333, 500)

            query_label = query_label.cpu().numpy().astype(np.int32)#(333, 500)
            class_ind = int(deploy_info['first_semantic_labels'][0][0])-1 # because class indices from 1 in data layer,0
            scorer.update(pred, query_label, class_ind+1)
            tp, tn, fp, fn = measure(query_label, pred)
            # iou_img = tp/float(max(tn+fp+fn,1))
            tp_list[class_ind] += tp
            fp_list[class_ind] += fp
            fn_list[class_ind] += fn
            # max in case both pred and label are zero
            iou_list = [tp_list[ic] /
                        float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1))
                        for ic in range(num_classes)]


            tmp_pred = pred
            tmp_pred[tmp_pred>0.5] = class_ind+1
            tmp_gt_label = query_label
            tmp_gt_label[tmp_gt_label>0.5] = class_ind+1

            hist += Metrics.fast_hist(tmp_pred, query_label, 21)


        print("-------------GROUP %d-------------"%(group))
        print(iou_list)
        class_indexes = range(group*5, (group+1)*5)
        print('Mean:', np.mean(np.take(iou_list, class_indexes)))

    print('BMVC IOU', np.mean(np.take(iou_list, range(0,20))))

    miou = Metrics.get_voc_iou(hist)
    print('IOU:', miou, np.mean(miou))
    '''

    binary_hist = np.array((hist[0, 0], hist[0, 1:].sum(), hist[1:, 0].sum(),
                            hist[1:, 1:].sum())).reshape((2, 2))
    bin_iu = np.diag(binary_hist) / (binary_hist.sum(1) + binary_hist.sum(0) -
                                     np.diag(binary_hist))
    print('Bin_iu:', bin_iu)

    scores = scorer.score()
    for k in scores.keys():
        print(k, np.mean(scores[k]), scores[k])
Exemplo n.º 8
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    log_dir = args.save
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
    log_dir = os.path.join(log_dir, exp_name)
    writer = SummaryWriter(log_dir)

    # testloader = data.DataLoader(SyntheticSmokeTrain(args={}, dataset_limit=-1, #args.num_steps * args.iter_size * args.batch_size,
    #                 image_shape=(360,640), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)

    testloader = data.DataLoader(SmokeDataset(image_size=(640, 360),
                                              dataset_mean=IMG_MEAN),
                                 batch_size=1,
                                 shuffle=True,
                                 pin_memory=True)
    # testloader = data.DataLoader(SimpleSmokeTrain(args = {}, image_size=(640,360), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)
    # testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
    # batch_size=1, shuffle=False, pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)

    count = 0
    iou_sum_fg = 0
    iou_count_fg = 0

    iou_sum_bg = 0
    iou_count_bg = 0

    for index, batch in enumerate(testloader):
        if (index + 1) % 100 == 0:
            print('%d processd' % index)
            # print("Processed {}/{}".format(index, len(testloader)))

        # if count > 5:
        #     break
        image, label, name = batch
        if args.model == 'DeeplabMulti':
            with torch.no_grad():
                output1, output2 = model(Variable(image).cuda(gpu0))
            # print(output1.shape)
            # print(output2.shape)
            output = interp(output2).cpu()
            orig_output = output.detach().clone()
            output = output.data[0].numpy()
            # output = (output > 0.5).astype(np.uint8)*255
            # print(np.all(output==0), np.all(output==255))
            # print(np.min(output), np.max(output))

        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            with torch.no_grad():
                output = model(Variable(image).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        classes_seen = set(output.ravel().tolist())
        # print(classes_seen)
        # print(output.shape, name[0])
        output_col = colorize_mask(output)
        output = Image.fromarray(output)
        # print("name", name)
        name = name[0]
        # name = name[0].split('/')[-1]

        if len(classes_seen) > 1:
            count += 1
            print(classes_seen)
            print(Counter(np.asarray(output).ravel()))
            image = image.squeeze()
            for c in range(3):
                image[c, :, :] += IMG_MEAN[c]
                # image2[c,:,:] += IMG_MEAN[2-c]
            image = (image - image.min()) / (image.max() - image.min())
            image = image[[2, 1, 0], :, :]
            print(image.shape, image.min(), image.max())
            output.save(os.path.join(args.save, name + '.png'))
            output_col.save(os.path.join(args.save, name + '_color.png'))
            # output.save('%s/%s.png' % (args.save, name))
            # output_col.save('%s/%s_color.png' % (args.save, name))#.split('.')[0]))

            output_argmaxs = torch.argmax(orig_output.squeeze(), dim=0)
            mask1 = (output_argmaxs == 0).float() * 255
            label = label.squeeze()

            iou_fg = iou_pytorch(mask1, label)
            print("foreground IoU", iou_fg)
            iou_sum_fg += iou_fg
            iou_count_fg += 1

            mask2 = (output_argmaxs > 0).float() * 255
            label2 = label.max() - label

            iou_bg = iou_pytorch(mask2, label2)
            print("IoU for background: ", iou_bg)
            iou_sum_bg += iou_bg
            iou_count_bg += 1

            writer.add_images(f'input_images',
                              tf.resize(image[[2, 1, 0]], [1080, 1920]),
                              index,
                              dataformats='CHW')

            print("shape of label", label.shape)
            label_reshaped = tf.resize(label.unsqueeze(0),
                                       [1080, 1920]).squeeze()
            print("label reshaped: ", label_reshaped.shape)
            writer.add_images(f'labels',
                              label_reshaped,
                              index,
                              dataformats='HW')
            writer.add_images(
                f'output/1',
                255 - np.asarray(tf.resize(output, [1080, 1920])) * 255,
                index,
                dataformats='HW')
            # writer.add_images(f'output/1',np.asarray(output)*255, index,dataformats='HW')
            # writer.add_images(f'output/2',np.asarray(output_col), index, dataformats='HW')
            writer.add_scalar(f'iou/smoke', iou_fg, index)
            writer.add_scalar(f'iou/background', iou_bg, index)
            writer.add_scalar(f'iou/mean', (iou_bg + iou_fg) / 2, index)
            writer.flush()

    if iou_count_fg > 0:
        print("Mean IoU, foreground: {}".format(iou_sum_fg / iou_count_fg))
        print("Mean IoU, background: {}".format(iou_sum_bg / iou_count_bg))
        print("Mean IoU, averaged over classes: {}".format(
            (iou_sum_fg + iou_sum_bg) / (iou_count_fg + iou_count_bg)))
Exemplo n.º 9
0
def main():
    seed = 1338
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)
    for files in range(int(args.num_steps_stop / args.save_pred_every)):
        print('Step: ', (files + 1) * args.save_pred_every)
        if SOURCE_ONLY:
            saved_state_dict = torch.load('./snapshots/source_only/GTA5_' +
                                          str((files + 1) *
                                              args.save_pred_every) + '.pth')
        else:
            if args.level == 'single-level':
                saved_state_dict = torch.load(
                    './snapshots/single_level/GTA5_' +
                    str((files + 1) * args.save_pred_every) + '.pth')
            elif args.level == 'multi-level':
                saved_state_dict = torch.load('./snapshots/multi_level/GTA5_' +
                                              str((files + 1) *
                                                  args.save_pred_every) +
                                              '.pth')
            else:
                raise NotImplementedError(
                    'level choice {} is not implemented'.format(args.level))
        ### for running different versions of pytorch
        model_dict = model.state_dict()
        saved_state_dict = {
            k: v
            for k, v in saved_state_dict.items() if k in model_dict
        }
        model_dict.update(saved_state_dict)
        ###
        model.load_state_dict(saved_state_dict)

        device = torch.device("cuda" if not args.cpu else "cpu")
        model = model.to(device)
        if args.multi_gpu:
            model = nn.DataParallel(model)

        model.eval()

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, name = batch
            image = image.to(device)

            if args.model == 'DeeplabMulti':
                output1, output2 = model(image)
                output = interp(output2).cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
                output = model(image)
                output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            if SOURCE_ONLY:
                if not os.path.exists(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every))):
                    os.makedirs(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every)))
                output.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every), name))
                output_col.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every),
                        name.split('.')[0] + '_color.png'))
            else:
                if args.level == 'single-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                elif args.level == 'multi-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                else:
                    raise NotImplementedError(
                        'level choice {} is not implemented'.format(
                            args.level))
Exemplo n.º 10
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    confidence_path = os.path.join(args.save, 'submit/confidence')
    label_path = os.path.join(args.save, 'submit/labelTrainIds')
    label_invalid_path = os.path.join(args.save,
                                      'submit/labelTrainIds_invalid')
    for path in [confidence_path, label_path, label_invalid_path]:
        if not os.path.exists(path):
            os.makedirs(path)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(DarkZurichDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(h, w),
                                                   resize_size=(w, h),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(DarkZurichDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(h * scale), round(w * scale)),
        resize_size=(round(w * scale), round(h * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1080, 1920),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1080, 1920), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')
    prior = np.load('./utils/prior_all.npy').transpose(
        (2, 0, 1))[np.newaxis, :, :, :]
    prior = torch.from_numpy(prior)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, name = batch
        image2, _, name2 = batch2

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, args.batchsize * len(testloader)),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))

                heatmap_batch = torch.sum(kl_distance(log_sm(output1),
                                                      sm(output2)),
                                          dim=1)

                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                ratio = 0.95
                output_batch = output_batch.cpu() / 4
                # output_batch = output_batch *(ratio + (1 - ratio) * prior)
                output_batch = output_batch.data.numpy()
                heatmap_batch = heatmap_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        threshold = 0.3274
        for i in range(output_batch.shape[0]):
            output_single = output_batch[i, :, :]
            output_col = colorize_mask(output_single)
            output = Image.fromarray(output_single)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

            # heatmap_tmp = heatmap_batch[i,:,:]/np.max(heatmap_batch[i,:,:])
            # fig = plt.figure()
            # plt.axis('off')
            # heatmap = plt.imshow(heatmap_tmp, cmap='viridis')
            # fig.colorbar(heatmap)
            # fig.savefig('%s/%s_heatmap.png' % (save_path, name_tmp.split('.')[0]))

            if args.set == 'test' or args.set == 'val':
                # label
                output.save('%s/%s' % (label_path, name_tmp))
                # label invalid
                output_single[score_batch[i, :, :] < threshold] = 255
                output = Image.fromarray(output_single)
                output.save('%s/%s' % (label_invalid_path, name_tmp))
                # conficence

                confidence = score_batch[i, :, :] * 65535
                confidence = np.asarray(confidence, dtype=np.uint16)
                print(confidence.min(), confidence.max())
                iio.imwrite('%s/%s' % (confidence_path, name_tmp), confidence)

    return args.save
Exemplo n.º 11
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict, strict=False)

    print(model)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(BDDDataSet(args.data_dir,
                                            args.data_list,
                                            crop_size=(960, 540),
                                            mean=IMG_MEAN,
                                            scale=False,
                                            mirror=False,
                                            set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)  # 960 540

    interp = nn.Upsample(size=(720, 1280), mode='bilinear', align_corners=True)

    if args.save_confidence:
        select = open('list.txt', 'w')
        c_list = []

    for index, batch in enumerate(testloader):
        if index % 10 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        output = model(image)

        if args.save_confidence:
            confidence = get_confidence(output)
            confidence = confidence.cpu().item()
            c_list.append([confidence, name])
            name = name[0].split('/')[-1]
            save_path = '%s/%s_c.txt' % (args.save, name.split('.')[0])
            record = open(save_path, 'w')
            record.write('%.5f' % confidence)
            record.close()
        else:
            name = name[0].split('/')[-1]

        output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        output.save('%s/%s' % (args.save, name[:-4] + '.png'))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))

    def takeFirst(elem):
        return elem[0]

    if args.save_confidence:
        c_list.sort(key=takeFirst, reverse=True)
        length = len(c_list)
        for i in range(length // 3):
            print(c_list[i][0])
            print(c_list[i][1])
            select.write(c_list[i][1][0])
            select.write('\n')
        select.close()

    print(args.save)
Exemplo n.º 12
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes,
                             train_bn=False,
                             norm_style='in')
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda()

    testloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                             args.data_list,
                                             crop_size=(640, 1280),
                                             resize_size=(1280, 640),
                                             mean=IMG_MEAN,
                                             scale=False,
                                             mirror=False),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, batch in enumerate(testloader):
        if (index * batchsize) % 100 == 0:
            print('%d processd' % (index * batchsize))
        image, _, _, name = batch
        print(image.shape)

        inputs = Variable(image).cuda()
        if args.model == 'DeeplabMulti':
            output1, output2 = model(inputs)
            output_batch = interp(sm(0.5 * output1 +
                                     output2)).cpu().data.numpy()
            #output1, output2 = model(fliplr(inputs))
            #output2 = fliplr(output2)
            #output_batch += interp(output2).cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            output.save('%s/%s' % (args.save, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (args.save, name_tmp.split('.')[0]))

    return args.save
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    nyu_nyu_dict = {11:255, 13:255, 15:255, 17:255, 19:255, 20:255, 21: 255, 23: 255, 
            24:255, 25:255, 26:255, 27:255, 28:255, 29:255, 31:255, 32:255, 33:255}
    nyu_nyu_map = lambda x: nyu_nyu_dict.get(x+1,x)
    nyu_nyu_map = np.vectorize(nyu_nyu_map)
    args.nyu_nyu_map = nyu_nyu_map
    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    metrics = StreamSegMetrics(args.num_classes)
    metrics_remap = StreamSegMetrics(args.num_classes)
    ignore_label = 255
    value_scale = 255 
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    val_transform = transforms.Compose([
	    # et.ExtResize( 512 ),
	    transforms.Crop([args.height+1, args.width+1], crop_type='center', padding=IMG_MEAN, ignore_label=ignore_label),
	    transforms.ToTensor(),
	    transforms.Normalize(mean=IMG_MEAN,
	    	    std=[1, 1, 1]),
	])
    val_dst = NYU(root=args.data_dir, opt=args,
			 split='val', transform=val_transform,
			 imWidth = args.width, imHeight = args.height, phase="TEST",
			 randomize = False)
    print("Dset Length {}".format(len(val_dst)))
    testloader = data.DataLoader(val_dst,
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(args.height+1, args.width+1), mode='bilinear', align_corners=True)
    metrics.reset()
    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, targets, name = batch
        image = image.to(device)
        print(index)
        if args.model == 'DeeplabMulti':
            output1, output2 = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()
        targets = targets.cpu().numpy()
        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        preds = output[None,:,:]
        #input_ = image.cpu().numpy()[0].transpose(1,2,0) + np.array(IMG_MEAN)
        metrics.update(targets, preds)
        targets = args.nyu_nyu_map(targets)
        preds = args.nyu_nyu_map(preds)
        metrics_remap.update(targets,preds)
        #input_ = Image.fromarray(input_.astype(np.uint8))
        #output_col = colorize_mask(output)
        #output = Image.fromarray(output)
        
        #name = name[0].split('/')[-1]
        #input_.save('%s/%s' % (args.save, name))
        #output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
    print(metrics.get_results())
    print(metrics_remap.get_results())
Exemplo n.º 14
0
def main():
    args = get_arguments()

    seed = args.random_seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    input_size = (1024, 512)
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    if args.num_classes == 13:
        name_classes = np.asarray([
            "road", "sidewalk", "building", "light", "sign", "vegetation",
            "sky", "person", "rider", "car", "bus", "motorcycle", "bicycle"
        ])
    elif args.num_classes == 18:
        name_classes = np.asarray([
            "road", "sidewalk", "building", "wall", "fence", "pole", "light",
            "sign", "vegetation", "sky", "person", "rider", "car", "truck",
            "bus", "train", "motorcycle", "bicycle"
        ])
    else:
        NotImplementedError("Unavailable number of classes")

    # Create the model and start the evaluation process
    model = DeeplabMulti(num_classes=args.num_classes)
    for files in range(int(args.num_steps_stop / args.save_pred_every)):
        print('Step: ', (files + 1) * args.save_pred_every)
        saved_state_dict = torch.load('./snapshots/' + args.dir_name + '/' +
                                      str((files + 1) * args.save_pred_every) +
                                      '.pth')
        # saved_state_dict = torch.load('./snapshots/' + '30000.pth')
        model.load_state_dict(saved_state_dict)

        device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
        model = model.to(device)

        model.eval()
        if args.gta5:
            gta5_loader = torch.utils.data.DataLoader(
                GTA5DataSet(args.data_dir_gta5,
                            args.data_list_gta5,
                            crop_size=input_size,
                            ignore_label=args.ignore_label,
                            set=args.set,
                            num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(gta5_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (GTA5): ' +
                  str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)

        if args.synthia:
            synthia_loader = torch.utils.data.DataLoader(
                SYNTHIADataSet(args.data_dir_synthia,
                               args.data_list_synthia,
                               crop_size=input_size,
                               ignore_label=args.ignore_label,
                               set=args.set,
                               num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(synthia_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (SYNTHIA): ' +
                  str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)

        if args.cityscapes:
            cityscapes_loader = torch.utils.data.DataLoader(
                cityscapesDataSet(args.data_dir_cityscapes,
                                  args.data_list_cityscapes,
                                  crop_size=input_size,
                                  ignore_label=args.ignore_label,
                                  set=args.set,
                                  num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(cityscapes_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (CityScapes): ' +
                  str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)

        if args.idd:
            idd_loader = torch.utils.data.DataLoader(
                IDDDataSet(args.data_dir_idd,
                           args.data_list_idd,
                           crop_size=input_size,
                           ignore_label=args.ignore_label,
                           set=args.set,
                           num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(idd_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (IDD): ' + str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)
Exemplo n.º 15
0
def main():
    """Create the model and start the evaluation process."""

    for i in range(1, 61):
        model_path = './snapshots/GTA2Cityscapes/GTA5_{0:d}.pth'.format(i *
                                                                        2000)
        model_D_path = './snapshots/GTA2Cityscapes/GTA5_{0:d}_D.pth'.format(
            i * 2000)
        save_path = './result/GTA2Cityscapes_{0:d}'.format(i * 2000)
        args = get_arguments()

        gpu0 = args.gpu

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        model = DeeplabMulti(num_classes=args.num_classes)
        saved_state_dict = torch.load(model_path)
        model.load_state_dict(saved_state_dict)
        model.eval()
        model.cuda(gpu0)

        num_class_list = [2048, 19]
        model_D = nn.ModuleList([
            FCDiscriminator(num_classes=num_class_list[i])
            if i < 1 else OutspaceDiscriminator(num_classes=num_class_list[i])
            for i in range(2)
        ])
        model_D.load_state_dict(torch.load(model_D_path))
        model_D.eval()
        model_D.cuda(gpu0)

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        with torch.no_grad():
            for index, batch in enumerate(testloader):
                if index % 100 == 0:
                    print('%d processd' % index)
                image, _, name = batch
                feat, pred = model(
                    Variable(image).cuda(gpu0), model_D, 'target')

                output = interp(pred).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                name = name[0].split('/')[-1]
                output.save('%s/%s' % (save_path, name))

                output_col.save('%s/%s_color.png' %
                                (save_path, name.split('.')[0]))

        print(save_path)
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()

        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if args.num_classes != 19 and i_parts[1] != 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        # if args.num_classes !=19:
        # print i_parts
        model.load_state_dict(new_params)
    start_time = datetime.datetime.now().strftime('%m-%d_%H-%M')
    writer_dir = os.path.join("./logs/", args.name, start_time)
    writer = tensorboard.SummaryWriter(writer_dir)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        MulitviewSegLoader(
            num_classes=args.num_classes,
            root=args.data_dir,
            number_views=2,
            view_idx=1,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size,
            # scale=args.random_scale, mirror=args.random_mirror,
            img_mean=IMG_MEAN),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    trainloader_iter = iter(data_loader_cycle(trainloader))

    targetloader = data.DataLoader(
        MulitviewSegLoader(
            root=args.data_dir_target,
            num_classes=args.num_classes,
            number_views=1,
            view_idx=0,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size_target,
            # scale=False,
            # mirror=args.random_mirror,
            img_mean=IMG_MEAN,
            # set=args.set
        ),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    def mdl_val_func(x):
        return interp_target(model(x)[1])

    targetloader_iter = iter(data_loader_cycle(targetloader))
    val_loader = data.DataLoader(
        MulitviewSegLoader(
            root=args.data_dir_val,
            number_views=1,
            view_idx=0,
            num_classes=args.num_classes,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size_target,
            # scale=False,
            # mirror=args.random_mirror,
            img_mean=IMG_MEAN,
            # set=args.set
        ),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True)

    criterion = CrossEntropyLoss2d().cuda(args.gpu)
    valhelper = ValHelper(gpu=args.gpu,
                          model=mdl_val_func,
                          val_loader=val_loader,
                          loss=criterion,
                          writer=writer)
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            batch = next(trainloader_iter)
            images, labels, *_ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            batch = next(targetloader_iter)
            images, *_ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()
        if i_iter % args.val_steps == 0 and i_iter:
            model.eval()
            log = valhelper.valid_epoch(i_iter)
            print('log: {}'.format(log))
            model.train()

        if i_iter % 10 == 0 and i_iter:
            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))
            writer.add_scalar(f'train/loss_seg_value1', loss_seg_value1,
                              i_iter)
            writer.add_scalar(f'train/loss_seg_value2', loss_seg_value2,
                              i_iter)
            writer.add_scalar(f'train/loss_adv_target_value1',
                              loss_adv_target_value1, i_iter)
            writer.add_scalar(f'train/loss_D_value1', loss_D_value1, i_iter)
            writer.add_scalar(f'train/loss_D_value2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))